|
|
from abc import ABCMeta, abstractmethod |
|
|
from functools import partial |
|
|
from typing import Callable, cast, Dict, Iterator, List, Optional, Tuple, Type, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import kornia |
|
|
from kornia.augmentation.base import ( |
|
|
_AugmentationBase, |
|
|
GeometricAugmentationBase2D, |
|
|
MixAugmentationBase, |
|
|
TensorWithTransformMat, |
|
|
) |
|
|
from kornia.constants import DataKey |
|
|
from kornia.geometry.bbox import transform_bbox |
|
|
from kornia.geometry.linalg import transform_points |
|
|
from kornia.utils.helpers import _torch_inverse_cast |
|
|
|
|
|
from .base import ParamItem |
|
|
|
|
|
|
|
|
def _get_geometric_only_param( |
|
|
module: "kornia.augmentation.container.ImageSequential", param: List[ParamItem] |
|
|
) -> List[ParamItem]: |
|
|
named_modules: Iterator[Tuple[str, nn.Module]] = module.get_forward_sequence(param) |
|
|
|
|
|
res: List[ParamItem] = [] |
|
|
for (_, mod), p in zip(named_modules, param): |
|
|
if isinstance(mod, (GeometricAugmentationBase2D,)): |
|
|
res.append(p) |
|
|
return res |
|
|
|
|
|
|
|
|
class ApplyInverseInterface(metaclass=ABCMeta): |
|
|
"""Abstract interface for applying and inversing transformations.""" |
|
|
|
|
|
@classmethod |
|
|
@abstractmethod |
|
|
def apply_trans( |
|
|
cls, |
|
|
input: torch.Tensor, |
|
|
label: Optional[torch.Tensor], |
|
|
module: nn.Module, |
|
|
param: ParamItem, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Apply a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
label: the optional label tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@classmethod |
|
|
@abstractmethod |
|
|
def inverse( |
|
|
cls, |
|
|
input: torch.Tensor, |
|
|
module: nn.Module, |
|
|
param: Optional[ParamItem] = None |
|
|
) -> torch.Tensor: |
|
|
"""Inverse a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class ApplyInverseImpl(ApplyInverseInterface): |
|
|
"""Standard matrix apply and inverse methods.""" |
|
|
|
|
|
apply_func: Callable |
|
|
|
|
|
@classmethod |
|
|
def apply_trans( |
|
|
cls, input: torch.Tensor, label: Optional[torch.Tensor], module: nn.Module, param: ParamItem |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Apply a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
label: the optional label tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
|
|
|
mat: Optional[torch.Tensor] = cls._get_transformation(input, module, param) |
|
|
|
|
|
if mat is not None: |
|
|
input = cls.apply_func(mat, input) |
|
|
|
|
|
return input, label |
|
|
|
|
|
@classmethod |
|
|
def inverse( |
|
|
cls, input: torch.Tensor, module: nn.Module, param: Optional[ParamItem] = None |
|
|
) -> torch.Tensor: |
|
|
"""Inverse a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
mat: Optional[torch.Tensor] = cls._get_transformation(input, module, param) |
|
|
|
|
|
if mat is not None: |
|
|
transform: torch.Tensor = cls._get_inverse_transformation(mat) |
|
|
input = cls.apply_func(torch.as_tensor(transform, device=input.device, dtype=input.dtype), input) |
|
|
return input |
|
|
|
|
|
@classmethod |
|
|
def _get_transformation( |
|
|
cls, input: torch.Tensor, module: nn.Module, param: Optional[ParamItem] = None |
|
|
) -> Optional[torch.Tensor]: |
|
|
|
|
|
if isinstance(module, ( |
|
|
GeometricAugmentationBase2D, |
|
|
kornia.augmentation.container.ImageSequential, |
|
|
)) and param is None: |
|
|
raise ValueError(f"Parameters of transformation matrix for {module} has not been computed.") |
|
|
|
|
|
if isinstance(module, GeometricAugmentationBase2D): |
|
|
_param = cast(Dict[str, torch.Tensor], param.data) |
|
|
mat = module.get_transformation_matrix(input, _param) |
|
|
elif isinstance(module, kornia.augmentation.container.ImageSequential) and not module.is_intensity_only(): |
|
|
_param = cast(List[ParamItem], param.data) |
|
|
mat = module.get_transformation_matrix(input, _param) |
|
|
else: |
|
|
return None |
|
|
return mat |
|
|
|
|
|
@classmethod |
|
|
def _get_inverse_transformation(cls, transform: torch.Tensor) -> torch.Tensor: |
|
|
return _torch_inverse_cast(transform) |
|
|
|
|
|
|
|
|
class InputApplyInverse(ApplyInverseImpl): |
|
|
"""Apply and inverse transformations for (image) input tensors.""" |
|
|
|
|
|
@classmethod |
|
|
def apply_trans( |
|
|
cls, |
|
|
input: TensorWithTransformMat, |
|
|
label: Optional[torch.Tensor], |
|
|
module: nn.Module, |
|
|
param: ParamItem, |
|
|
) -> Tuple[TensorWithTransformMat, Optional[torch.Tensor]]: |
|
|
"""Apply a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
label: the optional label tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
if isinstance(module, (MixAugmentationBase,)): |
|
|
input, label = module(input, label, params=param.data) |
|
|
elif isinstance(module, (_AugmentationBase,)): |
|
|
input = module(input, params=param.data) |
|
|
elif isinstance(module, kornia.augmentation.container.ImageSequential): |
|
|
temp = module.apply_inverse_func |
|
|
temp2 = module.return_label |
|
|
module.apply_inverse_func = InputApplyInverse |
|
|
module.return_label = True |
|
|
input, label = module(input, label, param.data) |
|
|
module.apply_inverse_func = temp |
|
|
module.return_label = temp2 |
|
|
else: |
|
|
if param.data is not None: |
|
|
raise AssertionError(f"Non-augmentaion operation {param.name} require empty parameters. Got {param}.") |
|
|
|
|
|
if isinstance(input, (tuple, list)): |
|
|
input = (module(input[0]), input[1]) |
|
|
else: |
|
|
input = module(input) |
|
|
return input, label |
|
|
|
|
|
@classmethod |
|
|
def inverse(cls, input: torch.Tensor, module: nn.Module, param: Optional[ParamItem] = None) -> torch.Tensor: |
|
|
"""Inverse a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
if isinstance(module, GeometricAugmentationBase2D): |
|
|
input = module.inverse(input, None if param is None else cast(Dict, param.data)) |
|
|
elif isinstance(module, kornia.augmentation.container.ImageSequential): |
|
|
temp = module.apply_inverse_func |
|
|
module.apply_inverse_func = InputApplyInverse |
|
|
input = module.inverse(input, None if param is None else cast(List, param.data)) |
|
|
module.apply_inverse_func = temp |
|
|
return input |
|
|
|
|
|
|
|
|
class MaskApplyInverse(ApplyInverseImpl): |
|
|
"""Apply and inverse transformations for mask tensors.""" |
|
|
|
|
|
@classmethod |
|
|
def make_input_only_sequential(cls, module: "kornia.augmentation.container.ImageSequential") -> Callable: |
|
|
"""Disable all other additional inputs (e.g. ) for ImageSequential.""" |
|
|
def f(*args, **kwargs): |
|
|
if_return_trans = module.return_transform |
|
|
if_return_label = module.return_label |
|
|
module.return_transform = False |
|
|
module.return_label = False |
|
|
out = module(*args, **kwargs) |
|
|
module.return_transform = if_return_trans |
|
|
module.return_label = if_return_label |
|
|
return out |
|
|
return f |
|
|
|
|
|
@classmethod |
|
|
def apply_trans( |
|
|
cls, input: torch.Tensor, label: Optional[torch.Tensor], module: nn.Module, param: Optional[ParamItem] = None |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Apply a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
label: the optional label tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
if param is not None: |
|
|
_param = param.data |
|
|
else: |
|
|
_param = None |
|
|
|
|
|
if isinstance(module, GeometricAugmentationBase2D): |
|
|
_param = cast(Dict[str, torch.Tensor], _param) |
|
|
input = module(input, _param, return_transform=False) |
|
|
elif isinstance(module, kornia.augmentation.container.ImageSequential) and not module.is_intensity_only(): |
|
|
_param = cast(List[ParamItem], _param) |
|
|
temp = module.apply_inverse_func |
|
|
module.apply_inverse_func = MaskApplyInverse |
|
|
geo_param: List[ParamItem] = _get_geometric_only_param(module, _param) |
|
|
input = cls.make_input_only_sequential(module)(input, None, geo_param) |
|
|
module.apply_inverse_func = temp |
|
|
else: |
|
|
pass |
|
|
return input, label |
|
|
|
|
|
@classmethod |
|
|
def inverse( |
|
|
cls, input: torch.Tensor, module: nn.Module, param: Optional[ParamItem] = None |
|
|
) -> torch.Tensor: |
|
|
"""Inverse a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
""" |
|
|
if isinstance(module, GeometricAugmentationBase2D): |
|
|
input = module.inverse(input, None if param is None else cast(Dict, param.data)) |
|
|
elif isinstance(module, kornia.augmentation.container.ImageSequential): |
|
|
temp = module.apply_inverse_func |
|
|
module.apply_inverse_func = MaskApplyInverse |
|
|
input = module.inverse(input, None if param is None else cast(List, param.data)) |
|
|
module.apply_inverse_func = temp |
|
|
return input |
|
|
|
|
|
|
|
|
class BBoxXYXYApplyInverse(ApplyInverseImpl): |
|
|
"""Apply and inverse transformations for bounding box tensors. |
|
|
|
|
|
This is for transform boxes in the format [xmin, ymin, xmax, ymax]. |
|
|
""" |
|
|
|
|
|
apply_func = partial(transform_bbox, mode="xyxy") |
|
|
|
|
|
|
|
|
class BBoxXYWHApplyInverse(ApplyInverseImpl): |
|
|
"""Apply and inverse transformations for bounding box tensors. |
|
|
|
|
|
This is for transform boxes in the format [xmin, ymin, width, height]. |
|
|
""" |
|
|
|
|
|
apply_func = partial(transform_bbox, mode="xywh") |
|
|
|
|
|
|
|
|
class KeypointsApplyInverse(ApplyInverseImpl): |
|
|
"""Apply and inverse transformations for keypoints tensors.""" |
|
|
|
|
|
apply_func = transform_points |
|
|
|
|
|
|
|
|
class ApplyInverse: |
|
|
"""Apply and inverse transformations for any tensors (e.g. mask, box, points).""" |
|
|
|
|
|
@classmethod |
|
|
def _get_func_by_key(cls, dcate: Union[str, int, DataKey]) -> Type[ApplyInverseInterface]: |
|
|
if DataKey.get(dcate) == DataKey.INPUT: |
|
|
return InputApplyInverse |
|
|
if DataKey.get(dcate) in [DataKey.MASK]: |
|
|
return MaskApplyInverse |
|
|
if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: |
|
|
return BBoxXYXYApplyInverse |
|
|
if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: |
|
|
return BBoxXYWHApplyInverse |
|
|
if DataKey.get(dcate) in [DataKey.KEYPOINTS]: |
|
|
return KeypointsApplyInverse |
|
|
raise NotImplementedError(f"input type of {dcate} is not implemented.") |
|
|
|
|
|
@classmethod |
|
|
def apply_by_key( |
|
|
cls, |
|
|
input: TensorWithTransformMat, |
|
|
label: Optional[torch.Tensor], |
|
|
module: nn.Module, |
|
|
param: ParamItem, |
|
|
dcate: Union[str, int, DataKey] = DataKey.INPUT, |
|
|
) -> Tuple[TensorWithTransformMat, Optional[torch.Tensor]]: |
|
|
"""Apply a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
label: the optional label tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
dcate: data category. 'input', 'mask', 'bbox', 'bbox_xyxy', 'bbox_xyhw', 'keypoints'. |
|
|
By default, it is set to 'input'. |
|
|
""" |
|
|
func: Type[ApplyInverseInterface] = cls._get_func_by_key(dcate) |
|
|
|
|
|
if isinstance(input, (tuple,)): |
|
|
|
|
|
return (func.apply_trans(input[0], label, module, param), *input[1:]) |
|
|
return func.apply_trans(input, label, module=module, param=param) |
|
|
|
|
|
@classmethod |
|
|
def inverse_by_key( |
|
|
cls, |
|
|
input: torch.Tensor, |
|
|
module: nn.Module, |
|
|
param: Optional[ParamItem] = None, |
|
|
dcate: Union[str, int, DataKey] = DataKey.INPUT, |
|
|
) -> torch.Tensor: |
|
|
"""Inverse a transformation with respect to the parameters. |
|
|
|
|
|
Args: |
|
|
input: the input tensor. |
|
|
module: any torch Module but only kornia augmentation modules will count |
|
|
to apply transformations. |
|
|
param: the corresponding parameters to the module. |
|
|
dcate: data category. 'input', 'mask', 'bbox', 'bbox_xyxy', 'bbox_xyhw', 'keypoints'. |
|
|
By default, it is set to 'input'. |
|
|
""" |
|
|
func: Type[ApplyInverseInterface] = cls._get_func_by_key(dcate) |
|
|
return func.inverse(input, module, param) |
|
|
|