|
|
from itertools import zip_longest |
|
|
from typing import Iterator, List, Optional, Tuple, Type, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import kornia |
|
|
from kornia.augmentation.base import ( |
|
|
_AugmentationBase, |
|
|
GeometricAugmentationBase2D, |
|
|
IntensityAugmentationBase2D, |
|
|
MixAugmentationBase, |
|
|
TensorWithTransformMat, |
|
|
) |
|
|
|
|
|
from .base import ParamItem, SequentialBase |
|
|
from .utils import ApplyInverseInterface, InputApplyInverse |
|
|
|
|
|
__all__ = ["ImageSequential"] |
|
|
|
|
|
|
|
|
class ImageSequential(SequentialBase): |
|
|
r"""Sequential for creating kornia image processing pipeline. |
|
|
|
|
|
Args: |
|
|
*args : a list of kornia augmentation and image operation modules. |
|
|
same_on_batch: apply the same transformation across the batch. |
|
|
If None, it will not overwrite the function-wise settings. |
|
|
return_transform: if ``True`` return the matrix describing the transformation |
|
|
applied to each. If None, it will not overwrite the function-wise settings. |
|
|
keepdim: whether to keep the output shape the same as input (True) or broadcast it |
|
|
to the batch form (False). If None, it will not overwrite the function-wise settings. |
|
|
random_apply: randomly select a sublist (order agnostic) of args to |
|
|
apply transformation. |
|
|
If int, a fixed number of transformations will be selected. |
|
|
If (a,), x number of transformations (a <= x <= len(args)) will be selected. |
|
|
If (a, b), x number of transformations (a <= x <= b) will be selected. |
|
|
If True, the whole list of args will be processed as a sequence in a random order. |
|
|
If False, the whole list of args will be processed as a sequence in original order. |
|
|
|
|
|
.. note:: |
|
|
Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module. |
|
|
Those transformations in ``kornia.geometry`` will not be taken into account. |
|
|
|
|
|
Examples: |
|
|
>>> _ = torch.manual_seed(77) |
|
|
>>> import kornia |
|
|
>>> input, label = torch.randn(2, 3, 5, 6), torch.tensor([0, 1]) |
|
|
>>> aug_list = ImageSequential( |
|
|
... kornia.color.BgrToRgb(), |
|
|
... kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), |
|
|
... kornia.filters.MedianBlur((3, 3)), |
|
|
... kornia.augmentation.RandomAffine(360, p=1.0), |
|
|
... kornia.enhance.Invert(), |
|
|
... kornia.augmentation.RandomMixUp(p=1.0), |
|
|
... return_transform=True, |
|
|
... same_on_batch=True, |
|
|
... random_apply=10, |
|
|
... ) |
|
|
>>> out, lab = aug_list(input, label=label) |
|
|
>>> lab |
|
|
tensor([[0.0000, 1.0000, 0.1214], |
|
|
[1.0000, 0.0000, 0.1214]]) |
|
|
>>> out[0].shape, out[1].shape |
|
|
(torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 3])) |
|
|
|
|
|
Reproduce with provided params. |
|
|
>>> out2, lab2 = aug_list(input, label=label, params=aug_list._params) |
|
|
>>> torch.equal(out[0], out2[0]), torch.equal(out[1], out2[1]), torch.equal(lab[1], lab2[1]) |
|
|
(True, True, True) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*args: nn.Module, |
|
|
same_on_batch: Optional[bool] = None, |
|
|
return_transform: Optional[bool] = None, |
|
|
keepdim: Optional[bool] = None, |
|
|
random_apply: Union[int, bool, Tuple[int, int]] = False, |
|
|
if_unsupported_ops: str = 'raise' |
|
|
) -> None: |
|
|
super().__init__(*args, same_on_batch=same_on_batch, return_transform=return_transform, keepdim=keepdim) |
|
|
|
|
|
self.random_apply: Union[Tuple[int, int], bool] = self._read_random_apply(random_apply, len(args)) |
|
|
self.return_label: Optional[bool] = None |
|
|
self.apply_inverse_func: Type[ApplyInverseInterface] = InputApplyInverse |
|
|
self.if_unsupported_ops = if_unsupported_ops |
|
|
|
|
|
def _read_random_apply( |
|
|
self, random_apply: Union[int, bool, Tuple[int, int]], max_length: int |
|
|
) -> Union[Tuple[int, int], bool]: |
|
|
"""Process the scenarios for random apply.""" |
|
|
if isinstance(random_apply, (bool,)) and random_apply is False: |
|
|
random_apply = False |
|
|
elif isinstance(random_apply, (bool,)) and random_apply is True: |
|
|
random_apply = (max_length, max_length + 1) |
|
|
elif isinstance(random_apply, (int,)): |
|
|
random_apply = (random_apply, random_apply + 1) |
|
|
elif ( |
|
|
isinstance(random_apply, (tuple,)) |
|
|
and len(random_apply) == 2 |
|
|
and isinstance(random_apply[0], (int,)) |
|
|
and isinstance(random_apply[1], (int,)) |
|
|
): |
|
|
random_apply = (random_apply[0], random_apply[1] + 1) |
|
|
elif isinstance(random_apply, (tuple,)) and len(random_apply) == 1 and isinstance(random_apply[0], (int,)): |
|
|
random_apply = (random_apply[0], max_length + 1) |
|
|
else: |
|
|
raise ValueError(f"Non-readable random_apply. Got {random_apply}.") |
|
|
if random_apply is not False and not ( |
|
|
isinstance(random_apply, (tuple,)) |
|
|
and len(random_apply) == 2 |
|
|
and isinstance(random_apply[0], (int,)) |
|
|
and isinstance(random_apply[0], (int,)) |
|
|
): |
|
|
raise AssertionError(f"Expect a tuple of (int, int). Got {random_apply}.") |
|
|
return random_apply |
|
|
|
|
|
def get_random_forward_sequence(self, with_mix: bool = True) -> Tuple[Iterator[Tuple[str, nn.Module]], bool]: |
|
|
"""Get a forward sequence when random apply is in need. |
|
|
|
|
|
Note: |
|
|
Mix augmentations (e.g. RandomMixUp) will be only applied once even in a random forward. |
|
|
""" |
|
|
num_samples = int(torch.randint(*self.random_apply, (1,)).item()) |
|
|
multinomial_weights = torch.ones((len(self),)) |
|
|
|
|
|
mix_indices = self.get_mix_augmentation_indices(self.named_children()) |
|
|
|
|
|
multinomial_weights[mix_indices] = 0 |
|
|
indices = torch.multinomial( |
|
|
multinomial_weights, |
|
|
num_samples, |
|
|
|
|
|
replacement=num_samples > multinomial_weights.sum().item(), |
|
|
) |
|
|
|
|
|
mix_added = False |
|
|
if with_mix and len(mix_indices) != 0: |
|
|
|
|
|
if (torch.rand(1) < ((len(mix_indices) + len(indices)) / len(self))).item(): |
|
|
indices[-1] = torch.multinomial((~multinomial_weights.bool()).float(), 1) |
|
|
indices = indices[torch.randperm(len(indices))] |
|
|
mix_added = True |
|
|
|
|
|
return self.get_children_by_indices(indices), mix_added |
|
|
|
|
|
def get_mix_augmentation_indices(self, named_modules: Iterator[Tuple[str, nn.Module]]) -> List[int]: |
|
|
"""Get all the mix augmentations since they are label-involved. |
|
|
|
|
|
Special operations needed for label-involved augmentations. |
|
|
""" |
|
|
indices = [] |
|
|
for idx, (_, child) in enumerate(named_modules): |
|
|
if isinstance(child, (MixAugmentationBase,)): |
|
|
indices.append(idx) |
|
|
return indices |
|
|
|
|
|
def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, nn.Module]]: |
|
|
if params is None: |
|
|
|
|
|
mix_indices = self.get_mix_augmentation_indices(self.named_children()) |
|
|
|
|
|
if self.random_apply: |
|
|
return self.get_random_forward_sequence()[0] |
|
|
|
|
|
if len(mix_indices) > 1: |
|
|
raise ValueError( |
|
|
"Multiple mix augmentation is prohibited without enabling random_apply." |
|
|
f"Detected {len(mix_indices)}." |
|
|
) |
|
|
|
|
|
return self.named_children() |
|
|
|
|
|
return self.get_children_by_params(params) |
|
|
|
|
|
def apply_to_input( |
|
|
self, |
|
|
input: TensorWithTransformMat, |
|
|
label: Optional[torch.Tensor], |
|
|
module: Optional[nn.Module], |
|
|
param: ParamItem, |
|
|
) -> Tuple[TensorWithTransformMat, Optional[torch.Tensor]]: |
|
|
if module is None: |
|
|
module = self.get_submodule(param.name) |
|
|
return self.apply_inverse_func.apply_trans(input, label, module, param) |
|
|
|
|
|
def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]: |
|
|
named_modules: Iterator[Tuple[str, nn.Module]] = self.get_forward_sequence() |
|
|
|
|
|
params: List[ParamItem] = [] |
|
|
mod_param: Union[dict, list] |
|
|
for name, module in named_modules: |
|
|
if isinstance(module, (_AugmentationBase, MixAugmentationBase)): |
|
|
mod_param = module.forward_parameters(batch_shape) |
|
|
param = ParamItem(name, mod_param) |
|
|
elif isinstance(module, ImageSequential): |
|
|
mod_param = module.forward_parameters(batch_shape) |
|
|
param = ParamItem(name, mod_param) |
|
|
else: |
|
|
param = ParamItem(name, None) |
|
|
params.append(param) |
|
|
return params |
|
|
|
|
|
def contains_label_operations(self, params: List[ParamItem]) -> bool: |
|
|
"""Check if current sequential contains label-involved operations like MixUp.""" |
|
|
for param in params: |
|
|
if param.name.startswith("RandomMixUp") or param.name.startswith("RandomCutMix"): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def __packup_output__( |
|
|
self, output: TensorWithTransformMat, label: Optional[torch.Tensor] = None |
|
|
) -> Union[TensorWithTransformMat, Tuple[TensorWithTransformMat, torch.Tensor]]: |
|
|
if self.return_label: |
|
|
return output, label |
|
|
|
|
|
return output |
|
|
|
|
|
def get_transformation_matrix( |
|
|
self, input: torch.Tensor, params: Optional[List[ParamItem]] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Compute the transformation matrix according to the provided parameters.""" |
|
|
if params is None: |
|
|
raise NotImplementedError("requires params to be provided.") |
|
|
named_modules: Iterator[Tuple[str, nn.Module]] = self.get_forward_sequence(params) |
|
|
|
|
|
res_mat: torch.Tensor = kornia.eye_like(3, input) |
|
|
for (_, module), param in zip(named_modules, params): |
|
|
if isinstance(module, (_AugmentationBase, MixAugmentationBase)): |
|
|
mat = module.compute_transformation(input, param.data) |
|
|
res_mat = mat @ res_mat |
|
|
elif isinstance(module, (ImageSequential,)): |
|
|
mat = module.get_transformation_matrix(input, param.data) |
|
|
res_mat = mat @ res_mat |
|
|
return res_mat |
|
|
|
|
|
def is_intensity_only(self, strict: bool = True) -> bool: |
|
|
"""Check if all transformations are intensity-based. |
|
|
|
|
|
Args: |
|
|
strict: if strict is False, it will allow non-augmentation nn.Modules to be passed. |
|
|
e.g. `kornia.enhance.AdjustBrightness` will be recognized as non-intensity module |
|
|
if strict is set to True. |
|
|
|
|
|
Note: patch processing would break the continuity of labels (e.g. bbounding boxes, masks). |
|
|
""" |
|
|
for arg in self.children(): |
|
|
if isinstance(arg, (ImageSequential,)) and not arg.is_intensity_only(strict): |
|
|
return False |
|
|
elif isinstance(arg, (ImageSequential,)): |
|
|
pass |
|
|
elif isinstance(arg, IntensityAugmentationBase2D): |
|
|
pass |
|
|
elif strict: |
|
|
|
|
|
|
|
|
return False |
|
|
return True |
|
|
|
|
|
def inverse( |
|
|
self, |
|
|
input: torch.Tensor, |
|
|
params: Optional[List[ParamItem]] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Inverse transformation. |
|
|
|
|
|
Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to |
|
|
provided parameters. |
|
|
""" |
|
|
if params is None: |
|
|
if self._params is None: |
|
|
raise ValueError( |
|
|
"No parameters available for inversing, please run a forward pass first " |
|
|
"or passing valid params into this function." |
|
|
) |
|
|
params = self._params |
|
|
|
|
|
for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]): |
|
|
if isinstance(module, (_AugmentationBase, ImageSequential)): |
|
|
param = params[name] if name in params else param |
|
|
else: |
|
|
param = None |
|
|
|
|
|
if isinstance(module, IntensityAugmentationBase2D): |
|
|
pass |
|
|
elif isinstance(module, ImageSequential) and module.is_intensity_only(): |
|
|
pass |
|
|
elif isinstance(module, ImageSequential): |
|
|
input = module.inverse(input, param.data) |
|
|
elif isinstance(module, (GeometricAugmentationBase2D,)): |
|
|
input = self.apply_inverse_func.inverse(input, module, param) |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
return input |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input: TensorWithTransformMat, |
|
|
label: Optional[torch.Tensor] = None, |
|
|
params: Optional[List[ParamItem]] = None, |
|
|
) -> Union[TensorWithTransformMat, Tuple[TensorWithTransformMat, torch.Tensor]]: |
|
|
self.clear_state() |
|
|
if params is None: |
|
|
if isinstance(input, (tuple, list)): |
|
|
inp = input[0] |
|
|
else: |
|
|
inp = input |
|
|
_, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) |
|
|
params = self.forward_parameters(out_shape) |
|
|
if self.return_label is None: |
|
|
self.return_label = label is not None or self.contains_label_operations(params) |
|
|
for param in params: |
|
|
module = self.get_submodule(param.name) |
|
|
input, label = self.apply_to_input(input, label, module, param=param) |
|
|
if isinstance(module, (_AugmentationBase, MixAugmentationBase, SequentialBase)): |
|
|
param = ParamItem(param.name, module._params) |
|
|
else: |
|
|
param = ParamItem(param.name, None) |
|
|
self.update_params(param) |
|
|
return self.__packup_output__(input, label) |
|
|
|