| """ |
| Advanced augmentations: MixUp, CutMix at batch level. |
| |
| These are applied after the dataloader returns a batch, |
| not as part of the per-image transform pipeline. |
| """ |
|
|
| import torch |
| import numpy as np |
|
|
|
|
| def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2) -> tuple: |
| """MixUp: convex combination of pairs of examples.""" |
| if alpha <= 0: |
| return x, y, y, 1.0 |
| lam = np.random.beta(alpha, alpha) |
| batch_size = x.size(0) |
| index = torch.randperm(batch_size, device=x.device) |
| mixed_x = lam * x + (1 - lam) * x[index] |
| return mixed_x, y, y[index], lam |
|
|
|
|
| def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0) -> tuple: |
| """CutMix: cut and paste patches between training images.""" |
| if alpha <= 0: |
| return x, y, y, 1.0 |
| lam = np.random.beta(alpha, alpha) |
| batch_size = x.size(0) |
| index = torch.randperm(batch_size, device=x.device) |
|
|
| _, _, h, w = x.shape |
| cut_ratio = np.sqrt(1.0 - lam) |
| cut_h = int(h * cut_ratio) |
| cut_w = int(w * cut_ratio) |
|
|
| cy = np.random.randint(h) |
| cx = np.random.randint(w) |
|
|
| y1 = np.clip(cy - cut_h // 2, 0, h) |
| y2 = np.clip(cy + cut_h // 2, 0, h) |
| x1 = np.clip(cx - cut_w // 2, 0, w) |
| x2 = np.clip(cx + cut_w // 2, 0, w) |
|
|
| x_cut = x.clone() |
| x_cut[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2] |
|
|
| |
| lam = 1 - (y2 - y1) * (x2 - x1) / (h * w) |
|
|
| return x_cut, y, y[index], lam |
|
|
|
|
| def mixup_criterion(criterion, pred, y_a, y_b, lam): |
| """Loss for MixUp/CutMix: weighted sum of losses.""" |
| return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) |
|
|