Dexter's picture
Upload folder using huggingface_hub
36c95ba verified
from collections import OrderedDict
from typing import Any, Iterator, List, NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from kornia.augmentation.base import _AugmentationBase, MixAugmentationBase
__all__ = ["SequentialBase", "ParamItem"]
class ParamItem(NamedTuple):
name: str
data: Optional[Union[dict, list]]
class SequentialBase(nn.Sequential):
r"""SequentialBase for creating kornia modulized processing pipeline.
Args:
*args : a list of kornia augmentation and image operation modules.
same_on_batch: apply the same transformation across the batch.
If None, it will not overwrite the function-wise settings.
return_transform: if ``True`` return the matrix describing the transformation
applied to each. If None, it will not overwrite the function-wise settings.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False). If None, it will not overwrite the function-wise settings.
"""
def __init__(
self,
*args: nn.Module,
same_on_batch: Optional[bool] = None,
return_transform: Optional[bool] = None,
keepdim: Optional[bool] = None,
) -> None:
# To name the modules properly
_args = OrderedDict()
for idx, mod in enumerate(args):
if not isinstance(mod, nn.Module):
raise NotImplementedError(f"Only nn.Module are supported at this moment. Got {mod}.")
_args.update({f"{mod.__class__.__name__}_{idx}": mod})
super().__init__(_args)
self._same_on_batch = same_on_batch
self._return_transform = return_transform
self._keepdim = keepdim
self._params: Optional[List[ParamItem]] = None
self.update_attribute(same_on_batch, return_transform, keepdim)
def update_attribute(
self,
same_on_batch: Optional[bool] = None,
return_transform: Optional[bool] = None,
keepdim: Optional[bool] = None,
) -> None:
for mod in self.children():
# MixAugmentation does not have return transform
if isinstance(mod, (_AugmentationBase, MixAugmentationBase)):
if same_on_batch is not None:
mod.same_on_batch = same_on_batch
if keepdim is not None:
mod.keepdim = keepdim
if isinstance(mod, _AugmentationBase):
if return_transform is not None:
mod.return_transform = return_transform
if isinstance(mod, SequentialBase):
mod.update_attribute(same_on_batch, return_transform, keepdim)
def get_submodule(self, target: str) -> nn.Module:
"""Get submodule.
This code is taken from torch 1.9.0 since it is not introduced
back to torch 1.7.1. We included this for maintaining more
backward torch versions.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
"""
if target == "":
return self
atoms: List[str] = target.split(".")
mod: torch.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
@property
def same_on_batch(self) -> Optional[bool]:
return self._same_on_batch
@same_on_batch.setter
def same_on_batch(self, same_on_batch: Optional[bool]) -> None:
self._same_on_batch = same_on_batch
self.update_attribute(same_on_batch=same_on_batch)
@property
def return_transform(self) -> Optional[bool]:
return self._return_transform
@return_transform.setter
def return_transform(self, return_transform: Optional[bool]) -> None:
self._return_transform = return_transform
self.update_attribute(return_transform=return_transform)
@property
def keepdim(self) -> Optional[bool]:
return self._keepdim
@keepdim.setter
def keepdim(self, keepdim: Optional[bool]) -> None:
self._keepdim = keepdim
self.update_attribute(keepdim=keepdim)
def clear_state(self) -> None:
self._params = None
def update_params(self, param: Any) -> None:
if self._params is None:
self._params = [param]
else:
self._params.append(param)
# TODO: Implement this for all submodules.
def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
raise NotImplementedError
def get_children_by_indices(self, indices: torch.Tensor) -> Iterator[Tuple[str, nn.Module]]:
modules = list(self.named_children())
for idx in indices:
yield modules[idx]
def get_children_by_params(self, params: List[ParamItem]) -> Iterator[Tuple[str, nn.Module]]:
modules = list(self.named_children())
# TODO: Wrong params passed here when nested ImageSequential
for param in params:
yield modules[list(dict(self.named_children()).keys()).index(param.name)]
def get_params_by_module(self, named_modules: Iterator[Tuple[str, nn.Module]]) -> Iterator[ParamItem]:
# This will not take module._params
for name, _ in named_modules:
yield ParamItem(name, None)
def contains_label_operations(self, params: List) -> bool:
raise NotImplementedError
def autofill_dim(self, input: torch.Tensor, dim_range: Tuple[int, int] = (2, 4)) -> Tuple[torch.Size, torch.Size]:
"""Fill tensor dim to the upper bound of dim_range.
If input tensor dim is smaller than the lower bound of dim_range, an error will be thrown out.
"""
ori_shape = input.shape
if len(ori_shape) < dim_range[0] or len(ori_shape) > dim_range[1]:
raise RuntimeError(f"input shape expected to be in {dim_range} while got {ori_shape}.")
while len(input.shape) < dim_range[1]:
input = input[None]
return ori_shape, input.shape