|
|
import numpy as np
|
|
|
import torch
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
|
|
|
|
class TwoCropTransform:
|
|
|
def __init__(self, transform, img_size):
|
|
|
self.transform = transform
|
|
|
self.img_size = img_size
|
|
|
color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
|
|
|
self.data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.RandomApply([color_jitter], p=0.8),
|
|
|
transforms.RandomGrayscale(p=0.2),
|
|
|
transforms.ToTensor()])
|
|
|
|
|
|
def __call__(self, x):
|
|
|
return [self.transform(x), self.data_transforms(x)]
|
|
|
|
|
|
|
|
|
def rotation(input):
|
|
|
batch = input.shape[0]
|
|
|
target = torch.tensor(np.random.permutation([0,1,2,3] * (int(batch / 4) + 1)), device = input.device)[:batch]
|
|
|
target = target.long()
|
|
|
image = torch.zeros_like(input)
|
|
|
image.copy_(input)
|
|
|
for i in range(batch):
|
|
|
image[i, :, :, :] = torch.rot90(input[i, :, :, :], target[i], [1, 2])
|
|
|
|
|
|
return image, target
|
|
|
|