File size: 11,240 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
from typing import cast, List, Optional, Tuple, Union

import torch
import torch.nn as nn

import kornia
from kornia.augmentation.base import _AugmentationBase, MixAugmentationBase, TensorWithTransformMat
from kornia.augmentation.container.base import SequentialBase
from kornia.augmentation.container.utils import InputApplyInverse, MaskApplyInverse

from .image import ImageSequential, ParamItem

__all__ = ["VideoSequential"]


class VideoSequential(ImageSequential):
    r"""VideoSequential for processing 5-dim video data like (B, T, C, H, W) and (B, C, T, H, W).

    `VideoSequential` is used to replace `nn.Sequential` for processing video data augmentations.
    By default, `VideoSequential` enabled `same_on_frame` to make sure the same augmentations happen
    across temporal dimension. Meanwhile, it will not affect other augmentation behaviours like the
    settings on `same_on_batch`, etc.

    Args:
        *args: a list of augmentation module.
        data_format: only BCTHW and BTCHW are supported.
        same_on_frame: apply the same transformation across the channel per frame.
        random_apply: randomly select a sublist (order agnostic) of args to
            apply transformation.
            If int, a fixed number of transformations will be selected.
            If (a,), x number of transformations (a <= x <= len(args)) will be selected.
            If (a, b), x number of transformations (a <= x <= b) will be selected.
            If None, the whole list of args will be processed as a sequence.

    Note:
        Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
        Those transformations in ``kornia.geometry`` will not be taken into account.

    Example:
        If set `same_on_frame` to True, we would expect the same augmentation has been applied to each
        timeframe.

        >>> input, label = torch.randn(2, 3, 1, 5, 6).repeat(1, 1, 4, 1, 1), torch.tensor([0, 1])
        >>> aug_list = VideoSequential(
        ...     kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.color.BgrToRgb(),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     random_apply=10,
        ...     data_format="BCTHW",
        ...     same_on_frame=True)
        >>> output = aug_list(input)
        >>> (output[0, :, 0] == output[0, :, 1]).all()
        tensor(True)
        >>> (output[0, :, 1] == output[0, :, 2]).all()
        tensor(True)
        >>> (output[0, :, 2] == output[0, :, 3]).all()
        tensor(True)

        If set `same_on_frame` to False:

        >>> aug_list = VideoSequential(
        ...     kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
        ...     kornia.augmentation.RandomAffine(360, p=1.0),
        ...     kornia.augmentation.RandomMixUp(p=1.0),
        ... data_format="BCTHW",
        ... same_on_frame=False)
        >>> output, lab = aug_list(input)
        >>> output.shape, lab.shape
        (torch.Size([2, 3, 4, 5, 6]), torch.Size([2, 4, 3]))
        >>> (output[0, :, 0] == output[0, :, 1]).all()
        tensor(False)

        Reproduce with provided params.
        >>> out2, lab2 = aug_list(input, label, params=aug_list._params)
        >>> torch.equal(output, out2)
        True
    """

    def __init__(
        self,
        *args: nn.Module,
        data_format: str = "BTCHW",
        same_on_frame: bool = True,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
    ) -> None:
        super().__init__(*args, same_on_batch=None, return_transform=None, keepdim=None, random_apply=random_apply)
        self.same_on_frame = same_on_frame
        self.data_format = data_format.upper()
        if self.data_format not in ["BCTHW", "BTCHW"]:
            raise AssertionError(f"Only `BCTHW` and `BTCHW` are supported. Got `{data_format}`.")
        self._temporal_channel: int
        if self.data_format == "BCTHW":
            self._temporal_channel = 2
        elif self.data_format == "BTCHW":
            self._temporal_channel = 1

    def __infer_channel_exclusive_batch_shape__(self, batch_shape: torch.Size, chennel_index: int) -> torch.Size:
        # Fix mypy complains: error: Incompatible return value type (got "Tuple[int, ...]", expected "Size")
        return cast(torch.Size, batch_shape[:chennel_index] + batch_shape[chennel_index + 1:])

    def __repeat_param_across_channels__(self, param: torch.Tensor, frame_num: int) -> torch.Tensor:
        """Repeat parameters across channels.

        The input is shaped as (B, ...), while to output (B * same_on_frame, ...), which
        to guarantee that the same transformation would happen for each frame.

        (B1, B2, ..., Bn) => (B1, ... B1, B2, ..., B2, ..., Bn, ..., Bn)
                              | ch_size | | ch_size |  ..., | ch_size |
        """
        repeated = param[:, None, ...].repeat(1, frame_num, *([1] * len(param.shape[1:])))
        return repeated.reshape(-1, *list(param.shape[1:]))  # type: ignore

    def _input_shape_convert_in(
        self, input: torch.Tensor, label: Optional[torch.Tensor], frame_num: int
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Convert any shape to (B, T, C, H, W)
        if self.data_format == "BCTHW":
            # Convert (B, C, T, H, W) to (B, T, C, H, W)
            input = input.transpose(1, 2)
        if self.data_format == "BTCHW":
            pass

        if label is not None:
            if label.shape == input.shape[:2]:
                # if label is provided as (B, T)
                label = label.view(-1)
            elif label.shape == input.shape[:1]:
                label = label[..., None].repeat(1, frame_num).view(-1)
            elif label.shape == torch.Size([input.shape[0] * input.shape[1]]):
                # Skip the conversion if label is provided as (B * T,)
                pass
            else:
                raise NotImplementedError(f"Invalid label shape of {label.shape}.")
        input = input.reshape(-1, *input.shape[2:])
        return input, label

    def _input_shape_convert_back(
        self, input: torch.Tensor, label: Optional[torch.Tensor], frame_num: int
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        input = input.view(-1, frame_num, *input.shape[1:])
        if self.data_format == "BCTHW":
            input = input.transpose(1, 2)
        if self.data_format == "BTCHW":
            pass

        if label is not None:
            label = label.view(input.size(0), frame_num, -1)
        return input, label

    def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
        frame_num = batch_shape[self._temporal_channel]
        named_modules = self.get_forward_sequence()
        # Got param generation shape to (B, C, H, W). Ignoring T.
        batch_shape = self.__infer_channel_exclusive_batch_shape__(batch_shape, self._temporal_channel)

        if not self.same_on_frame:
            # Overwrite param generation shape to (B * T, C, H, W).
            batch_shape = torch.Size([batch_shape[0] * frame_num, *batch_shape[1:]])

        params = []
        for name, module in named_modules:
            if isinstance(module, (SequentialBase,)):
                seq_param = module.forward_parameters(batch_shape)
                if self.same_on_frame:
                    raise ValueError("Sequential is currently unsupported for ``same_on_frame``.")
                param = ParamItem(name, seq_param)
            elif isinstance(module, (_AugmentationBase, MixAugmentationBase)):
                mod_param = module.forward_parameters(batch_shape)
                if self.same_on_frame:
                    for k, v in mod_param.items():
                        # TODO: revise colorjitter order param in the future to align the standard.
                        if not (k == "order" and isinstance(module, kornia.augmentation.ColorJitter)):
                            mod_param.update({k: self.__repeat_param_across_channels__(v, frame_num)})
                param = ParamItem(name, mod_param)
            else:
                param = ParamItem(name, None)
            params.append(param)
        return params

    def inverse(self, input: torch.Tensor, params: Optional[List[ParamItem]] = None) -> torch.Tensor:
        """Inverse transformation.

        Used to inverse a tensor according to the performed transformation by a forward pass, or with respect to
        provided parameters.
        """
        if self.apply_inverse_func in (InputApplyInverse, MaskApplyInverse):
            frame_num: int = input.size(self._temporal_channel)
            input, _ = self._input_shape_convert_in(input, None, frame_num)
        else:
            batch_size: int = input.size(0)
            input = input.view(-1, *input.shape[2:])

        input = super().inverse(input, params)
        if self.apply_inverse_func in (InputApplyInverse, MaskApplyInverse):
            input, _ = self._input_shape_convert_back(input, None, frame_num)
        else:
            input = input.view(batch_size, -1, *input.shape[1:])

        return input

    def forward(  # type: ignore
        self, input: torch.Tensor, label: Optional[torch.Tensor] = None, params: Optional[List[ParamItem]] = None
    ) -> Union[TensorWithTransformMat, Tuple[TensorWithTransformMat, torch.Tensor]]:
        """Define the video computation performed."""
        if len(input.shape) != 5:
            raise AssertionError(f"Input must be a 5-dim tensor. Got {input.shape}.")

        if params is None:
            params = self.forward_parameters(input.shape)

        # Size of T
        if self.apply_inverse_func in (InputApplyInverse, MaskApplyInverse):
            frame_num: int = input.size(self._temporal_channel)
            input, label = self._input_shape_convert_in(input, label, frame_num)
        else:
            if label is not None:
                raise ValueError(f"Invalid label value. Got {label}")
            batch_size: int = input.size(0)
            input = input.view(-1, *input.shape[2:])

        out = super().forward(input, label, params)  # type: ignore
        if self.return_label:
            output, label = cast(Tuple[TensorWithTransformMat, torch.Tensor], out)
        else:
            output = cast(TensorWithTransformMat, out)

        if isinstance(output, (tuple, list)):
            if self.apply_inverse_func in (InputApplyInverse, MaskApplyInverse):
                _out, label = self._input_shape_convert_back(output[0], label, frame_num)
                output = (_out, output[1])
            else:
                if label is not None:
                    raise ValueError(f"Invalid label value. Got {label}")
                output = output[0].view(batch_size, -1, *output[0].shape[1:])
        else:
            if self.apply_inverse_func in (InputApplyInverse, MaskApplyInverse):
                output, label = self._input_shape_convert_back(output, label, frame_num)
            else:
                if label is not None:
                    raise ValueError(f"Invalid label value. Got {label}")
                output = output.view(batch_size, -1, *output.shape[1:])

        return self.__packup_output__(output, label)