File size: 2,741 Bytes
a6eed2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn as nn
import numpy as np
def mixup_data(x, y, alpha=1.0, device='cuda'):
"""
Mixup data augmentation.
Args:
x: Input batch
y: Target batch
alpha: Mixup parameter (higher = more mixing)
device: Device to run on
Returns:
mixed_x: Mixed input batch
y_a, y_b: Original targets for loss calculation
lam: Mixing ratio
"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size(0)
if device == 'cuda':
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
"""
Mixup loss calculation.
Args:
criterion: Loss function
pred: Model predictions
y_a, y_b: Original targets
lam: Mixing ratio
Returns:
Mixed loss
"""
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
class MixupTrainer:
"""
Mixup training wrapper.
"""
def __init__(self, model, optimizer, criterion, device, alpha=0.2):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.alpha = alpha
def train_step(self, dataloader):
"""
Single training step with mixup.
"""
self.model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(self.device), target.to(self.device)
# Apply mixup
data, target_a, target_b, lam = mixup_data(data, target, self.alpha, self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = mixup_criterion(self.criterion, output, target_a, target_b, lam)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
# For accuracy calculation, use original targets
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (lam * predicted.eq(target_a.data).cpu().sum().float() +
(1 - lam) * predicted.eq(target_b.data).cpu().sum().float())
avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracy.item()
|