| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def rand_bbox(size, lam): |
| W = size[2] |
| H = size[3] |
| cut_rat = np.sqrt(1. - lam) |
| cut_w = np.int(W * cut_rat) |
| cut_h = np.int(H * cut_rat) |
|
|
| |
| cx = np.random.randint(W) |
| cy = np.random.randint(H) |
|
|
| bbx1 = np.clip(cx - cut_w // 2, 0, W) |
| bby1 = np.clip(cy - cut_h // 2, 0, H) |
| bbx2 = np.clip(cx + cut_w // 2, 0, W) |
| bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
|
| return bbx1, bby1, bbx2, bby2 |
|
|
|
|
| def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5, device=0): |
| assert (alpha > 0) |
| |
| lam = np.random.beta(alpha, alpha) |
|
|
| batch_size = x.size()[0] |
| index = torch.randperm(batch_size) |
|
|
| if torch.cuda.is_available(): |
| index = index.to(device) |
|
|
| y_a, y_b = y, y[index] |
| bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) |
| x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] |
|
|
| |
| lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) |
| return x, y_a, y_b, lam |
|
|
|
|
| def normalize(x, mean, std): |
| assert len(x.shape) == 4 |
| return (x - torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)) \ |
| / torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) |
|
|
|
|
| def random_flip(x): |
| assert len(x.shape) == 4 |
| mask = torch.rand(x.shape[0]) < 0.5 |
| x[mask] = x[mask].flip(3) |
| return x |
|
|
|
|
| def random_grayscale(x, prob=0.2): |
| assert len(x.shape) == 4 |
| mask = torch.rand(x.shape[0]) < prob |
| x[mask] = (x[mask] * torch.tensor([[0.299, 0.587, 0.114]]).unsqueeze(2).unsqueeze(2).to(x.device)).sum(1, keepdim=True).repeat_interleave(3, 1) |
| return x |
|
|
|
|
| def random_crop(x, padding): |
| assert len(x.shape) == 4 |
| crop_x = torch.randint(-padding, padding, size=(x.shape[0],)) |
| crop_y = torch.randint(-padding, padding, size=(x.shape[0],)) |
|
|
| crop_x_start, crop_y_start = crop_x + padding, crop_y + padding |
| crop_x_end, crop_y_end = crop_x_start + x.shape[-1], crop_y_start + x.shape[-2] |
|
|
| oboe = F.pad(x, (padding, padding, padding, padding)) |
| mask_x = torch.arange(x.shape[-1] + padding * 2).repeat(x.shape[0], x.shape[-1] + padding * 2, 1) |
| mask_y = mask_x.transpose(1, 2) |
| mask_x = ((mask_x >= crop_x_start.unsqueeze(1).unsqueeze(2)) & (mask_x < crop_x_end.unsqueeze(1).unsqueeze(2))) |
| mask_y = ((mask_y >= crop_y_start.unsqueeze(1).unsqueeze(2)) & (mask_y < crop_y_end.unsqueeze(1).unsqueeze(2))) |
| return oboe[mask_x.unsqueeze(1).repeat(1, x.shape[1], 1, 1) * mask_y.unsqueeze(1).repeat(1, x.shape[1], 1, 1)].reshape(x.shape[0], 3, x.shape[2], x.shape[3]) |
|
|
|
|
| class soft_aug(): |
|
|
| def __init__(self, mean, std): |
| self.mean = mean |
| self.std = std |
|
|
| def __call__(self, x): |
| return normalize( |
| random_flip( |
| random_crop(x, 4) |
| ), |
| self.mean, self.std) |
|
|
|
|
| class strong_aug(): |
|
|
| def __init__(self, size, mean, std): |
| from torchvision import transforms |
| self.transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), |
| transforms.RandomApply([ |
| transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) |
| ], p=0.8), |
| transforms.ToTensor() |
| ]) |
| self.mean = mean |
| self.std = std |
|
|
| def __call__(self, x): |
| flip = random_flip(x) |
| return normalize(random_grayscale( |
| torch.stack( |
| [self.transform(a) for a in flip] |
| )), self.mean, self.std) |
|
|