| | import torch
|
| | import torch.nn as nn
|
| | import numpy as np
|
| |
|
| | def mixup_data(x, y, alpha=1.0, device='cuda'):
|
| | """
|
| | Mixup data augmentation.
|
| |
|
| | Args:
|
| | x: Input batch
|
| | y: Target batch
|
| | alpha: Mixup parameter (higher = more mixing)
|
| | device: Device to run on
|
| |
|
| | Returns:
|
| | mixed_x: Mixed input batch
|
| | y_a, y_b: Original targets for loss calculation
|
| | lam: Mixing ratio
|
| | """
|
| | if alpha > 0:
|
| | lam = np.random.beta(alpha, alpha)
|
| | else:
|
| | lam = 1
|
| |
|
| | batch_size = x.size(0)
|
| | if device == 'cuda':
|
| | index = torch.randperm(batch_size).cuda()
|
| | else:
|
| | index = torch.randperm(batch_size)
|
| |
|
| | mixed_x = lam * x + (1 - lam) * x[index, :]
|
| | y_a, y_b = y, y[index]
|
| | return mixed_x, y_a, y_b, lam
|
| |
|
| | def mixup_criterion(criterion, pred, y_a, y_b, lam):
|
| | """
|
| | Mixup loss calculation.
|
| |
|
| | Args:
|
| | criterion: Loss function
|
| | pred: Model predictions
|
| | y_a, y_b: Original targets
|
| | lam: Mixing ratio
|
| |
|
| | Returns:
|
| | Mixed loss
|
| | """
|
| | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
| |
|
| | class MixupTrainer:
|
| | """
|
| | Mixup training wrapper.
|
| | """
|
| | def __init__(self, model, optimizer, criterion, device, alpha=0.2):
|
| | self.model = model
|
| | self.optimizer = optimizer
|
| | self.criterion = criterion
|
| | self.device = device
|
| | self.alpha = alpha
|
| |
|
| | def train_step(self, dataloader):
|
| | """
|
| | Single training step with mixup.
|
| | """
|
| | self.model.train()
|
| | total_loss = 0
|
| | correct = 0
|
| | total = 0
|
| |
|
| | for batch_idx, (data, target) in enumerate(dataloader):
|
| | data, target = data.to(self.device), target.to(self.device)
|
| |
|
| |
|
| | data, target_a, target_b, lam = mixup_data(data, target, self.alpha, self.device)
|
| |
|
| | self.optimizer.zero_grad()
|
| | output = self.model(data)
|
| | loss = mixup_criterion(self.criterion, output, target_a, target_b, lam)
|
| | loss.backward()
|
| | self.optimizer.step()
|
| |
|
| | total_loss += loss.item()
|
| |
|
| | _, predicted = torch.max(output.data, 1)
|
| | total += target.size(0)
|
| | correct += (lam * predicted.eq(target_a.data).cpu().sum().float() +
|
| | (1 - lam) * predicted.eq(target_b.data).cpu().sum().float())
|
| |
|
| | avg_loss = total_loss / len(dataloader)
|
| | accuracy = 100. * correct / total
|
| |
|
| | return avg_loss, accuracy.item()
|
| |
|