Transformers
jespark commited on
Commit
21900f4
·
verified ·
1 Parent(s): 7ebc30a

Upload custom_transforms.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_transforms.py +1078 -0
custom_transforms.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple, Optional, Dict
3
+ import os
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import numpy as np
8
+ import PIL.Image
9
+ import random
10
+ from io import BytesIO
11
+ import cv2
12
+ import numpy as np
13
+
14
+ from torchvision.transforms import functional as F, InterpolationMode
15
+ import torchvision.transforms as T
16
+
17
+ __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]
18
+
19
+ def get_dimensions(img):
20
+ height, width = F.get_image_size(img)
21
+ channels = F.get_image_num_channels(img)
22
+ return channels, height, width
23
+
24
+ def cutout(img, pad_size, replace=0):
25
+ """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
26
+
27
+ ### (PyTorch implementation of Google's big_vision cutout) ###
28
+
29
+ This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
30
+ a random location within `img`. The pixel values filled in will be of the
31
+ value `replace`. The located where the mask will be applied is randomly
32
+ chosen uniformly over the whole image.
33
+ Args:
34
+ image: A PIL image
35
+ pad_size: Specifies how big the zero mask that will be generated is that
36
+ is applied to the image. The mask will be of size
37
+ (2*pad_size x 2*pad_size).
38
+ replace: What pixel value to fill in the image in the area that has
39
+ the cutout mask applied to it.
40
+ Returns:
41
+ A PIL image of type uint8.
42
+ """
43
+ convert_back=False
44
+ if F._is_pil_image(img):
45
+ img = F.pil_to_tensor(img) # convert to tensor for pytorch operations
46
+ convert_back=True
47
+ assert img.dtype == torch.uint8, "PIL to tensor image is expected to have torch.unit8 as dtype."
48
+ channels, height, width = get_dimensions(img)
49
+ cutout_center_height = torch.randint(low=0, high=height, size=(1,)).item()
50
+ cutout_center_width = torch.randint(low=0, high=width, size=(1,)).item()
51
+
52
+ lower_pad = max(0, cutout_center_height - pad_size)
53
+ upper_pad = max(0, height - cutout_center_height - pad_size)
54
+ left_pad = max(0, cutout_center_width - pad_size)
55
+ right_pad = max(0, width - cutout_center_width - pad_size)
56
+
57
+ cutout_shape = (height - (lower_pad + upper_pad),
58
+ width - (left_pad + right_pad)) # cutout this shape
59
+ padding_dims = (left_pad, right_pad, upper_pad, lower_pad)
60
+ cutout_mask = torch.nn.functional.pad(
61
+ torch.zeros(cutout_shape, dtype=img.dtype, device=img.device),
62
+ padding_dims, value=1
63
+ )
64
+ cutout_mask = cutout_mask.unsqueeze(dim=0)
65
+ cutout_mask = torch.tile(cutout_mask, (channels,1,1))
66
+ #replacement = torch.ones_like(img, dtype=torch.float32) * replace[0]
67
+ #replacement = replacement.to(torch.uint8)
68
+ img = torch.where(
69
+ cutout_mask==0, # condition.
70
+ torch.ones_like(img, dtype=img.dtype, device=img.device) * replace, # If true
71
+ #replacement,
72
+ img # If condition is false
73
+ )
74
+ if convert_back:
75
+ return F.to_pil_image(img)
76
+ else:
77
+ return img
78
+
79
+ def solarize_add(img, addition=0, threshold=128):
80
+ """
81
+ For each pixel in the image less than threshold
82
+ we add 'addition' amount to it and then clip the
83
+ pixel value to be between 0 and 255. The value
84
+ of 'addition' is between -128 and 128.
85
+
86
+ ### Re-implementation of Google's big_vision in PyTorch ###
87
+ """
88
+ convert_back=False
89
+ if F._is_pil_image(img):
90
+ img = F.pil_to_tensor(img) # convert to tensor for pytorch operations
91
+ convert_back=True
92
+ assert img.dtype == torch.uint8, "PIL to tensor image is expected to have torch.unit8 as dtype."
93
+ added_img = img.to(torch.int) + addition
94
+ added_img = torch.clamp(added_img, min=0,max=255)
95
+ added_img = added_img.to(img.dtype)
96
+ img = torch.where(
97
+ img < threshold, # condition
98
+ added_img, # if true
99
+ img # if false
100
+ )
101
+ if convert_back:
102
+ return F.to_pil_image(img)
103
+ else:
104
+ return img
105
+
106
+ def chroma_drop(img):
107
+ img = img.convert("YCbCr")
108
+ Y, Cb, Cr = img.split()
109
+ if torch.rand(1).item() > 0.5:
110
+ Cr = Cr.point(lambda i: 128)
111
+ else:
112
+ Cb = Cb.point(lambda i: 128)
113
+ img = PIL.Image.merge("YCbCr", (Y, Cb, Cr))
114
+ return img.convert("RGB")
115
+
116
+ def auto_saturation_separate(img):
117
+ img = img.convert("YCbCr")
118
+ Y, Cb, Cr = img.split()
119
+ Cbmin, Cbmax = Cb.getextrema()
120
+ Crmin, Crmax = Cr.getextrema()
121
+ Cmin = min(Cbmin, Crmin)
122
+ Cmax = max(Cbmax, Crmax)
123
+ Cb = Cb.point(lambda i: ((i-128) / (Cmax - 128) * 127 + 128 if Cmax > 128 else i) if i>127 \
124
+ else ((i - Cmin) / (127 - Cmin) * 127) if Cmin<127 else i) # scale >127 and else separately (they represent different hue)
125
+ #Cb = Cb.point(lambda i: (i-Cbmin) / (Cbmax - Cbmin) * 255)
126
+ Cr = Cr.point(lambda i: ((i-128) / (Cmax - 128) * 127 + 128 if Cmax > 128 else i) if i>127 \
127
+ else ((i - Cmin) / (127 - Cmin) * 127) if Cmin<127 else i)
128
+ #Cr = Cr.point(lambda i: (i-Crmin) / (Crmax - Crmin) * 255)
129
+ img = PIL.Image.merge("YCbCr", (Y, Cb, Cr))
130
+ return img.convert("RGB")
131
+
132
+
133
+ def auto_saturation(img):
134
+ img = img.convert("YCbCr")
135
+ Y, Cb, Cr = img.split()
136
+ Cbmin, Cbmax = Cb.getextrema()
137
+ Crmin, Crmax = Cr.getextrema()
138
+ Cmin = min(Cbmin, Crmin)
139
+ Cmax = max(Cbmax, Crmax)
140
+ Cb = Cb.point(lambda i: (i-Cmin) / (Cmax - Cmin) * 255 if (Cmax - Cmin) != 0 else i)
141
+ Cr = Cr.point(lambda i: (i-Cmin) / (Cmax - Cmin) * 255 if (Cmax - Cmin) != 0 else i)
142
+ img = PIL.Image.merge("YCbCr", (Y, Cb, Cr))
143
+ return img.convert("RGB")
144
+
145
+ def _apply_op(
146
+ img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
147
+ ):
148
+ if op_name == "ShearX":
149
+ # magnitude should be arctan(magnitude)
150
+ # official autoaug: (1, level, 0, 0, 1, 0)
151
+ # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
152
+ # compared to
153
+ # torchvision: (1, tan(level), 0, 0, 1, 0)
154
+ # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
155
+ img = F.affine(
156
+ img,
157
+ angle=0.0,
158
+ translate=[0, 0],
159
+ scale=1.0,
160
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
161
+ interpolation=interpolation,
162
+ fill=fill,
163
+ center=[0, 0],
164
+ )
165
+ elif op_name == "ShearY":
166
+ # magnitude should be arctan(magnitude)
167
+ # See above
168
+ img = F.affine(
169
+ img,
170
+ angle=0.0,
171
+ translate=[0, 0],
172
+ scale=1.0,
173
+ shear=[0.0, math.degrees(math.atan(magnitude))],
174
+ interpolation=interpolation,
175
+ fill=fill,
176
+ center=[0, 0],
177
+ )
178
+ elif op_name == "TranslateX":
179
+ img = F.affine(
180
+ img,
181
+ angle=0.0,
182
+ translate=[int(magnitude), 0],
183
+ scale=1.0,
184
+ interpolation=interpolation,
185
+ shear=[0.0, 0.0],
186
+ fill=fill,
187
+ )
188
+ elif op_name == "TranslateY":
189
+ img = F.affine(
190
+ img,
191
+ angle=0.0,
192
+ translate=[0, int(magnitude)],
193
+ scale=1.0,
194
+ interpolation=interpolation,
195
+ shear=[0.0, 0.0],
196
+ fill=fill,
197
+ )
198
+ elif op_name == "Rotate":
199
+ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
200
+ elif op_name == "Brightness":
201
+ img = F.adjust_brightness(img, 1.0 + magnitude)
202
+ elif op_name == "Color":
203
+ img = F.adjust_saturation(img, 1.0 + magnitude)
204
+ elif op_name == "Contrast":
205
+ img = F.adjust_contrast(img, 1.0 + magnitude)
206
+ elif op_name == "Sharpness":
207
+ img = F.adjust_sharpness(img, 1.0 + magnitude)
208
+ elif op_name == "Posterize":
209
+ img = F.posterize(img, int(magnitude))
210
+ elif op_name == "Solarize":
211
+ img = F.solarize(img, magnitude)
212
+ elif op_name == "AutoContrast":
213
+ img = F.autocontrast(img)
214
+ elif op_name == "Equalize":
215
+ img = F.equalize(img)
216
+ elif op_name == "Invert":
217
+ img = F.invert(img)
218
+ elif op_name == "Identity":
219
+ pass
220
+ elif op_name == 'Cutout': # added
221
+ img = cutout(img, int(magnitude), replace=fill)
222
+ elif op_name == "SolarizeAdd": # added
223
+ img = solarize_add(img, int(magnitude))
224
+ elif op_name == "Grayscale": # added v2
225
+ img = F.to_grayscale(img, num_output_channels=3)
226
+ elif op_name == "ChromaDrop": #
227
+ img = chroma_drop(img)
228
+ elif op_name == "AutoSaturation":
229
+ #img = auto_saturation(img)
230
+ img = auto_saturation(img) # dct-equivalent
231
+ elif op_name == "AutoSaturation_old": # for compatibility purposes
232
+ img = auto_saturation(img)
233
+ elif op_name == "Rotate90": # magnitude is +- 90
234
+ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
235
+ else:
236
+ raise ValueError(f"The provided operator {op_name} is not recognized.")
237
+ return img
238
+
239
+
240
+
241
+ class RandAugment_bv(torch.nn.Module):
242
+ r"""RandAugment data augmentation method based on
243
+ `"RandAugment: Practical automated data augmentation with a reduced search space"
244
+ <https://arxiv.org/abs/1909.13719>`_.
245
+
246
+ ### Re-implementation of Google's Big Vision randaugment in PyTorch ###
247
+
248
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
249
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
250
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
251
+
252
+ Args:
253
+ num_ops (int): Number of augmentation transformations to apply sequentially.
254
+ magnitude (int): Magnitude for all the transformations.
255
+ num_magnitude_bins (int): The number of different magnitude values.
256
+ interpolation (InterpolationMode): Desired interpolation enum defined by
257
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
258
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
259
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
260
+ image. If given a number, the value is used for all bands respectively.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ num_ops: int = 2,
266
+ magnitude: int = 10,
267
+ num_magnitude_bins: int = 11,
268
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
269
+ fill: Optional[List[float]] = None,
270
+ ops_list = ["AutoContrast", "Equalize", "Invert", "Rotate", "Posterize", "Solarize", "SolarizeAdd", "Color", "Contrast", "Brightness",
271
+ "Sharpness", "ShearX", "ShearY", "Cutout", "TranslateX", "TranslateY"]
272
+ ) -> None:
273
+ super().__init__()
274
+ self.num_ops = num_ops
275
+ self.magnitude = magnitude
276
+ self.num_magnitude_bins = num_magnitude_bins
277
+ self.interpolation = interpolation
278
+ self.fill = fill
279
+ if ops_list==None:
280
+ self.ops_list = ["AutoContrast", "Equalize", "Invert", "Rotate", "Posterize", "Solarize", "SolarizeAdd", "Color", "Contrast", "Brightness",
281
+ "Sharpness", "ShearX", "ShearY", "Cutout", "TranslateX", "TranslateY"]
282
+ else:
283
+ self.ops_list = ops_list
284
+
285
+ def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
286
+ return {
287
+ # op_name: (magnitudes, signed)
288
+ #"Identity": (torch.tensor(0.0), False), not needed
289
+ "AutoContrast": (torch.tensor(0.0), False),
290
+ "Equalize": (torch.tensor(0.0), False),
291
+ "Invert": (torch.tensor(0.0), False), # added
292
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
293
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
294
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
295
+ "SolarizeAdd": (torch.linspace(0, 110, num_bins), False), # added
296
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
297
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
298
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
299
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
300
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
301
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
302
+ "Cutout": (torch.linspace(0, 40, num_bins), False), #added
303
+ "TranslateX": (torch.linspace(0.0, 150.0 / 336.0 * image_size[1], num_bins), True),
304
+ "TranslateY": (torch.linspace(0.0, 150.0 / 336.0 * image_size[0], num_bins), True),
305
+ "Grayscale": (torch.tensor(0.0), False),
306
+ "ChromaDrop": (torch.tensor(0.0), False),
307
+ "AutoSaturation": (torch.tensor(0.0), False),
308
+ "AutoSaturation_old": (torch.tensor(0.0), False),
309
+ "Rotate90": (torch.tensor(90.0), True),
310
+ }
311
+
312
+
313
+ def forward(self, img: Tensor) -> Tensor:
314
+ """
315
+ img (PIL Image or Tensor): Image to be transformed.
316
+
317
+ Returns:
318
+ PIL Image or Tensor: Transformed image.
319
+ """
320
+ fill = self.fill
321
+ channels, height, width = get_dimensions(img)
322
+ #if isinstance(img, Tensor):
323
+ # if isinstance(fill, (int, float)):
324
+ # fill = [float(fill)] * channels
325
+ # elif fill is not None:
326
+ # fill = [float(f) for f in fill]
327
+
328
+ op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
329
+ for _ in range(self.num_ops):
330
+ op_index = int(torch.randint(len(self.ops_list), (1,)).item())
331
+ op_name = list(self.ops_list)[op_index]
332
+ magnitudes, signed = op_meta[op_name]
333
+ magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
334
+ if signed and torch.randint(2, (1,)):
335
+ magnitude *= -1.0
336
+ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
337
+
338
+ return img
339
+
340
+
341
+ def __repr__(self) -> str:
342
+ s = (
343
+ f"{self.__class__.__name__}("
344
+ f"num_ops={self.num_ops}"
345
+ f", magnitude={self.magnitude}"
346
+ f", num_magnitude_bins={self.num_magnitude_bins}"
347
+ f", interpolation={self.interpolation}"
348
+ f", fill={self.fill}"
349
+ f")"
350
+ )
351
+ return s
352
+
353
+
354
+ class ToTensor_range(torch.nn.Module):
355
+ r"""
356
+ Converts PIL image to Tensor into a specified range
357
+
358
+ Args:
359
+ val_min = minimum value after convert
360
+ val_max = maximum value after convert
361
+ dtype = dtype after convert (default=torch.float32)
362
+
363
+ Returns:
364
+ Converted Torch Tensor
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ val_min: float = -1.,
370
+ val_max: float = 1.,
371
+ dtype = torch.float32,
372
+ ) -> Tensor:
373
+ super().__init__()
374
+ self.val_min = val_min
375
+ self.val_max = val_max
376
+ self.dtype = dtype
377
+
378
+ def forward(self, img) -> Tensor:
379
+ """
380
+ img (PIL Image): Image to be transformed.
381
+
382
+ Returns:
383
+ Tensor: Converted Image
384
+ """
385
+ #assert F._is_pil_image(img), "Input should be a PIL image (ToTensor_range transform)"
386
+ if F._is_pil_image(img):
387
+ img = F.to_tensor(img) # to_tensor normalizes data to (0,1)
388
+ img = img.to(self.dtype) # convert dtype
389
+ img = self.val_min + (img * (self.val_max - self.val_min)) # scale to val_min to val_max
390
+
391
+ return img
392
+
393
+ def __repr__(self) -> str:
394
+ s = (
395
+ f"{self.__class__.__name__}("
396
+ f"val_min={self.val_min}"
397
+ f", val_max={self.val_max}"
398
+ f", dtype={self.dtype}"
399
+ f")"
400
+ )
401
+ return s
402
+
403
+ def apply_PILJPEG(img, quality):
404
+ buffer = BytesIO()
405
+ img.save(buffer, format="JPEG", quality=quality)
406
+ buffer.seek(0) # move pointer to 0 so we can read them
407
+ img = PIL.Image.open(buffer).convert("RGB")
408
+ return img
409
+
410
+ def apply_cv2JPEG(img, quality):
411
+ # convert PIL image to cv2 image
412
+ img_cv2 = np.array(img)
413
+ img_cv2 = img_cv2[:,:,::-1]
414
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
415
+ result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
416
+ decimg = cv2.imdecode(encimg, 1)
417
+ return PIL.Image.fromarray(decimg[:,:,::-1])
418
+
419
+ def apply_randomJPEG(img, quality):
420
+ if random.random() < 0.5:
421
+ img = apply_PILJPEG(img, quality) # randomly apply PIL or CV2
422
+ else:
423
+ img = apply_cv2JPEG(img, quality)
424
+ return img
425
+
426
+ def resize_with_random_intpl(img, size):
427
+ """
428
+ Perform resizing with random interpolation
429
+ """
430
+ #intp_list = [InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC, InterpolationMode.LANCZOS, InterpolationMode.HAMMING, InterpolationMode.BOX]
431
+ intp_list = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
432
+ #interp_idx = random.randint(0, len(intp_list)-1)
433
+ interp = random.choice(intp_list)
434
+ # random interpolation somehow doesn't work
435
+ img = F.resize(img, size, interpolation=interp)
436
+ return img
437
+
438
+ class RandomResizeWithRandomIntpl(torch.nn.Module):
439
+ r"""
440
+ Reads PIL Image. Resizes with random interpolation. Returns torch tensor.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ size_range: int=(112,448),
446
+ ) -> Tensor:
447
+ super().__init__()
448
+ self.size_range = size_range
449
+
450
+ def forward(self, img) -> Tensor:
451
+ """
452
+ Args:
453
+ img: PIL image to be transformed.
454
+
455
+ Returns:
456
+ Tensor: Converted Image
457
+ """
458
+ assert F._is_pil_image(img), "Input should be a PIL image (RandomResizeWithRandomIntpl transform)"
459
+ # add resize
460
+ img = resize_with_random_intpl(img, random.randint(self.size_range[0], self.size_range[1]))
461
+ return img
462
+
463
+ def __repr__(self) -> str:
464
+ s = (
465
+ f"{self.__class__.__name__}()"
466
+ f" size_range={self.size_range}"
467
+ f")"
468
+ )
469
+
470
+ class ResizeWithRandomIntpl(torch.nn.Module):
471
+ r"""
472
+ Reads PIL Image. Resizes with random interpolation. Returns torch tensor.
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ size: int,
478
+ ) -> Tensor:
479
+ super().__init__()
480
+ self.size = size
481
+
482
+ def forward(self, img) -> Tensor:
483
+ """
484
+ Args:
485
+ img: PIL image to be transformed.
486
+
487
+ Returns:
488
+ Tensor: Converted Image
489
+ """
490
+ assert F._is_pil_image(img), "Input should be a PIL image (ResizeWithRandomIntpl transform)"
491
+ # add resize
492
+ img = resize_with_random_intpl(img, self.size)
493
+ return img
494
+
495
+ def __repr__(self) -> str:
496
+ s = (
497
+ f"{self.__class__.__name__}("
498
+ f" size={self.size}"
499
+ f")"
500
+ )
501
+ return s
502
+
503
+ class RRCWithRandomIntpl(T.RandomResizedCrop):
504
+ r"""
505
+ Reads PIL Image. Randomly resized crop with random interpolation. Returns torch tensor.
506
+ """
507
+
508
+ def __init__(
509
+ self,
510
+ size: int,
511
+ scale: Tuple[float, float] = (0.08, 1.0),
512
+ ratio: Tuple[float, float] = (3./4., 4./3.),
513
+ ) -> Tensor:
514
+ super().__init__(size=size, scale=scale, ratio=ratio)
515
+ self.size = size
516
+ self.scale = scale
517
+ self.ratio = ratio
518
+ self.intp_list=[InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC, InterpolationMode.LANCZOS, InterpolationMode.HAMMING, InterpolationMode.BOX]
519
+
520
+ def forward(self, img) -> Tensor:
521
+ """
522
+ Args:
523
+ img: PIL image to be transformed.
524
+
525
+ Returns:
526
+ Tensor: Converted Image
527
+ """
528
+ assert F._is_pil_image(img), "Input should be a PIL image (RRCWithRandomIntpl transform)"
529
+ # add resize
530
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
531
+ #interp_idx = random.randint(0, len(self.intp_list)-1)
532
+ interp = random.choice(self.intp_list) # somehow doesn't work. Gives me error: TypeError: resized_crop() got multiple values for argument 'interpolation'
533
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation=interp)
534
+
535
+ def __repr__(self) -> str:
536
+ s = (
537
+ f"{self.__class__.__name__}("
538
+ f" size={self.size}"
539
+ f", scale={self.scale}"
540
+ f", ratio={self.ratio}"
541
+ f")"
542
+ )
543
+ return s
544
+
545
+ class JPEGinMemory(torch.nn.Module):
546
+ r"""
547
+ Reads PIL Image. Compress JPEG in memory. Returns PIL Image.
548
+
549
+ """
550
+
551
+ def __init__(
552
+ self,
553
+ quality_range = (30, 100),
554
+ method: str = "cv,pil",
555
+ dtype = torch.float32,
556
+ ) -> Tensor:
557
+ super().__init__()
558
+ self.quality_range = quality_range
559
+ self.method = method.lower().split(',')
560
+ self.dtype = dtype
561
+
562
+ def forward(self, img) -> Tensor:
563
+ """
564
+ Args:
565
+ img: PIL image to be transformed.jdt
566
+
567
+ Returns:
568
+ Tensor: Converted Image
569
+ """
570
+ assert F._is_pil_image(img), "Input should be a PIL image (ResizeAndJPEGinMemory transform)"
571
+ if "cv" in self.method and "pil" in self.method:
572
+ img = apply_randomJPEG(img, random.randint(self.quality_range[0], self.quality_range[1]))
573
+ elif "cv" in self.method:
574
+ img = apply_cv2JPEG(img, random.randint(self.quality_range[0], self.quality_range[1]))
575
+ elif "pil" in self.method:
576
+ img = apply_PILJPEG(img, random.randint(self.quality_range[0], self.quality_range[1]))
577
+ return img
578
+
579
+ def __repr__(self) -> str:
580
+ s = (
581
+ f"{self.__class__.__name__}("
582
+ f", quality_range={self.quality_range}"
583
+ f", dtype={self.dtype}"
584
+ f")"
585
+ )
586
+ return s
587
+
588
+ class ResizeAndJPEGinMemory(torch.nn.Module):
589
+ r"""
590
+ Reads PIL Image. Resizes and compresses to JPEG in memory. Returns torch tensor.
591
+
592
+ """
593
+
594
+ def __init__(
595
+ self,
596
+ size: int,
597
+ quality: int = 95,
598
+ method: str = "cv,pil",
599
+ dtype = torch.float32,
600
+ ) -> Tensor:
601
+ super().__init__()
602
+ self.size = size
603
+ self.quality = quality
604
+ self.method = method.lower().split(',')
605
+ self.dtype = dtype
606
+
607
+ def forward(self, img) -> Tensor:
608
+ """
609
+ Args:
610
+ img: PIL image to be transformed.
611
+
612
+ Returns:
613
+ Tensor: Converted Image
614
+ """
615
+ assert F._is_pil_image(img), "Input should be a PIL image (ResizeAndJPEGinMemory transform)"
616
+ # add resize
617
+ img = F.resize(img, self.size, interpolation=InterpolationMode.BILINEAR) # this is the right way to resize! If torchvision updates, make sure that this resizes the smaller side to the specified size and keeps the aspect ratio
618
+ if "cv" in self.method and "pil" in self.method:
619
+ img = apply_randomJPEG(img, self.quality)
620
+ elif "cv" in self.method:
621
+ img = apply_cv2JPEG(img, self.quality)
622
+ elif "pil" in self.method:
623
+ img = apply_PILJPEG(img, self.quality)
624
+ return img
625
+
626
+ def __repr__(self) -> str:
627
+ s = (
628
+ f"{self.__class__.__name__}("
629
+ f" size={self.size}"
630
+ f", quality={self.quality}"
631
+ f", dtype={self.dtype}"
632
+ f")"
633
+ )
634
+ return s
635
+
636
+ class StochasticJPEG(torch.nn.Module):
637
+ r"""
638
+ Stochastically applies multiple JPEG compression and resizing to an image.
639
+ """
640
+
641
+ def __init__(
642
+ self,
643
+ size: int, # final output size
644
+ quality: Tuple[int, int] = (50, 100), # quality range
645
+ num_jpeg: Tuple[int, int] = (1, 5), # number of jpegs to apply
646
+ jpeg_p: float = 0.5, # probability of applying JPEG compression
647
+ rrc_p: float = 0.5, # probability of applying random resized crop
648
+ rrc_scale: Tuple[float, float] = (0.75, 1.0), # random resize crop scale
649
+ rrc_ratio: Tuple[float, float] = (3./4., 4./3.), # random resize crop ratio
650
+ no_rrc: bool = False, # if True, no random resized crop is applied
651
+ dtype: type = torch.float32,
652
+ ) -> Tensor:
653
+ """
654
+ Initialize the CustomTransforms class.
655
+
656
+ Args:
657
+ size (int): The final output size.
658
+ quality (Tuple[int, int]): The quality range as a tuple of two integers.
659
+ num_jpeg (Tuple[int, int]): The number of jpegs to apply as a tuple of two integers.
660
+ p (float): The probability of applying the transform.
661
+ rrc_scale (Tuple[float, float]): The random resize crop scale as a tuple of two floats.
662
+ rrc_ratio (Tuple[float, float]): The random resize crop ratio as a tuple of two floats.
663
+ no_rrc (bool): If True, no random resized crop is applied.
664
+ dtype (type): The data type of the tensor.
665
+
666
+ Returns:
667
+ Tensor: The initialized CustomTransforms object.
668
+ """
669
+ super().__init__()
670
+ self.size = size
671
+ self.quality = quality
672
+ self.num_jpeg = num_jpeg
673
+ self.jpeg_p = jpeg_p
674
+ self.rrc_p = rrc_p
675
+ self.rrc = torch.nn.Identity() if no_rrc else T.RandomResizedCrop(size=size, scale=rrc_scale, ratio=rrc_ratio, interpolation=InterpolationMode.BILINEAR)
676
+ self.dtype = dtype
677
+
678
+ def forward(self, img) -> Tensor:
679
+ """
680
+ Args:
681
+ img: PIL image to be transformed.
682
+
683
+ Returns:
684
+ Tensor: Converted Image
685
+ """
686
+ assert F._is_pil_image(img), "Input should be a PIL image (StochasticJPEG transform)"
687
+
688
+ # randomly sample p
689
+ count = self.num_jpeg[0]
690
+ for _ in range(self.num_jpeg[0]): # apply min number of jpegs and RRC first
691
+ img = self.rrc(img)
692
+ img = apply_randomJPEG(img, random.randint(self.quality[0], self.quality[1]))
693
+
694
+ while count < self.num_jpeg[1]:
695
+ if random.random() < self.p: # apply more jpegs with set probability.
696
+ img = self.rrc(img)
697
+ img = apply_randomJPEG(img, random.randint(self.quality[0], self.quality[1]))
698
+ count += 1
699
+ else:
700
+ break
701
+
702
+ return img
703
+
704
+ class RandomJPEG(torch.nn.Module):
705
+ """
706
+ Randomly applies JPEG
707
+ Args:
708
+ quality: tuple of quality value range for JPEG
709
+ p: probability of applying JPEG
710
+ """
711
+ def __init__(
712
+ self,
713
+ quality_list: tuple = (30, 100),
714
+ p: float = 0.5,
715
+ ):
716
+ super().__init__()
717
+ self.quality_list = quality_list
718
+ self.p = p
719
+
720
+ def forward(self, img):
721
+ if random.random() < self.p:
722
+ img = apply_randomJPEG(img, random.randint(self.quality_list[0], self.quality_list[1]))
723
+ return img
724
+
725
+ class RandomGaussianBlur(torch.nn.Module):
726
+ """
727
+ Randomly applies Gaussian Blur
728
+ Args:
729
+ p: probability of applying JPEG
730
+ sigma: tuple of sigma values for Gaussian Blur
731
+ """
732
+ def __init__(
733
+ self,
734
+ p: float = 0.5,
735
+ sigma: Tuple[float, float] = (0.0, 3.0),
736
+ ):
737
+ super().__init__()
738
+ self.p = p
739
+ self.sigma = sigma
740
+
741
+ def forward(self, img):
742
+ if random.random() < self.p:
743
+ sigma=random.uniform(self.sigma[0], self.sigma[1])
744
+ kernel_size=1+2*round(sigma*4.0) # default sigma used in scipy (https://github.com/scipy/scipy/blob/v1.13.1/scipy/ndimage/_filters.py#L286-L390)
745
+ img = F.gaussian_blur(img, kernel_size=kernel_size, sigma=sigma)
746
+ return img
747
+
748
+ class RandomPaddingAndResize(torch.nn.Module):
749
+ r"""
750
+ Reads PIL Image. Randomly applies padding, and resize it back to original resolution.
751
+
752
+ """
753
+
754
+ def __init__(
755
+ self,
756
+ pad_percentage_range = (0.1, 0.1), # random padding percentage for x (width) and y (height)
757
+ padding_value_range = (0, 255), # random padding value range
758
+ ) -> Tensor:
759
+ super().__init__()
760
+ self.pad_percentage_range = pad_percentage_range
761
+ self.padding_value_range = padding_value_range
762
+
763
+ def forward(self, img) -> Tensor:
764
+ """
765
+ Args:
766
+ img: PIL image to be transformed.jdt
767
+
768
+ Returns:
769
+ Tensor: Converted Image
770
+ """
771
+ assert F._is_pil_image(img), "Input should be a PIL image (ResizeAndJPEGinMemory transform)"
772
+ original_size = img.size
773
+ pad_x_l = random.uniform(0, self.pad_percentage_range[0]/2) # x-axis random padding ratio (left)
774
+ pad_x_r = random.uniform(0, self.pad_percentage_range[0]/2) # x-axis random padding ratio (right)
775
+ pad_y_l = random.uniform(0, self.pad_percentage_range[1]/2) # y-axis random padding ratio (left)
776
+ pad_y_r = random.uniform(0, self.pad_percentage_range[1]/2) # y-axis random padding ratio (right)
777
+ pad_fill = random.randint(int(self.padding_value_range[0]), int(self.padding_value_range[1])) # random padding fill value
778
+ img = F.pad(img, (int(pad_x_l*img.size[0]), int(pad_y_l*img.size[1]), int(pad_x_r*img.size[0]), int(pad_y_r*img.size[1])), fill=pad_fill, padding_mode='constant')
779
+ img = F.resize(img, original_size, interpolation=InterpolationMode.BILINEAR)
780
+ return img
781
+
782
+ def __repr__(self) -> str:
783
+ s = (
784
+ f"{self.__class__.__name__}("
785
+ f", pad_percentage_range={self.pad_percentage_range}"
786
+ f", padding_value_range={self.padding_value_range}"
787
+ f")"
788
+ )
789
+ return s
790
+
791
+ class RandomCutout(T.RandomErasing):
792
+ r"""
793
+ Random cutout with random numbers
794
+ """
795
+ def __init__(
796
+ self,
797
+ p=0.5,
798
+ scale=(0.02, 0.33),
799
+ ratio=(0.3, 3.3),
800
+ value_range=(0, 255),
801
+ ):
802
+ super().__init__(p=p, scale=scale, ratio=ratio)
803
+ self.value_range = value_range
804
+
805
+ def forward(self, img):
806
+ convert_to_pil=False
807
+ if F._is_pil_image(img):
808
+ img = F.pil_to_tensor(img)
809
+ convert_to_pil=True
810
+ if torch.rand(1) < self.p:
811
+ rand_value = random.randint(self.value_range[0], self.value_range[1])
812
+ # cast self.value to script acceptable type
813
+ if isinstance(rand_value, (int, float)):
814
+ rand_value = [float(rand_value)]
815
+ elif isinstance(rand_value, str):
816
+ rand_value = None
817
+ elif isinstance(rand_value, (list, tuple)):
818
+ rand_value = [float(v) for v in rand_value]
819
+ else:
820
+ rand_value = rand_value
821
+
822
+ if rand_value is not None and not (len(rand_value) in (1, img.shape[-3])):
823
+ raise ValueError(
824
+ "If value is a sequence, it should have either a single value or "
825
+ f"{img.shape[-3]} (number of input channels)"
826
+ )
827
+ x, y, h, w, v = self.get_params(img, self.scale, self.ratio, rand_value)
828
+ img = F.erase(img, x, y, h, w, v)
829
+ if convert_to_pil:
830
+ img = F.to_pil_image(img)
831
+ return img
832
+
833
+ class RandomVisualization(torch.nn.Module):
834
+ r"""
835
+ Randomly visualizes the fully augmented images by saving them at a specified directory.
836
+ """
837
+ def __init__(
838
+ self,
839
+ save_dir: str = "/nfs/turbo/coe-ahowens-nobackup/jespark/visualizations/fake_img",
840
+ save_p: float = 0.01,
841
+ max_imgs: int = 500,
842
+ overwrite: bool = False,
843
+ ) -> None:
844
+ super().__init__()
845
+ self.save_dir = save_dir
846
+ self.save_p = save_p
847
+ self.max_imgs = max_imgs
848
+ self.overwrite = overwrite
849
+ self.skip_namecheck=False
850
+
851
+ def next_available_filename(self, save_dir, max_imgs):
852
+ # Returns next available filename
853
+ # image format = visualization_{03d}_{i}.png, i=[0, max_imgs)
854
+ # let's not make it overwrite
855
+ imgs = os.listdir(save_dir)
856
+ imgs_list = [int(img.split("_")[-1].split(".")[0]) for img in imgs]
857
+ random_int = random.randint(0, 999)
858
+ if len(imgs_list) >= max_imgs:
859
+ if self.overwrite:
860
+ return random.choice(imgs) # overwrite random file from imgs
861
+ else:
862
+ self.skip_namecheck=True
863
+ return False
864
+ elif len(imgs_list) > 0:
865
+ next_int = max(imgs_list) + 1
866
+ return f"visualization_{next_int}_{random_int:03d}.png"
867
+ elif len(imgs_list) == 0:
868
+ return f"visualization_0_{random_int:03d}.png"
869
+ else: # uncaught, unexpected situation.
870
+ raise ValueError("Error in next_available_filename")
871
+
872
+ def forward(self, img) -> Tensor:
873
+ """
874
+ Args:
875
+ img: PIL image to be transformed.
876
+
877
+ Returns:
878
+ Tensor: Converted Image
879
+ """
880
+ if not self.skip_namecheck:
881
+ if random.random() < self.save_p:
882
+ os.makedirs(self.save_dir, exist_ok=True)
883
+ filename = self.next_available_filename(self.save_dir, self.max_imgs)
884
+ if filename:
885
+ img.save(os.path.join(self.save_dir, filename))
886
+ return img
887
+
888
+ class RandomStateAugmentation(torch.nn.Module):
889
+ r"""
890
+ Randomly applies augmentations given in the input
891
+ """
892
+ def __init__(
893
+ self,
894
+ resize_size=256,
895
+ crop_size=224,
896
+ auglist="JPEGinMemory,RandomResizeWithRandomIntpl,RandomCrop,RandomHorizontalFlip,RandomVerticalFlip,RRCWithRandomIntpl,RandomRotation,RandomTranslate,RandomShear,RandomPadding",
897
+ min_augs='0',
898
+ max_augs='5',
899
+ ):
900
+ """
901
+ auglist: augmentation lists to apply. Input comma-separated string of augmentations.
902
+ min_augs: minimum number of augmentations to apply. (can be comma-separated string to denote per-augmentation minimum)
903
+ max_augs: maximum number of augmentations to apply. (can be comma-separated string to denote per-augmentation maximum)
904
+ """
905
+ super().__init__()
906
+ self.resize_size=resize_size
907
+ self.crop_size=crop_size
908
+
909
+ self.auglist = self.parse_auglist(auglist)
910
+ # convert min_augs and max_augs to appropriate format
911
+ min_augs = self.parse_augnums(min_augs)
912
+ max_augs = self.parse_augnums(max_augs)
913
+ if type(min_augs) == list:
914
+ assert type(max_augs) == list, "max_augs should be list if min_augs is list."
915
+ assert len(min_augs) == len(auglist), "min_augs length should be equal to auglist length."
916
+ assert len(max_augs) == len(auglist), "max_augs length should be equal to auglist length."
917
+ # convert min_augs and max_augs to list if they are not
918
+ self.min_augs = [min_augs] * len(self.auglist) if type(min_augs) != list else min_augs
919
+ self.max_augs = [max_augs] * len(self.auglist) if type(max_augs) != list else max_augs
920
+
921
+ def parse_augnums(self, augsnum):
922
+ # parse min_augs or max_augs. They are expected to be a string of integers, optinally separated by commas.
923
+ augsnum_list = augsnum.split(",")
924
+ if len(augsnum_list) == 1:
925
+ return int(augsnum_list[0])
926
+ else:
927
+ return [int(aug) for aug in augsnum_list]
928
+
929
+
930
+ def parse_auglist(self, auglist):
931
+ # parse str-comma-separated auglist to list of augmentations
932
+ # default augmentation thoughts: "JPEGinMemory,RandomResizeWithRandomIntpl,RandomCrop,RandomHorizontalFlip,RandomVerticalFlip,RRCWithRandomIntpl,RandomRotation,RandomTranslate,RandomShear,RandomPadding"
933
+ auglist_list = auglist.split(",")
934
+ parsed_list = torch.nn.ModuleList()
935
+ for aug_name in auglist_list:
936
+ if aug_name=='singleJPEG':
937
+ parsed_list.append(ResizeAndJPEGinMemory(size=self.crop_size, quality=95, dtype=torch.float32))
938
+ if aug_name=='StochasticJPEG':
939
+ parsed_list.append(StochasticJPEG(size=self.crop_size, quality=(75, 100), num_jpeg=(1, 5), jpeg_p=0.5, rrc_p=0.5, rrc_scale=(0.75, 1.0), rrc_ratio=(3./4., 4./3.), no_rrc=False, dtype=torch.float32))
940
+ if aug_name=='JPEGinMemory':
941
+ parsed_list.append(JPEGinMemory(quality_range=(75, 100), dtype=torch.float32))
942
+ if aug_name=='RandomResizeWithRandomIntpl':
943
+ parsed_list.append(RandomResizeWithRandomIntpl(size_range=(self.crop_size+1,round(self.crop_size*1.228)))) # should not be smaller; causes issues with Random Crop.
944
+ if aug_name=='RandomCrop':
945
+ parsed_list.append(T.RandomCrop(self.crop_size))
946
+ if aug_name=='RandomHorizontalFlip':
947
+ parsed_list.append(T.RandomHorizontalFlip())
948
+ if aug_name=='RandomVerticalFlip':
949
+ parsed_list.append(T.RandomVerticalFlip())
950
+ if aug_name=='RRCWithRandomIntpl':
951
+ parsed_list.append(RRCWithRandomIntpl(size=self.crop_size, scale=(0.9, 1.0), ratio=(3./4., 4./3.)))
952
+ if aug_name=='RandomRotation':
953
+ parsed_list.append(T.RandomRotation(15, interpolation=InterpolationMode.BILINEAR))
954
+ if aug_name=='RandomTranslate':
955
+ parsed_list.append(T.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=None, shear=None, interpolation=InterpolationMode.BILINEAR))
956
+ if aug_name=='RandomShear':
957
+ parsed_list.append(T.RandomAffine(degrees=0, translate=None, scale=None, shear=(-15, 15, -15, 15), interpolation=InterpolationMode.BILINEAR))
958
+ if aug_name=='RandomPadding' or aug_name=='RandomPaddingAndResize':
959
+ parsed_list.append(RandomPaddingAndResize(pad_percentage_range=(0.1, 0.1), padding_value_range=(0, 255)))
960
+ if aug_name=='RandomCutout':
961
+ parsed_list.append(RandomCutout(p=0.5, scale=(0.02, 0.06), ratio=(0.3, 3.3), value_range=(0, 255)))
962
+
963
+ return parsed_list
964
+
965
+ def generate_randAug_counts(self):
966
+ # Generates random required counts per augmentation
967
+ per_aug_counts = [0] * len(self.auglist)
968
+ for i in range(len(per_aug_counts)):
969
+ per_aug_counts[i] = random.randint(self.min_augs[i], self.max_augs[i])
970
+ return per_aug_counts
971
+
972
+ def convert_aug_counts_to_idxList(self, per_aug_counts):
973
+ # convert per augmentation count to list of indices. For example, [1,3,2] = [0,1,1,1,2,2]
974
+ idxList = []
975
+ for i in range(len(per_aug_counts)):
976
+ idxList += [i] * per_aug_counts[i]
977
+ return idxList
978
+
979
+ def check_if_complete(self, count, min_augs):
980
+ # not needed
981
+ if type(min_augs) == list:
982
+ min_augs_list = min_augs
983
+ else:
984
+ min_augs_list = [min_augs] * len(self.auglist)
985
+ for i in range(len(min_augs_list)):
986
+ if count[i] < min_augs_list[i]:
987
+ return False
988
+ return True
989
+
990
+ def forward(self, img) -> Tensor:
991
+ """
992
+ Args:
993
+ img: PIL image to be transformed.
994
+
995
+ Returns:
996
+ Tensor: Converted Image
997
+ """
998
+ assert F._is_pil_image(img), "Input should be a PIL image (RandomStateAugmentation transform)"
999
+ # randomly applies augmentation. Randomly walks through the list of augmentations and applies them. They should be applied at least "min_augs" number of times.
1000
+ #count = [0] * len(self.auglist)
1001
+
1002
+ idxList = self.convert_aug_counts_to_idxList(self.generate_randAug_counts())
1003
+
1004
+ while len(idxList) > 0:
1005
+ randomIdx = idxList.pop(random.randint(0, len(idxList)-1)) # randomly pop index from idxList
1006
+ img = self.auglist[randomIdx](img)
1007
+ #count[randomIdx] += 1 # not needed, idxList contains exact amount of augmentations to apply per idx.
1008
+
1009
+ return img
1010
+
1011
+ class RandomSignRotation(torch.nn.Module):
1012
+ r"""
1013
+ Randomly rotates the image by given angle. Randomly changes sign.
1014
+ """
1015
+
1016
+ def __init__(
1017
+ self,
1018
+ angle: int,
1019
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
1020
+ ) -> Tensor:
1021
+ super().__init__()
1022
+ self.angle = angle
1023
+ self.interpolation = interpolation
1024
+
1025
+ def forward(self, img) -> Tensor:
1026
+ """
1027
+ Args:
1028
+ img: PIL image to be transformed.
1029
+
1030
+ Returns:
1031
+ Tensor: Converted Image
1032
+ """
1033
+ if random.random() < 0.5:
1034
+ angle = -self.angle
1035
+ else:
1036
+ angle = self.angle
1037
+ img = F.rotate(img, angle, interpolation=self.interpolation)
1038
+ return img
1039
+
1040
+ def __repr__(self) -> str:
1041
+ s = (
1042
+ f"{self.__class__.__name__}("
1043
+ f" angle={self.angle}"
1044
+ f", interpolation={self.interpolation}"
1045
+ f")"
1046
+ )
1047
+ return s
1048
+
1049
+ class RandomResize(torch.nn.Module):
1050
+ r"""
1051
+ Randomly resizes the input. Either up or downsample and then return it to the original size. Arguments take percentage of resizing (e.g., 0.3 means it can be downsized or upsampled by 30%)
1052
+ """
1053
+ def __init__(
1054
+ self,
1055
+ resize_percentage: float,
1056
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
1057
+ ) -> Tensor:
1058
+ super().__init__()
1059
+ self.resize_percentage = resize_percentage
1060
+ self.interpolation = interpolation
1061
+
1062
+ def forward(self, img) -> Tensor:
1063
+ """
1064
+ Args:
1065
+ img: PIL image to be transformed.
1066
+
1067
+ Returns:
1068
+ Tensor: Converted Image
1069
+ """
1070
+ if random.random() < 0.5:
1071
+ resize_percentage = 1.0 - self.resize_percentage
1072
+ else:
1073
+ resize_percentage = 1.0 + self.resize_percentage
1074
+ original_size_1, original_size_0 = img.size # width, height
1075
+ img = F.resize(img, (int(original_size_0*resize_percentage), int(original_size_1*resize_percentage)), interpolation=self.interpolation) # resized height, width
1076
+ img = F.resize(img, (original_size_0, original_size_1), interpolation=self.interpolation)
1077
+ return img
1078
+