vrevar
Add application file
04c78c7
import random
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as tf
import torch.nn.functional as F
class RandomResizedCrop(T.RandomResizedCrop):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, x, dtypes):
"""WARNING: torchvision v0.11. Wrapper to T.RandomResizedCrop.__call__"""
i, j, h, w = self.get_params(x[0], self.scale, self.ratio)
return [tf.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in x]
class ColorJitter(T.ColorJitter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, dtypes):
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
o = []
for img, dtype in zip(x, dtypes):
if dtype == 'albedo':
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = tf.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = tf.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = tf.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = tf.adjust_hue(img, hue_factor)
o.append(img)
return o
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def flip_x(self, img, dtype):
if dtype == 'normals':
img[0] *= -1
return tf.hflip(img)
def forward(self, x, dtypes):
if torch.rand(1) < self.p:
return [self.flip_x(img, dtype) for img, dtype in zip(x, dtypes)]
return x
class RandomVerticalFlip(T.RandomVerticalFlip):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def flip_y(self, img, dtype):
if dtype == 'normals':
img[1] *= -1
return tf.vflip(img)
def forward(self, x, dtypes):
if torch.rand(1) < self.p:
return [self.flip_y(img, dtype) for img, dtype in zip(x, dtypes)]
return x
def deg0(x, y, z):
return torch.stack([ x, y, z])
def deg90(x, y, z):
return torch.stack([-y, x, z])
def deg180(x, y, z):
return torch.stack([-x, -y, z])
def deg270(x, y, z):
return torch.stack([ y, -x, z])
class RandomIncrementRotate:
def __init__(self, p):
self.p = p
self.angles = [0, 90, 180, 270]
# adjusts surface normals vector depending on rotation angle
self.f = { 0: deg0, 90: deg90, 180: deg180, 270: deg270 }
def rotate(self, img, theta, dtype):
if dtype == 'normals':
img = self.f[theta](*img)
return tf.rotate(img, theta)
def __call__(self, x, dtypes):
if torch.rand(1) < self.p:
theta = random.choice(self.angles)
return [self.rotate(img, theta, dtype) for img, dtype in zip(x, dtypes)]
return x
class NormalizeGeometry:
def normalize(self, img, dtype):
if dtype == 'normals':
# perform [0, 1] -> [-1, 1] mapping
img = 2*img - 1
# normalize vector to unit sphere
img = F.normalize(img, dim=0)
return img
def __call__(self, x, dtypes):
return [self.normalize(img, dtype) for img, dtype in zip(x, dtypes)]
class RandomCrop(T.RandomCrop):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, dtypes):
img_size = tf.get_image_size(x[0])
assert all(tf.get_image_size(y) == img_size for y in x)
i, j, h, w = self.get_params(x[0], self.size)
return [tf.crop(img, i, j, h, w) for img in x]
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, x, dtypes):
return [tf.center_crop(img, self.size) for img in x]
class Resize(T.Resize):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, dtypes):
return [super(Resize, self).forward(img) for img in x]
class Identity():
def __call__(self, x, dtypes):
return x
class ToTensor:
def __call__(self, x, dtypes):
return [tf.to_tensor(img) for img in x]
class Pipeline:
DATA_TYPES = ['input', 'normals', 'albedo']
def __init__(self, *transforms, dtypes=None):
assert all(d in Pipeline.DATA_TYPES for d in dtypes)
self.dtypes = dtypes
self.transforms = transforms
def __call__(self, x):
#print(self.dtypes)
assert len(self.dtypes) == len(x)
assert all(y.shape[1:] == x[0].shape[1:] for y in x)
for f in self.transforms:
x = f(x, self.dtypes)
return x