import warnings from typing import cast, Dict, Optional, Tuple, Union import torch import torch.nn as nn from torch.distributions import Bernoulli import kornia from kornia.utils.helpers import _torch_inverse_cast from .utils import ( _adapted_sampling, _transform_input, _transform_input3d, _transform_output_shape, _validate_input_dtype, ) TensorWithTransformMat = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] class _BasicAugmentationBase(nn.Module): r"""_BasicAugmentationBase base class for customized augmentation implementations. Plain augmentation base class without the functionality of transformation matrix calculations. By default, the random computations will be happened on CPU with ``torch.get_default_dtype()``. To change this behaviour, please use ``set_rng_device_and_dtype``. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def __init__( self, p: float = 0.5, p_batch: float = 1.0, same_on_batch: bool = False, keepdim: bool = False ) -> None: super().__init__() self.p = p self.p_batch = p_batch self.same_on_batch = same_on_batch self.keepdim = keepdim self._params: Dict[str, torch.Tensor] = {} if p != 0.0 or p != 1.0: self._p_gen = Bernoulli(self.p) if p_batch != 0.0 or p_batch != 1.0: self._p_batch_gen = Bernoulli(self.p_batch) self.set_rng_device_and_dtype(torch.device('cpu'), torch.get_default_dtype()) def __repr__(self) -> str: return f"p={self.p}, p_batch={self.p_batch}, same_on_batch={self.same_on_batch}" def __unpack_input__(self, input: torch.Tensor) -> torch.Tensor: return input def __check_batching__(self, input: TensorWithTransformMat): """Check if a transformation matrix is returned, it has to be in the same batching mode as output.""" raise NotImplementedError def transform_tensor(self, input: torch.Tensor) -> torch.Tensor: """Standardize input tensors.""" raise NotImplementedError def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: return {} def apply_transform(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: raise NotImplementedError def set_rng_device_and_dtype(self, device: torch.device, dtype: torch.dtype) -> None: """Change the random generation device and dtype. Note: The generated random numbers are not reproducible across different devices and dtypes. """ self.device = device self.dtype = dtype def __batch_prob_generator__( self, batch_shape: torch.Size, p: float, p_batch: float, same_on_batch: bool ) -> torch.Tensor: batch_prob: torch.Tensor if p_batch == 1: batch_prob = torch.tensor([True]) elif p_batch == 0: batch_prob = torch.tensor([False]) else: batch_prob = _adapted_sampling((1,), self._p_batch_gen, same_on_batch).bool() if batch_prob.sum().item() == 1: elem_prob: torch.Tensor if p == 1: elem_prob = torch.tensor([True] * batch_shape[0]) elif p == 0: elem_prob = torch.tensor([False] * batch_shape[0]) else: elem_prob = _adapted_sampling((batch_shape[0],), self._p_gen, same_on_batch).bool() batch_prob = batch_prob * elem_prob else: batch_prob = batch_prob.repeat(batch_shape[0]) return batch_prob def forward_parameters(self, batch_shape) -> Dict[str, torch.Tensor]: to_apply = self.__batch_prob_generator__(batch_shape, self.p, self.p_batch, self.same_on_batch) _params = self.generate_parameters(torch.Size((int(to_apply.sum().item()), *batch_shape[1:]))) if _params is None: _params = {} _params['batch_prob'] = to_apply return _params def apply_func(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> TensorWithTransformMat: input = self.transform_tensor(input) return self.apply_transform(input, params) def forward( # type: ignore self, input: torch.Tensor, params: Optional[Dict[str, torch.Tensor]] = None # type: ignore ) -> TensorWithTransformMat: in_tensor = self.__unpack_input__(input) self.__check_batching__(input) ori_shape = in_tensor.shape in_tensor = self.transform_tensor(in_tensor) batch_shape = in_tensor.shape if params is None: params = self.forward_parameters(batch_shape) self._params = params output = self.apply_func(input, self._params) return _transform_output_shape(output, ori_shape) if self.keepdim else output class _AugmentationBase(_BasicAugmentationBase): r"""_AugmentationBase base class for customized augmentation implementations. Advanced augmentation base class with the functionality of transformation matrix calculations. Args: pprobability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. return_transform: if ``True`` return the matrix describing the geometric transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation won't be concatenated. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def __init__( self, return_transform: bool = None, same_on_batch: bool = False, p: float = 0.5, p_batch: float = 1.0, keepdim: bool = False, ) -> None: super().__init__(p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) self.p = p self.p_batch = p_batch self.return_transform = return_transform def __repr__(self) -> str: return super().__repr__() + f", return_transform={self.return_transform}" def identity_matrix(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError def compute_transformation(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: raise NotImplementedError def apply_transform( self, input: torch.Tensor, params: Dict[str, torch.Tensor], transform: Optional[torch.Tensor] = None ) -> torch.Tensor: raise NotImplementedError def __unpack_input__( # type: ignore self, input: TensorWithTransformMat ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if isinstance(input, tuple): in_tensor = input[0] in_transformation = input[1] return in_tensor, in_transformation in_tensor = input return in_tensor, None def apply_func( # type: ignore self, in_tensor: torch.Tensor, in_transform: Optional[torch.Tensor], # type: ignore params: Dict[str, torch.Tensor], return_transform: bool = False, ) -> TensorWithTransformMat: to_apply = params['batch_prob'] # if no augmentation needed if torch.sum(to_apply) == 0: output = in_tensor trans_matrix = self.identity_matrix(in_tensor) # if all data needs to be augmented elif torch.sum(to_apply) == len(to_apply): trans_matrix = self.compute_transformation(in_tensor, params) output = self.apply_transform(in_tensor, params, trans_matrix) else: output = in_tensor.clone() trans_matrix = self.identity_matrix(in_tensor) trans_matrix[to_apply] = self.compute_transformation(in_tensor[to_apply], params) output[to_apply] = self.apply_transform(in_tensor[to_apply], params, trans_matrix[to_apply]) self._transform_matrix = trans_matrix if return_transform: out_transformation = trans_matrix if in_transform is None else trans_matrix @ in_transform return output, out_transformation if in_transform is not None: return output, in_transform return output def forward( # type: ignore self, input: TensorWithTransformMat, params: Optional[Dict[str, torch.Tensor]] = None, # type: ignore return_transform: Optional[bool] = None, ) -> TensorWithTransformMat: in_tensor, in_transform = self.__unpack_input__(input) self.__check_batching__(input) ori_shape = in_tensor.shape in_tensor = self.transform_tensor(in_tensor) batch_shape = in_tensor.shape if return_transform is None: return_transform = self.return_transform return_transform = cast(bool, return_transform) if params is None: params = self.forward_parameters(batch_shape) if 'batch_prob' not in params: params['batch_prob'] = torch.tensor([True] * batch_shape[0]) # TODO(jian): we cannot throw a warning every time. # warnings.warn("`batch_prob` is not found in params. Will assume applying on all data.") self._params = params output = self.apply_func(in_tensor, in_transform, self._params, return_transform) return _transform_output_shape(output, ori_shape) if self.keepdim else output class AugmentationBase2D(_AugmentationBase): r"""AugmentationBase2D base class for customized augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the "compute_transformation" is only required when passing "return_transform" as True. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. return_transform: if ``True`` return the matrix describing the geometric transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation won't be concatenated. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def __check_batching__(self, input: TensorWithTransformMat): if isinstance(input, tuple): inp, mat = input if len(inp.shape) == 4: if len(mat.shape) != 3: raise AssertionError('Input tensor is in batch mode ' 'but transformation matrix is not') if mat.shape[0] != inp.shape[0]: raise AssertionError( f"In batch dimension, input has {inp.shape[0]} but transformation matrix has {mat.shape[0]}" ) elif len(inp.shape) in (2, 3): if len(mat.shape) != 2: raise AssertionError("Input tensor is in non-batch mode but transformation matrix is not") else: raise ValueError(f'Unrecognized output shape. Expected 2, 3, or 4, got {len(inp.shape)}') def transform_tensor(self, input: torch.Tensor) -> torch.Tensor: """Convert any incoming (H, W), (C, H, W) and (B, C, H, W) into (B, C, H, W).""" _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) return _transform_input(input) def identity_matrix(self, input) -> torch.Tensor: """Return 3x3 identity matrix.""" return kornia.eye_like(3, input) class IntensityAugmentationBase2D(AugmentationBase2D): r"""IntensityAugmentationBase2D base class for customized intensity augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the "compute_transformation" is only required when passing "return_transform" as True. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. return_transform: if ``True`` return the matrix describing the geometric transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation won't be concatenated. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def compute_transformation(self, input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: return self.identity_matrix(input) class GeometricAugmentationBase2D(AugmentationBase2D): r"""GeometricAugmentationBase2D base class for customized geometric augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the "compute_transformation" is only required when passing "return_transform" as True. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. return_transform: if ``True`` return the matrix describing the geometric transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation won't be concatenated. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def inverse_transform( self, input: torch.Tensor, transform: Optional[torch.Tensor] = None, size: Optional[Tuple[int, int]] = None, **kwargs, ) -> torch.Tensor: """By default, the exact transformation as ``apply_transform`` will be used.""" raise NotImplementedError def compute_inverse_transformation(self, transform: torch.Tensor): """Compute the inverse transform of given transformation matrices.""" return _torch_inverse_cast(transform) def get_transformation_matrix( self, input: torch.Tensor, params: Optional[Dict[str, torch.Tensor]] = None ) -> torch.Tensor: if params is not None: transform = self.compute_transformation(input, params) elif not hasattr(self, "_transform_matrix"): params = self.forward_parameters(input.shape) transform = self.identity_matrix(input) transform[params['batch_prob']] = self.compute_transformation(input[params['batch_prob']], params) else: transform = self._transform_matrix return torch.as_tensor(transform, device=input.device, dtype=input.dtype) def inverse( self, input: TensorWithTransformMat, params: Optional[Dict[str, torch.Tensor]] = None, size: Optional[Tuple[int, int]] = None, **kwargs, ) -> torch.Tensor: if isinstance(input, (list, tuple)): input, transform = input else: transform = self.get_transformation_matrix(input, params) if params is not None: transform = self.identity_matrix(input) transform[params['batch_prob']] = self.compute_transformation(input[params['batch_prob']], params) ori_shape = input.shape in_tensor = self.transform_tensor(input) batch_shape = input.shape if params is None: params = self._params if size is None and "input_size" in params: # Majorly for cropping functions size = params['input_size'].unique(dim=0).squeeze().numpy().tolist() size = (size[0], size[1]) if 'batch_prob' not in params: params['batch_prob'] = torch.tensor([True] * batch_shape[0]) warnings.warn("`batch_prob` is not found in params. Will assume applying on all data.") output = input.clone() to_apply = params['batch_prob'] # if no augmentation needed if torch.sum(to_apply) == 0: output = in_tensor # if all data needs to be augmented elif torch.sum(to_apply) == len(to_apply): transform = self.compute_inverse_transformation(transform) output = self.inverse_transform(in_tensor, transform, size, **kwargs) else: transform[to_apply] = self.compute_inverse_transformation(transform[to_apply]) output[to_apply] = self.inverse_transform(in_tensor[to_apply], transform[to_apply], size, **kwargs) return cast(torch.Tensor, _transform_output_shape(output, ori_shape)) if self.keepdim else output class AugmentationBase3D(_AugmentationBase): r"""AugmentationBase3D base class for customized augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required while the "compute_transformation" is only required when passing "return_transform" as True. Args: p: probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. return_transform: if ``True`` return the matrix describing the geometric transformation applied to each input tensor. If ``False`` and the input is a tuple the applied transformation won't be concatenated. same_on_batch: apply the same transformation across the batch. """ def __check_batching__(self, input: TensorWithTransformMat): if isinstance(input, tuple): inp, mat = input if len(inp.shape) == 5: if len(mat.shape) != 3: raise AssertionError('Input tensor is in batch mode ' 'but transformation matrix is not') if mat.shape[0] != inp.shape[0]: raise AssertionError( f"In batch dimension, input has {inp.shape[0]} but transformation matrix has {mat.shape[0]}" ) elif len(inp.shape) in (3, 4): if len(mat.shape) != 2: raise AssertionError("Input tensor is in non-batch mode but transformation matrix is not") else: raise ValueError(f'Unrecognized output shape. Expected 3, 4 or 5, got {len(inp.shape)}') def transform_tensor(self, input: torch.Tensor) -> torch.Tensor: """Convert any incoming (D, H, W), (C, D, H, W) and (B, C, D, H, W) into (B, C, D, H, W).""" _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) return _transform_input3d(input) def identity_matrix(self, input) -> torch.Tensor: """Return 4x4 identity matrix.""" return kornia.eye_like(4, input) class MixAugmentationBase(_BasicAugmentationBase): r"""MixAugmentationBase base class for customized mix augmentation implementations. For any augmentation, the implementation of "generate_parameters" and "apply_transform" are required. "apply_transform" will need to handle the probabilities internally. Args: p: probability for applying an augmentation. This param controls if to apply the augmentation for the batch. p_batch: probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. """ def __init__(self, p: float, p_batch: float, same_on_batch: bool = False, keepdim: bool = False) -> None: super().__init__(p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim) def __check_batching__(self, input: TensorWithTransformMat): if isinstance(input, tuple): inp, mat = input if len(inp.shape) == 4: if len(mat.shape) != 3: raise AssertionError('Input tensor is in batch mode ' 'but transformation matrix is not') if mat.shape[0] != inp.shape[0]: raise AssertionError( f"In batch dimension, input has {inp.shape[0]} but transformation matrix has {mat.shape[0]}" ) elif len(inp.shape) in (2, 3): if len(mat.shape) != 2: raise AssertionError("Input tensor is in non-batch mode but transformation matrix is not") else: raise ValueError(f'Unrecognized output shape. Expected 2, 3, or 4, got {len(inp.shape)}') def __unpack_input__( # type: ignore self, input: TensorWithTransformMat ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if isinstance(input, tuple): in_tensor = input[0] in_transformation = input[1] return in_tensor, in_transformation in_tensor = input return in_tensor, None def transform_tensor(self, input: torch.Tensor) -> torch.Tensor: """Convert any incoming (H, W), (C, H, W) and (B, C, H, W) into (B, C, H, W).""" _validate_input_dtype(input, accepted_dtypes=[torch.float16, torch.float32, torch.float64]) return _transform_input(input) def apply_transform( # type: ignore self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError def apply_func( # type: ignore self, in_tensor: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: to_apply = params['batch_prob'] # if no augmentation needed if torch.sum(to_apply) == 0: output = in_tensor # if all data needs to be augmented elif torch.sum(to_apply) == len(to_apply): output, label = self.apply_transform(in_tensor, label, params) else: raise ValueError( "Mix augmentations must be performed batch-wisely. Element-wise augmentation is not supported." ) return output, label def forward( # type: ignore self, input: TensorWithTransformMat, label: Optional[torch.Tensor] = None, params: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[TensorWithTransformMat, torch.Tensor]: in_tensor, in_trans = self.__unpack_input__(input) ori_shape = in_tensor.shape in_tensor = self.transform_tensor(in_tensor) # If label is not provided, it would output the indices instead. if label is None: if isinstance(input, (tuple, list)): device = input[0].device else: device = input.device label = torch.arange(0, in_tensor.size(0), device=device, dtype=torch.long) if params is None: batch_shape = in_tensor.shape params = self.forward_parameters(batch_shape) self._params = params output, lab = self.apply_func(in_tensor, label, self._params) output = _transform_output_shape(output, ori_shape) if self.keepdim else output # type: ignore if in_trans is not None: return (output, in_trans), lab return output, lab