import numpy as np import torch import torchvision.transforms as transforms # Transformations class TwoCropTransform: def __init__(self, transform, img_size): self.transform = transform self.img_size = img_size color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2) self.data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size), transforms.RandomHorizontalFlip(), transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.ToTensor()]) def __call__(self, x): return [self.transform(x), self.data_transforms(x)] def rotation(input): batch = input.shape[0] target = torch.tensor(np.random.permutation([0,1,2,3] * (int(batch / 4) + 1)), device = input.device)[:batch] target = target.long() image = torch.zeros_like(input) image.copy_(input) for i in range(batch): image[i, :, :, :] = torch.rot90(input[i, :, :, :], target[i], [1, 2]) return image, target