File size: 13,646 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import warnings
from itertools import zip_longest
from typing import cast, List, Optional, Tuple, Union
import torch
from kornia.augmentation.base import (
_AugmentationBase,
GeometricAugmentationBase2D,
IntensityAugmentationBase2D,
TensorWithTransformMat,
)
from kornia.constants import DataKey
from .base import SequentialBase
from .image import ImageSequential, ParamItem
from .patch import PatchSequential
from .utils import ApplyInverse
from .video import VideoSequential
__all__ = ["AugmentationSequential"]
class AugmentationSequential(ImageSequential):
r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once.
.. image:: https://kornia-tutorials.readthedocs.io/en/latest/_images/data_augmentation_sequential_5_1.png
:width: 49 %
.. image:: https://kornia-tutorials.readthedocs.io/en/latest/_images/data_augmentation_sequential_7_0.png
:width: 49 %
Args:
*args: a list of kornia augmentation modules.
data_keys: the input type sequential for applying augmentations.
Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
same_on_batch: apply the same transformation across the batch.
If None, it will not overwrite the function-wise settings.
return_transform: if ``True`` return the matrix describing the transformation
applied to each. If None, it will not overwrite the function-wise settings.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False). If None, it will not overwrite the function-wise settings.
random_apply: randomly select a sublist (order agnostic) of args to
apply transformation.
If int, a fixed number of transformations will be selected.
If (a,), x number of transformations (a <= x <= len(args)) will be selected.
If (a, b), x number of transformations (a <= x <= b) will be selected.
If True, the whole list of args will be processed as a sequence in a random order.
If False, the whole list of args will be processed as a sequence in original order.
.. note::
Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input" data key.
It is not clear how to deal with the conversions of masks, bounding boxes and keypoints.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/
latest/data_augmentation_sequential.html>`__.
Examples:
>>> import kornia
>>> input = torch.randn(2, 3, 5, 6)
>>> bbox = torch.tensor([[
... [1., 1.],
... [2., 1.],
... [2., 2.],
... [1., 2.],
... ]]).expand(2, -1, -1)
>>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)
>>> aug_list = AugmentationSequential(
... kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
... kornia.augmentation.RandomAffine(360, p=1.0),
... data_keys=["input", "mask", "bbox", "keypoints"],
... return_transform=False,
... same_on_batch=False,
... random_apply=10,
... )
>>> out = aug_list(input, input, bbox, points)
>>> [o.shape for o in out]
[torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 4, 2]), torch.Size([2, 1, 2])]
>>> out_inv = aug_list.inverse(*out)
>>> [o.shape for o in out_inv]
[torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 4, 2]), torch.Size([2, 1, 2])]
This example demonstrates the integration of VideoSequential and AugmentationSequential.
Examples:
>>> import kornia
>>> input = torch.randn(2, 3, 5, 6)[None]
>>> bbox = torch.tensor([[
... [1., 1.],
... [2., 1.],
... [2., 2.],
... [1., 2.],
... ]]).expand(2, -1, -1)[None]
>>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None]
>>> aug_list = AugmentationSequential(
... VideoSequential(
... kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
... kornia.augmentation.RandomAffine(360, p=1.0),
... ),
... data_keys=["input", "mask", "bbox", "keypoints"]
... )
>>> out = aug_list(input, input, bbox, points)
>>> [o.shape for o in out]
[torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 4, 2]), torch.Size([1, 2, 1, 2])]
"""
def __init__(
self,
*args: Union[_AugmentationBase, ImageSequential],
data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT],
same_on_batch: Optional[bool] = None,
return_transform: Optional[bool] = None,
keepdim: Optional[bool] = None,
random_apply: Union[int, bool, Tuple[int, int]] = False,
) -> None:
super().__init__(
*args,
same_on_batch=same_on_batch,
return_transform=return_transform,
keepdim=keepdim,
random_apply=random_apply,
)
self.data_keys = [DataKey.get(inp) for inp in data_keys]
if not all(in_type in DataKey for in_type in self.data_keys):
raise AssertionError(f"`data_keys` must be in {DataKey}. Got {data_keys}.")
if self.data_keys[0] != DataKey.INPUT:
raise NotImplementedError(f"The first input must be {DataKey.INPUT}.")
self.contains_video_sequential: bool = False
for arg in args:
if isinstance(arg, PatchSequential) and not arg.is_intensity_only():
warnings.warn("Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.")
if isinstance(arg, VideoSequential):
self.contains_video_sequential = True
def inverse( # type: ignore
self,
*args: torch.Tensor,
params: Optional[List[ParamItem]] = None,
data_keys: Optional[List[Union[str, int, DataKey]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Reverse the transformation applied.
Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use
``self.data_keys`` by default.
"""
if data_keys is None:
data_keys = cast(List[Union[str, int, DataKey]], self.data_keys)
if len(args) != len(data_keys):
raise AssertionError(
"The number of inputs must align with the number of data_keys, "
f"Got {len(args)} and {len(data_keys)}."
)
if params is None:
if self._params is None:
raise ValueError(
"No parameters available for inversing, please run a forward pass first "
"or passing valid params into this function."
)
params = self._params
outputs = []
for input, dcate in zip(args, data_keys):
if dcate == DataKey.INPUT and isinstance(input, (tuple, list)):
input, _ = input # ignore the transformation matrix whilst inverse
for (name, module), param in zip_longest(list(self.get_forward_sequence(params))[::-1], params[::-1]):
if isinstance(module, (_AugmentationBase, ImageSequential)):
param = params[name] if name in params else param
else:
param = None
if isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey:
pass # Do nothing
elif isinstance(module, ImageSequential) and module.is_intensity_only() and dcate in DataKey:
pass # Do nothing
elif isinstance(module, VideoSequential) and dcate not in [DataKey.INPUT, DataKey.MASK]:
batch_size: int = input.size(0)
input = input.view(-1, *input.shape[2:])
input = ApplyInverse.inverse_by_key(input, module, param, dcate)
input = input.view(batch_size, -1, *input.shape[1:])
elif isinstance(module, PatchSequential):
raise NotImplementedError("Geometric involved PatchSequential is not supported.")
elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential)) and dcate in DataKey:
input = ApplyInverse.inverse_by_key(input, module, param, dcate)
elif isinstance(module, (SequentialBase,)):
raise ValueError(f"Unsupported Sequential {module}.")
else:
raise NotImplementedError(f"data_key {dcate} is not implemented for {module}.")
outputs.append(input)
if len(outputs) == 1 and isinstance(outputs, (tuple, list)):
return outputs[0]
return outputs
def __packup_output__( # type: ignore
self, output: List[TensorWithTransformMat], label: Optional[torch.Tensor] = None
) -> Union[
TensorWithTransformMat,
Tuple[TensorWithTransformMat, Optional[torch.Tensor]],
List[TensorWithTransformMat],
Tuple[List[TensorWithTransformMat], Optional[torch.Tensor]],
]:
if len(output) == 1 and isinstance(output, (tuple, list)) and self.return_label:
return output[0], label
if len(output) == 1 and isinstance(output, (tuple, list)):
return output[0]
if self.return_label:
return output, label
return output
def forward( # type: ignore
self,
*args: TensorWithTransformMat,
label: Optional[torch.Tensor] = None,
params: Optional[List[ParamItem]] = None,
data_keys: Optional[List[Union[str, int, DataKey]]] = None,
) -> Union[
TensorWithTransformMat,
Tuple[TensorWithTransformMat, Optional[torch.Tensor]],
List[TensorWithTransformMat],
Tuple[List[TensorWithTransformMat], Optional[torch.Tensor]],
]:
"""Compute multiple tensors simultaneously according to ``self.data_keys``."""
if data_keys is None:
data_keys = cast(List[Union[str, int, DataKey]], self.data_keys)
else:
data_keys = [DataKey.get(inp) for inp in data_keys]
if len(args) != len(data_keys):
raise AssertionError(
f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}."
)
if params is None:
if DataKey.INPUT in data_keys:
_input = args[data_keys.index(DataKey.INPUT)]
if isinstance(_input, (tuple, list)):
inp = _input[0]
else:
inp = _input
if self.contains_video_sequential:
_, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
else:
_, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
params = self.forward_parameters(out_shape)
else:
raise ValueError("`params` must be provided whilst INPUT is not in data_keys.")
outputs: List[TensorWithTransformMat] = [None] * len(data_keys) # type: ignore
if DataKey.INPUT in data_keys:
idx = data_keys.index(DataKey.INPUT)
out = super().forward(args[idx], label, params=params)
if self.return_label:
input, label = cast(Tuple[TensorWithTransformMat, torch.Tensor], out)
else:
input = cast(TensorWithTransformMat, out)
outputs[idx] = input
self.return_label = label is not None or self.contains_label_operations(params)
for idx, (input, dcate, out) in enumerate(zip(args, data_keys, outputs)):
if out is not None:
continue
for param in params:
module = self.get_submodule(param.name)
if dcate == DataKey.INPUT:
input, label = self.apply_to_input(input, label, module=module, param=param)
elif isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey:
pass # Do nothing
elif isinstance(module, ImageSequential) and module.is_intensity_only() and dcate in DataKey:
pass # Do nothing
elif isinstance(module, VideoSequential) and dcate not in [DataKey.INPUT, DataKey.MASK]:
batch_size: int = input.size(0)
input = input.view(-1, *input.shape[2:])
input, label = ApplyInverse.apply_by_key(input, label, module, param, dcate)
input = input.view(batch_size, -1, *input.shape[1:])
elif isinstance(module, PatchSequential):
raise NotImplementedError("Geometric involved PatchSequential is not supported.")
elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential,)) and dcate in DataKey:
input, label = ApplyInverse.apply_by_key(input, label, module, param, dcate)
elif isinstance(module, (SequentialBase,)):
raise ValueError(f"Unsupported Sequential {module}.")
else:
raise NotImplementedError(f"data_key {dcate} is not implemented for {module}.")
outputs[idx] = input
return self.__packup_output__(outputs, label)
|