File size: 6,654 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from collections import OrderedDict
from typing import Any, Iterator, List, NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from kornia.augmentation.base import _AugmentationBase, MixAugmentationBase
__all__ = ["SequentialBase", "ParamItem"]
class ParamItem(NamedTuple):
name: str
data: Optional[Union[dict, list]]
class SequentialBase(nn.Sequential):
r"""SequentialBase for creating kornia modulized processing pipeline.
Args:
*args : a list of kornia augmentation and image operation modules.
same_on_batch: apply the same transformation across the batch.
If None, it will not overwrite the function-wise settings.
return_transform: if ``True`` return the matrix describing the transformation
applied to each. If None, it will not overwrite the function-wise settings.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False). If None, it will not overwrite the function-wise settings.
"""
def __init__(
self,
*args: nn.Module,
same_on_batch: Optional[bool] = None,
return_transform: Optional[bool] = None,
keepdim: Optional[bool] = None,
) -> None:
# To name the modules properly
_args = OrderedDict()
for idx, mod in enumerate(args):
if not isinstance(mod, nn.Module):
raise NotImplementedError(f"Only nn.Module are supported at this moment. Got {mod}.")
_args.update({f"{mod.__class__.__name__}_{idx}": mod})
super().__init__(_args)
self._same_on_batch = same_on_batch
self._return_transform = return_transform
self._keepdim = keepdim
self._params: Optional[List[ParamItem]] = None
self.update_attribute(same_on_batch, return_transform, keepdim)
def update_attribute(
self,
same_on_batch: Optional[bool] = None,
return_transform: Optional[bool] = None,
keepdim: Optional[bool] = None,
) -> None:
for mod in self.children():
# MixAugmentation does not have return transform
if isinstance(mod, (_AugmentationBase, MixAugmentationBase)):
if same_on_batch is not None:
mod.same_on_batch = same_on_batch
if keepdim is not None:
mod.keepdim = keepdim
if isinstance(mod, _AugmentationBase):
if return_transform is not None:
mod.return_transform = return_transform
if isinstance(mod, SequentialBase):
mod.update_attribute(same_on_batch, return_transform, keepdim)
def get_submodule(self, target: str) -> nn.Module:
"""Get submodule.
This code is taken from torch 1.9.0 since it is not introduced
back to torch 1.7.1. We included this for maintaining more
backward torch versions.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
"""
if target == "":
return self
atoms: List[str] = target.split(".")
mod: torch.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
@property
def same_on_batch(self) -> Optional[bool]:
return self._same_on_batch
@same_on_batch.setter
def same_on_batch(self, same_on_batch: Optional[bool]) -> None:
self._same_on_batch = same_on_batch
self.update_attribute(same_on_batch=same_on_batch)
@property
def return_transform(self) -> Optional[bool]:
return self._return_transform
@return_transform.setter
def return_transform(self, return_transform: Optional[bool]) -> None:
self._return_transform = return_transform
self.update_attribute(return_transform=return_transform)
@property
def keepdim(self) -> Optional[bool]:
return self._keepdim
@keepdim.setter
def keepdim(self, keepdim: Optional[bool]) -> None:
self._keepdim = keepdim
self.update_attribute(keepdim=keepdim)
def clear_state(self) -> None:
self._params = None
def update_params(self, param: Any) -> None:
if self._params is None:
self._params = [param]
else:
self._params.append(param)
# TODO: Implement this for all submodules.
def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
raise NotImplementedError
def get_children_by_indices(self, indices: torch.Tensor) -> Iterator[Tuple[str, nn.Module]]:
modules = list(self.named_children())
for idx in indices:
yield modules[idx]
def get_children_by_params(self, params: List[ParamItem]) -> Iterator[Tuple[str, nn.Module]]:
modules = list(self.named_children())
# TODO: Wrong params passed here when nested ImageSequential
for param in params:
yield modules[list(dict(self.named_children()).keys()).index(param.name)]
def get_params_by_module(self, named_modules: Iterator[Tuple[str, nn.Module]]) -> Iterator[ParamItem]:
# This will not take module._params
for name, _ in named_modules:
yield ParamItem(name, None)
def contains_label_operations(self, params: List) -> bool:
raise NotImplementedError
def autofill_dim(self, input: torch.Tensor, dim_range: Tuple[int, int] = (2, 4)) -> Tuple[torch.Size, torch.Size]:
"""Fill tensor dim to the upper bound of dim_range.
If input tensor dim is smaller than the lower bound of dim_range, an error will be thrown out.
"""
ori_shape = input.shape
if len(ori_shape) < dim_range[0] or len(ori_shape) > dim_range[1]:
raise RuntimeError(f"input shape expected to be in {dim_range} while got {ori_shape}.")
while len(input.shape) < dim_range[1]:
input = input[None]
return ori_shape, input.shape
|