|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import numpy as np
|
|
|
|
|
|
def mixup_data(x, y, alpha=1.0, device='cuda'):
|
|
|
"""
|
|
|
Mixup data augmentation.
|
|
|
|
|
|
Args:
|
|
|
x: Input batch
|
|
|
y: Target batch
|
|
|
alpha: Mixup parameter (higher = more mixing)
|
|
|
device: Device to run on
|
|
|
|
|
|
Returns:
|
|
|
mixed_x: Mixed input batch
|
|
|
y_a, y_b: Original targets for loss calculation
|
|
|
lam: Mixing ratio
|
|
|
"""
|
|
|
if alpha > 0:
|
|
|
lam = np.random.beta(alpha, alpha)
|
|
|
else:
|
|
|
lam = 1
|
|
|
|
|
|
batch_size = x.size(0)
|
|
|
if device == 'cuda':
|
|
|
index = torch.randperm(batch_size).cuda()
|
|
|
else:
|
|
|
index = torch.randperm(batch_size)
|
|
|
|
|
|
mixed_x = lam * x + (1 - lam) * x[index, :]
|
|
|
y_a, y_b = y, y[index]
|
|
|
return mixed_x, y_a, y_b, lam
|
|
|
|
|
|
def mixup_criterion(criterion, pred, y_a, y_b, lam):
|
|
|
"""
|
|
|
Mixup loss calculation.
|
|
|
|
|
|
Args:
|
|
|
criterion: Loss function
|
|
|
pred: Model predictions
|
|
|
y_a, y_b: Original targets
|
|
|
lam: Mixing ratio
|
|
|
|
|
|
Returns:
|
|
|
Mixed loss
|
|
|
"""
|
|
|
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
|
|
|
|
|
class MixupTrainer:
|
|
|
"""
|
|
|
Mixup training wrapper.
|
|
|
"""
|
|
|
def __init__(self, model, optimizer, criterion, device, alpha=0.2):
|
|
|
self.model = model
|
|
|
self.optimizer = optimizer
|
|
|
self.criterion = criterion
|
|
|
self.device = device
|
|
|
self.alpha = alpha
|
|
|
|
|
|
def train_step(self, dataloader):
|
|
|
"""
|
|
|
Single training step with mixup.
|
|
|
"""
|
|
|
self.model.train()
|
|
|
total_loss = 0
|
|
|
correct = 0
|
|
|
total = 0
|
|
|
|
|
|
for batch_idx, (data, target) in enumerate(dataloader):
|
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
|
|
|
|
|
|
|
data, target_a, target_b, lam = mixup_data(data, target, self.alpha, self.device)
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
output = self.model(data)
|
|
|
loss = mixup_criterion(self.criterion, output, target_a, target_b, lam)
|
|
|
loss.backward()
|
|
|
self.optimizer.step()
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
_, predicted = torch.max(output.data, 1)
|
|
|
total += target.size(0)
|
|
|
correct += (lam * predicted.eq(target_a.data).cpu().sum().float() +
|
|
|
(1 - lam) * predicted.eq(target_b.data).cpu().sum().float())
|
|
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
|
accuracy = 100. * correct / total
|
|
|
|
|
|
return avg_loss, accuracy.item()
|
|
|
|