File size: 7,594 Bytes
2a25b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
""""by lyuwenyu
"""


import torch 
import torch.nn as nn 

import torchvision
torchvision.disable_beta_transforms_warning()

try:
    from torchvision import datapoints as _datapoints
    _HAS_DATAPOINTS = True
except Exception:
    from torchvision import tv_tensors as _datapoints
    _HAS_DATAPOINTS = False

import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F

from PIL import Image 
from typing import Any, Dict, List, Optional

from src.core import register, GLOBAL_CONFIG


__all__ = ['Compose', ]


RandomPhotometricDistort = register(T.RandomPhotometricDistort)
RandomZoomOut = register(T.RandomZoomOut)
# RandomIoUCrop = register(T.RandomIoUCrop)
RandomHorizontalFlip = register(T.RandomHorizontalFlip)
Resize = register(T.Resize)

if hasattr(T, 'ToImageTensor'):
    ToImageTensor = register(T.ToImageTensor)
else:
    _BaseToImageTensor = getattr(T, 'ToImage', None) or getattr(T, 'ToTensor', None) or getattr(T, 'PILToTensor', None)
    if _BaseToImageTensor is None:
        raise AttributeError(
            'torchvision.transforms.v2 is missing ToImageTensor/ToImage/ToTensor/PILToTensor; please update torchvision.'
        )

    @register
    class ToImageTensor(_BaseToImageTensor):
        pass

if hasattr(T, 'ConvertDtype'):
    ConvertDtype = register(T.ConvertDtype)
else:
    _BaseConvertDtype = getattr(T, 'ToDtype', None)
    if _BaseConvertDtype is None:
        raise AttributeError('torchvision.transforms.v2 is missing ConvertDtype/ToDtype; please update torchvision.')

    @register
    class ConvertDtype(_BaseConvertDtype):
        def __init__(self, dtype: torch.dtype = torch.float32, scale: bool = True) -> None:
            super().__init__(dtype=dtype, scale=scale)

if hasattr(T, 'SanitizeBoundingBox'):
    _BaseSanitizeBoundingBox = T.SanitizeBoundingBox
else:
    _BaseSanitizeBoundingBox = getattr(T, 'SanitizeBoundingBoxes', None)
    if _BaseSanitizeBoundingBox is None:
        raise AttributeError(
            'torchvision.transforms.v2 is missing SanitizeBoundingBox/SanitizeBoundingBoxes; please update torchvision.'
        )

@register
class SanitizeBoundingBox(_BaseSanitizeBoundingBox):
    def forward(self, *inputs):
        # Avoid indexing t_gt (full-image mask) with per-box valid mask.
        if len(inputs) >= 2 and isinstance(inputs[1], dict) and "t_gt" in inputs[1]:
            inputs = list(inputs)
            target = dict(inputs[1])
            t_gt = target.pop("t_gt")
            inputs[1] = target
            outputs = super().forward(*inputs)
            if isinstance(outputs, tuple) and len(outputs) >= 2 and isinstance(outputs[1], dict):
                outputs = list(outputs)
                outputs[1]["t_gt"] = t_gt
                return tuple(outputs)
            return outputs
        return super().forward(*inputs)

RandomCrop = register(T.RandomCrop)
Normalize = register(T.Normalize)

_Image = _datapoints.Image
_Video = _datapoints.Video
_Mask = _datapoints.Mask
_BBoxFormat = _datapoints.BoundingBoxFormat
_BBoxType = _datapoints.BoundingBox if _HAS_DATAPOINTS else _datapoints.BoundingBoxes


def _make_bounding_box(data, format, spatial_size):
    fmt = format
    if _HAS_DATAPOINTS:
        return _datapoints.BoundingBox(data, format=fmt, spatial_size=spatial_size)
    return _datapoints.BoundingBoxes(data, format=fmt, canvas_size=spatial_size)


def _bbox_spatial_size(bbox):
    if hasattr(bbox, "spatial_size"):
        return bbox.spatial_size
    if hasattr(bbox, "canvas_size"):
        return bbox.canvas_size
    raise AttributeError("Bounding box object has neither spatial_size nor canvas_size.")



@register
class Compose(T.Compose):
    def __init__(self, ops) -> None:
        transforms = []
        if ops is not None:
            for op in ops:
                if isinstance(op, dict):
                    name = op.pop('type')
                    transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op)
                    transforms.append(transfom)
                    # op['type'] = name
                elif isinstance(op, nn.Module):
                    transforms.append(op)

                else:
                    raise ValueError('')
        else:
            transforms =[EmptyTransform(), ]
 
        super().__init__(transforms=transforms)


@register
class EmptyTransform(T.Transform):
    def __init__(self, ) -> None:
        super().__init__()

    def forward(self, *inputs):
        inputs = inputs if len(inputs) > 1 else inputs[0]
        return inputs


@register
class PadToSize(T.Pad):
    _transformed_types = (
        Image.Image,
        _Image,
        _Video,
        _Mask,
        _BBoxType,
    )
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        sz = F.get_spatial_size(flat_inputs[0])
        h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1]
        self.padding = [0, 0, w, h]
        return dict(padding=self.padding)

    def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        return self._get_params(flat_inputs)

    def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
        if isinstance(spatial_size, int):
            spatial_size = (spatial_size, spatial_size)
        
        self.spatial_size = spatial_size
        super().__init__(0, fill, padding_mode)

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:        
        fill = self._fill[type(inpt)]
        padding = params['padding']
        return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)  # type: ignore[arg-type]

    def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return self._transform(inpt, params)

    def __call__(self, *inputs: Any) -> Any:
        outputs = super().forward(*inputs)
        if len(outputs) > 1 and isinstance(outputs[1], dict):
            outputs[1]['padding'] = torch.tensor(self.padding)
        return outputs


@register
class RandomIoUCrop(T.RandomIoUCrop):
    def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0):
        super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials)
        self.p = p 

    def __call__(self, *inputs: Any) -> Any:
        if torch.rand(1) >= self.p:
            return inputs if len(inputs) > 1 else inputs[0]

        return super().forward(*inputs)


@register
class ConvertBox(T.Transform):
    _transformed_types = (
        _BBoxType,
    )
    def __init__(self, out_fmt='', normalize=False) -> None:
        super().__init__()
        self.out_fmt = out_fmt
        self.normalize = normalize

        self.data_fmt = {
            'xyxy': _BBoxFormat.XYXY,
            'cxcywh': _BBoxFormat.CXCYWH
        }

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:  
        if self.out_fmt:
            spatial_size = _bbox_spatial_size(inpt)
            in_fmt = inpt.format.value.lower()
            inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt)
            inpt = _make_bounding_box(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size)
        
        if self.normalize:
            spatial_size = _bbox_spatial_size(inpt)
            inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None]

        return inpt

    def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return self._transform(inpt, params)