Image Classification
English
zhanwang's picture
update
377dccd verified
# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn.functional as F
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5, device=0):
assert (alpha > 0)
# generate mixed sample
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
if torch.cuda.is_available():
index = index.to(device)
y_a, y_b = y, y[index]
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
return x, y_a, y_b, lam
def normalize(x, mean, std):
assert len(x.shape) == 4
return (x - torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)) \
/ torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device)
def random_flip(x):
assert len(x.shape) == 4
mask = torch.rand(x.shape[0]) < 0.5
x[mask] = x[mask].flip(3)
return x
def random_grayscale(x, prob=0.2):
assert len(x.shape) == 4
mask = torch.rand(x.shape[0]) < prob
x[mask] = (x[mask] * torch.tensor([[0.299, 0.587, 0.114]]).unsqueeze(2).unsqueeze(2).to(x.device)).sum(1, keepdim=True).repeat_interleave(3, 1)
return x
def random_crop(x, padding):
assert len(x.shape) == 4
crop_x = torch.randint(-padding, padding, size=(x.shape[0],))
crop_y = torch.randint(-padding, padding, size=(x.shape[0],))
crop_x_start, crop_y_start = crop_x + padding, crop_y + padding
crop_x_end, crop_y_end = crop_x_start + x.shape[-1], crop_y_start + x.shape[-2]
oboe = F.pad(x, (padding, padding, padding, padding))
mask_x = torch.arange(x.shape[-1] + padding * 2).repeat(x.shape[0], x.shape[-1] + padding * 2, 1)
mask_y = mask_x.transpose(1, 2)
mask_x = ((mask_x >= crop_x_start.unsqueeze(1).unsqueeze(2)) & (mask_x < crop_x_end.unsqueeze(1).unsqueeze(2)))
mask_y = ((mask_y >= crop_y_start.unsqueeze(1).unsqueeze(2)) & (mask_y < crop_y_end.unsqueeze(1).unsqueeze(2)))
return oboe[mask_x.unsqueeze(1).repeat(1, x.shape[1], 1, 1) * mask_y.unsqueeze(1).repeat(1, x.shape[1], 1, 1)].reshape(x.shape[0], 3, x.shape[2], x.shape[3])
class soft_aug():
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, x):
return normalize(
random_flip(
random_crop(x, 4)
),
self.mean, self.std)
class strong_aug():
def __init__(self, size, mean, std):
from torchvision import transforms
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.ToTensor()
])
self.mean = mean
self.std = std
def __call__(self, x):
flip = random_flip(x)
return normalize(random_grayscale(
torch.stack(
[self.transform(a) for a in flip]
)), self.mean, self.std)