Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.optim import Adam, AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| import time | |
| class TransformerTrainer: | |
| def __init__( | |
| self, | |
| model, | |
| train_loader, | |
| val_loader, | |
| num_epochs=50, | |
| learning_rate=1e-4, | |
| weight_decay=1e-4, | |
| warmup_epochs=5, | |
| checkpoint_dir="models/transformer/checkpoints", | |
| device="cuda" | |
| ): | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.num_epochs = num_epochs | |
| self.device = device | |
| self.checkpoint_dir = checkpoint_dir | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| self.optimizer = AdamW( | |
| model.parameters(), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| betas=(0.9, 0.999) | |
| ) | |
| self.scheduler = CosineAnnealingLR( | |
| self.optimizer, | |
| T_max = num_epochs - warmup_epochs, | |
| eta_min=1e-6 | |
| ) | |
| self.warmup_epochs = warmup_epochs | |
| self.base_lr = learning_rate | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.train_loss = [] | |
| self.val_loss = [] | |
| self.train_acc = [] | |
| self.val_acc = [] | |
| self.best_val_acc = 0 | |
| self.best_epoch = 0 | |
| def warmup_lr(self, epoch): | |
| if epoch < self.warmup_epochs: | |
| lr = self.base_lr * (epoch + 1) / self.warmup_epochs | |
| for param_group in self.optimizer.param_groups: | |
| param_group['lr'] = lr | |
| def train_epoch(self, epoch): | |
| self.model.train() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| for batch_idx, (data, target) in enumerate(tqdm(self.train_loader, des=f"Epoch {epoch}/{self.num_epochs}")): | |
| data, target = data.to(self.device), target.to(self.device) | |
| self.optimizer.zero_grad() | |
| output = self.model(data) | |
| loss = self.criterion(output, target) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| pred = output.argmax(dim=1) | |
| correct += pred.eq(target).sum().item() | |
| total_loss += loss.item() | |
| total += target.size() | |
| pbar.set_postfix({ | |
| 'loss': total_loss / (batch_idx + 1), | |
| 'acc': 100. * correct / total, | |
| 'lr': self.optimizer.param_groups[0]['lr'] | |
| }) | |
| avg_loss = total_loss / len(self.train_loader) | |
| avg_acc = 100. * correct / total | |
| return avg_loss, avg_acc | |
| def validate(self): | |
| self.model.eval() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for data, target in tqdm(self.val_loader, desc='Validation'): | |
| data, target = data.to(self.device), target.to(self.device) | |
| output = self.model(data) | |
| loss = self.criterion(output, target) | |
| total_loss += loss.item() | |
| pred = output.argmax(dim=1) | |
| correct += pred.eq(target).sum().item() | |
| total += target.size(0) | |
| avg_loss = total_loss / len(self.val_loader) | |
| avg_acc = 100. * correct / total | |
| return avg_loss, avg_acc | |
| def save_checkpoint(self, epoch, val_acc, is_best=False): | |
| """Save model checkpoint.""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'scheduler_state_dict': self.scheduler.state_dict(), | |
| 'val_acc': val_acc, | |
| 'train_losses': self.train_losses, | |
| 'val_losses': self.val_losses, | |
| 'train_accs': self.train_accs, | |
| 'val_accs': self.val_accs, | |
| } | |
| # Save latest checkpoint | |
| path = os.path.join(self.checkpoint_dir, 'checkpoint_latest.pth') | |
| torch.save(checkpoint, path) | |
| # Save best checkpoint | |
| if is_best: | |
| path = os.path.join(self.checkpoint_dir, 'checkpoint_best.pth') | |
| torch.save(checkpoint, path) | |
| print(f'✓ Saved best model with val_acc: {val_acc:.2f}%') | |
| def load_checkpoint(self, checkpoint_path): | |
| """Load model checkpoint.""" | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| self.train_losses = checkpoint['train_losses'] | |
| self.val_losses = checkpoint['val_losses'] | |
| self.train_accs = checkpoint['train_accs'] | |
| self.val_accs = checkpoint['val_accs'] | |
| print(f'✓ Loaded checkpoint from epoch {checkpoint["epoch"]}') | |
| return checkpoint['epoch'] | |
| def train(self, resume_from=None): | |
| """ | |
| Main training loop. | |
| Args: | |
| resume_from: Path to checkpoint to resume from | |
| Returns: | |
| Best validation accuracy | |
| """ | |
| start_epoch = 1 | |
| if resume_from: | |
| start_epoch = self.load_checkpoint(resume_from) + 1 | |
| print(f'\nStarting training for {self.num_epochs} epochs') | |
| print(f'Device: {self.device}') | |
| print(f'Training samples: {len(self.train_loader.dataset)}') | |
| print(f'Validation samples: {len(self.val_loader.dataset)}') | |
| print('-' * 60) | |
| start_time = time.time() | |
| for epoch in range(start_epoch, self.num_epochs + 1): | |
| # Warmup learning rate | |
| if epoch <= self.warmup_epochs: | |
| self._warmup_lr(epoch - 1) | |
| # Train | |
| train_loss, train_acc = self.train_epoch(epoch) | |
| # Validate | |
| val_loss, val_acc = self.validate() | |
| # Update scheduler (after warmup) | |
| if epoch > self.warmup_epochs: | |
| self.scheduler.step() | |
| # Track metrics | |
| self.train_losses.append(train_loss) | |
| self.val_losses.append(val_loss) | |
| self.train_accs.append(train_acc) | |
| self.val_accs.append(val_acc) | |
| # Print epoch summary | |
| print(f'\nEpoch {epoch}/{self.num_epochs}:') | |
| print(f' Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%') | |
| print(f' Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%') | |
| print(f' LR: {self.optimizer.param_groups[0]["lr"]:.6f}') | |
| # Save checkpoint | |
| is_best = val_acc > self.best_val_acc | |
| if is_best: | |
| self.best_val_acc = val_acc | |
| self.best_epoch = epoch | |
| self.save_checkpoint(epoch, val_acc, is_best) | |
| # Early stopping check (optional) | |
| if epoch - self.best_epoch > 30: | |
| print(f'\nEarly stopping: no improvement for 30 epochs') | |
| break | |
| elapsed_time = time.time() - start_time | |
| print(f'\n{"="*60}') | |
| print(f'Training completed in {elapsed_time/3600:.2f} hours') | |
| print(f'Best validation accuracy: {self.best_val_acc:.2f}% at epoch {self.best_epoch}') | |
| print(f'{"="*60}') | |
| return self.best_val_acc | |