esc50-model / src /models /traintransformer.py
mateo496's picture
Added Transformer training
83d5454
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import os
import numpy as np
from tqdm import tqdm
import time
class TransformerTrainer:
def __init__(
self,
model,
train_loader,
val_loader,
num_epochs=50,
learning_rate=1e-4,
weight_decay=1e-4,
warmup_epochs=5,
checkpoint_dir="models/transformer/checkpoints",
device="cuda"
):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.num_epochs = num_epochs
self.device = device
self.checkpoint_dir = checkpoint_dir
os.makedirs(checkpoint_dir, exist_ok=True)
self.optimizer = AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max = num_epochs - warmup_epochs,
eta_min=1e-6
)
self.warmup_epochs = warmup_epochs
self.base_lr = learning_rate
self.criterion = nn.CrossEntropyLoss()
self.train_loss = []
self.val_loss = []
self.train_acc = []
self.val_acc = []
self.best_val_acc = 0
self.best_epoch = 0
def warmup_lr(self, epoch):
if epoch < self.warmup_epochs:
lr = self.base_lr * (epoch + 1) / self.warmup_epochs
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(tqdm(self.train_loader, des=f"Epoch {epoch}/{self.num_epochs}")):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total_loss += loss.item()
total += target.size()
pbar.set_postfix({
'loss': total_loss / (batch_idx + 1),
'acc': 100. * correct / total,
'lr': self.optimizer.param_groups[0]['lr']
})
avg_loss = total_loss / len(self.train_loader)
avg_acc = 100. * correct / total
return avg_loss, avg_acc
def validate(self):
self.model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in tqdm(self.val_loader, desc='Validation'):
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
total += target.size(0)
avg_loss = total_loss / len(self.val_loader)
avg_acc = 100. * correct / total
return avg_loss, avg_acc
def save_checkpoint(self, epoch, val_acc, is_best=False):
"""Save model checkpoint."""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'val_acc': val_acc,
'train_losses': self.train_losses,
'val_losses': self.val_losses,
'train_accs': self.train_accs,
'val_accs': self.val_accs,
}
# Save latest checkpoint
path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth')
torch.save(checkpoint, path)
# Save best checkpoint
if is_best:
path = os.path.join(self.checkpoint_dir, 'checkpoint_best.pth')
torch.save(checkpoint, path)
print(f'✓ Saved best model with val_acc: {val_acc:.2f}%')
def load_checkpoint(self, checkpoint_path):
"""Load model checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.train_losses = checkpoint['train_losses']
self.val_losses = checkpoint['val_losses']
self.train_accs = checkpoint['train_accs']
self.val_accs = checkpoint['val_accs']
print(f'✓ Loaded checkpoint from epoch {checkpoint["epoch"]}')
return checkpoint['epoch']
def train(self, resume_from=None):
"""
Main training loop.
Args:
resume_from: Path to checkpoint to resume from
Returns:
Best validation accuracy
"""
start_epoch = 1
if resume_from:
start_epoch = self.load_checkpoint(resume_from) + 1
print(f'\nStarting training for {self.num_epochs} epochs')
print(f'Device: {self.device}')
print(f'Training samples: {len(self.train_loader.dataset)}')
print(f'Validation samples: {len(self.val_loader.dataset)}')
print('-' * 60)
start_time = time.time()
for epoch in range(start_epoch, self.num_epochs + 1):
# Warmup learning rate
if epoch <= self.warmup_epochs:
self._warmup_lr(epoch - 1)
# Train
train_loss, train_acc = self.train_epoch(epoch)
# Validate
val_loss, val_acc = self.validate()
# Update scheduler (after warmup)
if epoch > self.warmup_epochs:
self.scheduler.step()
# Track metrics
self.train_losses.append(train_loss)
self.val_losses.append(val_loss)
self.train_accs.append(train_acc)
self.val_accs.append(val_acc)
# Print epoch summary
print(f'\nEpoch {epoch}/{self.num_epochs}:')
print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
# Save checkpoint
is_best = val_acc > self.best_val_acc
if is_best:
self.best_val_acc = val_acc
self.best_epoch = epoch
self.save_checkpoint(epoch, val_acc, is_best)
# Early stopping check (optional)
if epoch - self.best_epoch > 30:
print(f'\nEarly stopping: no improvement for 30 epochs')
break
elapsed_time = time.time() - start_time
print(f'\n{"="*60}')
print(f'Training completed in {elapsed_time/3600:.2f} hours')
print(f'Best validation accuracy: {self.best_val_acc:.2f}% at epoch {self.best_epoch}')
print(f'{"="*60}')
return self.best_val_acc