File size: 11,938 Bytes
0b69a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import torch
from torch import Tensor
from torchvision.transforms import ColorJitter as _ColorJitter
import torchvision.transforms.functional as TF
import numpy as np
from typing import Tuple, Union, Optional, Callable


def _crop(
    image: Tensor,
    label: Tensor,
    top: int,
    left: int,
    height: int,
    width: int,
) -> Tuple[Tensor, Tensor]:
    image = TF.crop(image, top, left, height, width)
    if len(label) > 0:
        label[:, 0] -= left
        label[:, 1] -= top
        label_mask = (label[:, 0] >= 0) & (label[:, 0] < width) & (label[:, 1] >= 0) & (label[:, 1] < height)
        label = label[label_mask]

    return image, label


def _resize(
    image: Tensor,
    label: Tensor,
    height: int,
    width: int,
) -> Tuple[Tensor, Tensor]:
    image_height, image_width = image.shape[-2:]
    image = TF.resize(image, (height, width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) if (image_height != height or image_width != width) else image
    if len(label) > 0 and (image_height != height or image_width != width):
        label[:, 0] = label[:, 0] * width / image_width
        label[:, 1] = label[:, 1] * height / image_height
        label[:, 0] = label[:, 0].clamp(min=0, max=width - 1)
        label[:, 1] = label[:, 1].clamp(min=0, max=height - 1)

    return image, label


class RandomCrop(object):
    def __init__(self, size: Tuple[int, int]) -> None:
        self.size = size
        assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        crop_height, crop_width = self.size
        image_height, image_width = image.shape[-2:]
        assert crop_height <= image_height and crop_width <= image_width, \
            f"crop size should be no larger than image size, got crop size {self.size} and image size {image.shape}."
        
        top = torch.randint(0, image_height - crop_height + 1, (1,)).item()
        left = torch.randint(0, image_width - crop_width + 1, (1,)).item()
        return _crop(image, label, top, left, crop_height, crop_width)


class Resize(object):
    def __init__(self, size: Tuple[int, int]) -> None:
        self.size = size
        assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        return _resize(image, label, self.size[0], self.size[1])


class Resize2Multiple(object):
    """
    Resize the image so that it satisfies:
        img_h = window_h + stride_h * n_h
        img_w = window_w + stride_w * n_w
    """
    def __init__(
        self,
        window_size: Tuple[int, int],
        stride: Tuple[int, int],
    ) -> None:
        window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
        window_size = tuple(window_size)
        stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
        stride = tuple(stride)
        assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
        assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
        assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
        assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
        assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
        self.window_size = window_size
        self.stride = stride

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        image_height, image_width = image.shape[-2:]
        window_height, window_width = self.window_size
        stride_height, stride_width = self.stride
        new_height = int(max(round((image_height - window_height) / stride_height), 0) * stride_height + window_height)
        new_width = int(max(round((image_width - window_width) / stride_width), 0) * stride_width + window_width)

        if new_height == image_height and new_width == image_width:
            return image, label
        else:
            return _resize(image, label, new_height, new_width)


class ZeroPad2Multiple(object):
    def __init__(
        self,
        window_size: Tuple[int, int],
        stride: Tuple[int, int],
    ) -> None:
        window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
        window_size = tuple(window_size)
        stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
        stride = tuple(stride)
        assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
        assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
        assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
        assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
        assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
        self.window_size = window_size
        self.stride = stride

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        image_height, image_width = image.shape[-2:]
        window_height, window_width = self.window_size
        stride_height, stride_width = self.stride
        new_height = int(max(np.ceil((image_height - window_height) / stride_height), 0) * stride_height + window_height)
        new_width = int(max(np.ceil((image_width - window_width) / stride_width), 0) * stride_width + window_width)

        if new_height == image_height and new_width == image_width:
            return image, label
        else:
            assert new_height >= image_height and new_width >= image_width, f"new size should be no less than the original size, got {new_height} and {new_width}."
            pad_height, pad_width = new_height - image_height, new_width - image_width
            return TF.pad(image, (0, 0, pad_width, pad_height), fill=0), label  # only pad the right and bottom sides so that the label coordinates are not affected


class RandomResizedCrop(object):
    def __init__(
        self,
        size: Tuple[int, int],
        scale: Tuple[float, float] = (0.75, 1.25),
    ) -> None:
        """
        Randomly crop an image and resize it to a given size. The aspect ratio is preserved during this process.
        """
        self.size = size
        self.scale = scale
        assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
        assert 0 < self.scale[0] <= self.scale[1], f"scale should satisfy 0 < scale[0] <= scale[1], got {self.scale}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        out_height, out_width = self.size
        # out_ratio = out_width / out_height

        scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()  # if scale < 1, then the image will be zoomed in, otherwise zoomed out
        in_height, in_width = image.shape[-2:]

        # if in_width / in_height < out_ratio:  # Image is too tall
        #     crop_width = int(in_width * scale)
        #     crop_height = int(crop_width / out_ratio)
        # else:  # Image is too wide
        #     crop_height = int(in_height * scale)
        #     crop_width = int(crop_height * out_ratio)

        crop_height, crop_width = int(out_height * scale), int(out_width * scale)

        if crop_height <= in_height and crop_width <= in_width:  # directly crop and resize the image
            top = torch.randint(0, in_height - crop_height + 1, (1,)).item()
            left = torch.randint(0, in_width - crop_width + 1, (1,)).item()

        else:  # resize the image and then crop
            ratio = max(crop_height / in_height, crop_width / in_width)  # keep the aspect ratio
            resize_height, resize_width = int(in_height * ratio) + 1, int(in_width * ratio) + 1  # add 1 to make sure the resized image is no less than the crop size
            image, label = _resize(image, label, resize_height, resize_width)
            
            top = torch.randint(0, resize_height - crop_height + 1, (1,)).item()
            left = torch.randint(0, resize_width - crop_width + 1, (1,)).item()

        image, label = _crop(image, label, top, left, crop_height, crop_width)
        return _resize(image, label, out_height, out_width)
        

class RandomHorizontalFlip(object):
    def __init__(self, p: float = 0.5) -> None:
        self.p = p
        assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        if torch.rand(1) < self.p:
            image = TF.hflip(image)

            if len(label) > 0:
                label[:, 0] = image.shape[-1] - 1 - label[:, 0]  # if width is 256, then 0 -> 255, 1 -> 254, 2 -> 253, etc.
                label[:, 0] = label[:, 0].clamp(min=0, max=image.shape[-1] - 1)

        return image, label
    

class ColorJitter(object):
    def __init__(
        self,
        brightness: Union[float, Tuple[float, float]] = 0.4,
        contrast: Union[float, Tuple[float, float]] = 0.4,
        saturation: Union[float, Tuple[float, float]] = 0.4,
        hue: Union[float, Tuple[float, float]] = 0.2,
    ) -> None:
        self.color_jitter = _ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
    
    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        return self.color_jitter(image), label
    

class RandomGrayscale(object):
    def __init__(self, p: float = 0.1) -> None:
        self.p = p
        assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        if torch.rand(1) < self.p:
            image = TF.rgb_to_grayscale(image, num_output_channels=3)

        return image, label
    

class GaussianBlur(object):
    def __init__(self, kernel_size: int, sigma: Tuple[float, float] = (0.1, 2.0)) -> None:
        self.kernel_size = kernel_size
        self.sigma = sigma

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        return TF.gaussian_blur(image, self.kernel_size, self.sigma), label


class RandomApply(object):
    def __init__(self, transforms: Tuple[Callable, ...], p: Union[float, Tuple[float, ...]] = 0.5) -> None:
        self.transforms = transforms
        p = [p] * len(transforms) if isinstance(p, float) else p
        assert all(0 <= p_ <= 1 for p_ in p), f"p should be in range [0, 1], got {p}."
        assert len(p) == len(transforms), f"p should be a float or a tuple of floats with the same length as transforms, got {p}."
        self.p = p

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        for transform, p in zip(self.transforms, self.p):
            if torch.rand(1) < p:
                image, label = transform(image, label)

        return image, label


class PepperSaltNoise(object):
    def __init__(self, saltiness: float = 0.001, spiciness: float = 0.001) -> None:
        self.saltiness = saltiness
        self.spiciness = spiciness
        assert 0 <= self.saltiness <= 1, f"saltiness should be in range [0, 1], got {self.saltiness}."
        assert 0 <= self.spiciness <= 1, f"spiciness should be in range [0, 1], got {self.spiciness}."

    def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        noise = torch.rand_like(image)
        image = torch.where(noise < self.saltiness, 1., image)  # Salt
        image = torch.where(noise > 1 - self.spiciness, 0., image)    # Pepper
        return image, label