File size: 8,674 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Basic data augmentation class."""

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import TypeVar, no_type_check

import torch

from vis4d.common.dict import get_dict_nested, set_dict_nested
from vis4d.data.typing import DictData

TFunctor = TypeVar("TFunctor", bound=object)  # pylint: disable=invalid-name
TransformFunction = Callable[[list[DictData]], list[DictData]]


class Transform:
    """Transforms Decorator.

    This class stores which `in_keys` are input to a transformation function
    and which `out_keys` are overwritten in the data dictionary by the output
    of this transformation.
    Nested keys in the data dictionary can be accessed via key.subkey1.subkey2
    If any of `in_keys` is 'data', the full data dictionary will be forwarded
    to the transformation.
    If the only entry in `out_keys` is 'data', the full data dictionary will
    be updated with the return value of the transformation.
    For the case of multi-sensor data, the sensors that the transform should be
    applied can be set via the 'sensors' attribute. By default, we assume
    a transformation is applied to all sensors.
    This class will add a 'apply_to_data' method to a given Functor which is
    used to call it on a DictData object. NOTE: This is an issue for static
    checking and is not recognized by pylint. It will usually be called in the
    compose() function and will not be called directly.

    Example:
        >>> @Transform(in_keys="images", out_keys="images")
        >>> class MyTransform:
        >>>     def __call__(images: list[np.array]) -> list[np.array]:
        >>>         images = do_something(images)
        >>>         return images
        >>> my_transform = MyTransform()
        >>> data = my_transform.apply_to_data(data)
    """

    def __init__(
        self,
        in_keys: Sequence[str] | str,
        out_keys: Sequence[str] | str,
        sensors: Sequence[str] | str | None = None,
        same_on_batch: bool = True,
    ) -> None:
        """Creates an instance of Transform.

        Args:
            in_keys (Sequence[str] | str): Specifies one or multiple (if any)
                input keys of the data dictionary which should be remapeed to
                another key. Defaults to None.
            out_keys (Sequence[str] | str): Specifies one or multiple (if any)
                output keys of the data dictionary which should be remaped to
                another key. Defaults to None.
            sensors (Sequence[str] | str | None, optional): Specifies the
                sensors this transformation should be applied to. If None, it
                will be applied to all available sensors. Defaults to None.
            same_on_batch (bool, optional): Whether to use the same
                transformation parameters to all sensors / view. Defaults to
                True.
        """
        if isinstance(in_keys, str):
            in_keys = [in_keys]
        self.in_keys = in_keys

        if isinstance(out_keys, str):
            out_keys = [out_keys]
        self.out_keys = out_keys

        if isinstance(sensors, str):
            sensors = [sensors]
        self.sensors = sensors

        self.same_on_batch = same_on_batch

    @no_type_check
    def __call__(self, transform: TFunctor) -> TFunctor:
        """Add in_keys / out_keys / sensors / apply_to_data attributes.

        Args:
            transform (TFunctor): A given Functor.

        Returns:
            TFunctor: The decorated Functor.
        """
        original_init = transform.__init__

        def apply_to_data(
            self_, input_batch: list[DictData]
        ) -> list[DictData]:
            """Wrap function with a handler for input / output keys.

            We use the specified in_keys in order to extract the positional
            input arguments of a function from the data dictionary, and the
            out_keys to replace the corresponding values in the output
            dictionary.
            """

            def _transform_fn(batch: list[DictData]) -> list[DictData]:
                in_batch = []
                for key in self_.in_keys:
                    key_data = []
                    for data in batch:
                        # Optionally allow the function to get the full data
                        # dict as aux input and set default value to None if
                        # key is not found
                        key_data += [
                            (
                                get_dict_nested(
                                    data, key.split("."), allow_missing=True
                                )
                                if key != "data"
                                else data
                            )
                        ]
                    if any(d is None for d in key_data):
                        # If any of the data in the batch is None, replace
                        # the input of the key with None.
                        in_batch.append(None)
                    else:
                        in_batch.append(key_data)

                result = self_(*in_batch)

                if len(self_.out_keys) == 1:
                    if self_.out_keys[0] == "data":
                        return result
                    result = [result]

                for key, values in zip(self_.out_keys, result):
                    if values is None:
                        continue
                    for data, value in zip(batch, values):
                        if value is not None:
                            set_dict_nested(data, key.split("."), value)
                return batch

            if self_.sensors is not None:
                if self_.same_on_batch:
                    for sensor in self_.sensors:
                        batch_sensor = _transform_fn(
                            [d[sensor] for d in input_batch]
                        )
                        for i, d in enumerate(batch_sensor):
                            input_batch[i][sensor] = d
                else:
                    for i, data in enumerate(input_batch):
                        for sensor in self_.sensors:
                            input_batch[i][sensor] = _transform_fn(
                                [data[sensor]]
                            )
            elif self_.same_on_batch:
                input_batch = _transform_fn(input_batch)
            else:
                for i, data in enumerate(input_batch):
                    input_batch[i] = _transform_fn([data])[0]

            return input_batch

        def init(
            *args,
            in_keys: Sequence[str] = self.in_keys,
            out_keys: Sequence[str] = self.out_keys,
            sensors: Sequence[str] | None = self.sensors,
            same_on_batch: bool = self.same_on_batch,
            **kwargs,
        ):
            self_ = args[0]
            original_init(*args, **kwargs)
            self_.in_keys = in_keys
            self_.out_keys = out_keys
            self_.sensors = sensors
            self_.same_on_batch = same_on_batch
            self_.apply_to_data = lambda *args, **kwargs: apply_to_data(
                self_, *args, **kwargs
            )

        transform.__init__ = init
        return transform


def compose(transforms: list[TFunctor]) -> TransformFunction:
    """Compose transformations.

    This function composes a given set of transformation functions, i.e. any
    functor decorated with Transform, into a single transform.
    """

    def _preprocess_func(batch: list[DictData]) -> list[DictData]:
        for op in transforms:
            batch = op.apply_to_data(batch)  # type: ignore
        return batch

    return _preprocess_func


@Transform("data", "data")
class RandomApply:
    """Randomize the application of a given set of transformations."""

    def __init__(
        self, transforms: list[TFunctor], probability: float = 0.5
    ) -> None:
        """Creates an instance of RandomApply.

        Args:
            transforms (list[TFunctor]): Transformations that are applied with
                a given probability.
            probability (float, optional): Probability to apply
                transformations. Defaults to 0.5.
        """
        self.transforms = transforms
        self.probability = probability

    def __call__(self, batch: list[DictData]) -> list[DictData]:
        """Apply transforms with a given probability."""
        if torch.rand(1) < self.probability:
            for op in self.transforms:
                batch = op.apply_to_data(batch)  # type: ignore
        return batch