| | import random |
| |
|
| | import torch |
| | import torchvision.transforms as T |
| | import torchvision.transforms.functional as tf |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class RandomResizedCrop(T.RandomResizedCrop): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def __call__(self, x, dtypes): |
| | """WARNING: torchvision v0.11. Wrapper to T.RandomResizedCrop.__call__""" |
| | i, j, h, w = self.get_params(x[0], self.scale, self.ratio) |
| | return [tf.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in x] |
| |
|
| | class ColorJitter(T.ColorJitter): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x, dtypes): |
| | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ |
| | self.get_params(self.brightness, self.contrast, self.saturation, self.hue) |
| |
|
| | o = [] |
| | for img, dtype in zip(x, dtypes): |
| | if dtype == 'albedo': |
| | for fn_id in fn_idx: |
| | if fn_id == 0 and brightness_factor is not None: |
| | img = tf.adjust_brightness(img, brightness_factor) |
| | elif fn_id == 1 and contrast_factor is not None: |
| | img = tf.adjust_contrast(img, contrast_factor) |
| | elif fn_id == 2 and saturation_factor is not None: |
| | img = tf.adjust_saturation(img, saturation_factor) |
| | elif fn_id == 3 and hue_factor is not None: |
| | img = tf.adjust_hue(img, hue_factor) |
| | o.append(img) |
| | return o |
| |
|
| | class RandomHorizontalFlip(T.RandomHorizontalFlip): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def flip_x(self, img, dtype): |
| | if dtype == 'normals': |
| | img[0] *= -1 |
| | return tf.hflip(img) |
| |
|
| | def forward(self, x, dtypes): |
| | if torch.rand(1) < self.p: |
| | return [self.flip_x(img, dtype) for img, dtype in zip(x, dtypes)] |
| | return x |
| |
|
| | class RandomVerticalFlip(T.RandomVerticalFlip): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def flip_y(self, img, dtype): |
| | if dtype == 'normals': |
| | img[1] *= -1 |
| | return tf.vflip(img) |
| |
|
| | def forward(self, x, dtypes): |
| | if torch.rand(1) < self.p: |
| | return [self.flip_y(img, dtype) for img, dtype in zip(x, dtypes)] |
| | return x |
| |
|
| | def deg0(x, y, z): |
| | return torch.stack([ x, y, z]) |
| | def deg90(x, y, z): |
| | return torch.stack([-y, x, z]) |
| | def deg180(x, y, z): |
| | return torch.stack([-x, -y, z]) |
| | def deg270(x, y, z): |
| | return torch.stack([ y, -x, z]) |
| |
|
| | class RandomIncrementRotate: |
| | def __init__(self, p): |
| | self.p = p |
| | self.angles = [0, 90, 180, 270] |
| |
|
| | |
| | self.f = { 0: deg0, 90: deg90, 180: deg180, 270: deg270 } |
| |
|
| | def rotate(self, img, theta, dtype): |
| | if dtype == 'normals': |
| | img = self.f[theta](*img) |
| | return tf.rotate(img, theta) |
| |
|
| | def __call__(self, x, dtypes): |
| | if torch.rand(1) < self.p: |
| | theta = random.choice(self.angles) |
| | return [self.rotate(img, theta, dtype) for img, dtype in zip(x, dtypes)] |
| | return x |
| |
|
| | class NormalizeGeometry: |
| | def normalize(self, img, dtype): |
| | if dtype == 'normals': |
| | |
| | img = 2*img - 1 |
| | |
| | img = F.normalize(img, dim=0) |
| | return img |
| |
|
| | def __call__(self, x, dtypes): |
| | return [self.normalize(img, dtype) for img, dtype in zip(x, dtypes)] |
| |
|
| | class RandomCrop(T.RandomCrop): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x, dtypes): |
| | img_size = tf.get_image_size(x[0]) |
| | assert all(tf.get_image_size(y) == img_size for y in x) |
| | i, j, h, w = self.get_params(x[0], self.size) |
| | return [tf.crop(img, i, j, h, w) for img in x] |
| |
|
| | class CenterCrop: |
| | def __init__(self, size): |
| | self.size = size |
| |
|
| | def __call__(self, x, dtypes): |
| | return [tf.center_crop(img, self.size) for img in x] |
| |
|
| | class Resize(T.Resize): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x, dtypes): |
| | return [super(Resize, self).forward(img) for img in x] |
| |
|
| | class Identity(): |
| | def __call__(self, x, dtypes): |
| | return x |
| |
|
| | class ToTensor: |
| | def __call__(self, x, dtypes): |
| | return [tf.to_tensor(img) for img in x] |
| |
|
| | class Pipeline: |
| | DATA_TYPES = ['input', 'normals', 'albedo'] |
| |
|
| | def __init__(self, *transforms, dtypes=None): |
| | assert all(d in Pipeline.DATA_TYPES for d in dtypes) |
| | self.dtypes = dtypes |
| | self.transforms = transforms |
| |
|
| | def __call__(self, x): |
| | |
| | assert len(self.dtypes) == len(x) |
| | assert all(y.shape[1:] == x[0].shape[1:] for y in x) |
| | for f in self.transforms: |
| | x = f(x, self.dtypes) |
| | return x |