""" Artist Style Embedding - Trainer """ from pathlib import Path from typing import Dict from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import GradScaler, autocast from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR from tqdm import tqdm import numpy as np try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False class AverageMeter: def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class Trainer: def __init__(self, model, loss_fn, train_loader, val_loader, config, artist_to_idx): self.model = model self.loss_fn = loss_fn self.train_loader = train_loader self.val_loader = val_loader self.config = config self.artist_to_idx = artist_to_idx self.idx_to_artist = {v: k for k, v in artist_to_idx.items()} self.device = torch.device(config.train.device) self.model = self.model.to(self.device) self.loss_fn = self.loss_fn.to(self.device) self.optimizer = self._create_optimizer() self.scheduler = self._create_scheduler() self.use_amp = config.train.use_amp self.scaler = GradScaler() if self.use_amp else None self.save_dir = Path(config.train.save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.current_epoch = 0 self.global_step = 0 self.best_metric = 0.0 self.patience_counter = 0 self.use_wandb = WANDB_AVAILABLE and config.train.wandb_project if self.use_wandb: wandb.init( project=config.train.wandb_project, name=config.train.wandb_run_name, config={'model': config.model.__dict__, 'train': config.train.__dict__} ) def _create_optimizer(self): backbone_params = self.model.encoder.get_backbone_params() head_params = self.model.encoder.get_head_params() arcface_params = [self.model.arcface_weight] loss_params = list(self.loss_fn.center_loss.parameters()) return AdamW([ {'params': backbone_params, 'lr': self.config.train.learning_rate * self.config.train.backbone_lr_multiplier}, {'params': head_params, 'lr': self.config.train.learning_rate}, {'params': arcface_params, 'lr': self.config.train.learning_rate}, {'params': loss_params, 'lr': self.config.train.learning_rate * 0.5}, ], weight_decay=self.config.train.weight_decay) def _create_scheduler(self): warmup = LinearLR(self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config.train.warmup_epochs) main = CosineAnnealingWarmRestarts(self.optimizer, T_0=self.config.train.epochs - self.config.train.warmup_epochs, eta_min=self.config.train.min_lr) return SequentialLR(self.optimizer, [warmup, main], milestones=[self.config.train.warmup_epochs]) def train_epoch(self) -> Dict[str, float]: self.model.train() loss_meters = defaultdict(AverageMeter) if self.current_epoch < self.config.model.freeze_backbone_epochs: self.model.encoder.freeze_backbone() else: self.model.encoder.unfreeze_backbone() pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}") for batch in pbar: full = batch['full'].to(self.device) face = batch['face'].to(self.device) eye = batch['eye'].to(self.device) has_face = batch['has_face'].to(self.device) has_eye = batch['has_eye'].to(self.device) labels = batch['label'].to(self.device) with autocast(enabled=self.use_amp): output = self.model(full, face, eye, has_face, has_eye) loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) self.optimizer.zero_grad() if self.use_amp: self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) self.optimizer.step() for k, v in loss_dict.items(): loss_meters[k].update(v, full.size(0)) pbar.set_postfix({'loss': f"{loss_meters['loss_total'].avg:.4f}"}) self.global_step += 1 if self.global_step % self.config.train.log_every_n_steps == 0 and self.use_wandb: wandb.log({f"train/{k}": v.avg for k, v in loss_meters.items()}, step=self.global_step) return {k: v.avg for k, v in loss_meters.items()} @torch.no_grad() def validate(self) -> Dict[str, float]: self.model.eval() total_correct = 0 total_samples = 0 total_correct_top5 = 0 loss_meters = defaultdict(AverageMeter) for batch in tqdm(self.val_loader, desc="Validation"): full = batch['full'].to(self.device) face = batch['face'].to(self.device) eye = batch['eye'].to(self.device) has_face = batch['has_face'].to(self.device) has_eye = batch['has_eye'].to(self.device) labels = batch['label'].to(self.device) with autocast(enabled=self.use_amp): output = self.model(full, face, eye, has_face, has_eye) loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) # Top-1 accuracy preds = output['cosine'].argmax(dim=1) total_correct += (preds == labels).sum().item() # Top-5 accuracy _, top5_preds = output['cosine'].topk(5, dim=1) top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) total_correct_top5 += top5_correct.any(dim=1).sum().item() total_samples += labels.size(0) for k, v in loss_dict.items(): loss_meters[k].update(v, full.size(0)) accuracy = total_correct / total_samples if total_samples > 0 else 0 accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 metrics = { 'accuracy': accuracy, 'accuracy_top5': accuracy_top5, } metrics.update({k: v.avg for k, v in loss_meters.items()}) if self.use_wandb: wandb.log({f"val/{k}": v for k, v in metrics.items()}, step=self.global_step) return metrics def save_checkpoint(self, filename: str, is_best: bool = False): checkpoint = { 'epoch': self.current_epoch, 'global_step': self.global_step, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'best_metric': self.best_metric, 'config': {'model': self.config.model.__dict__, 'train': self.config.train.__dict__}, 'artist_to_idx': self.artist_to_idx, } if self.use_amp: checkpoint['scaler_state_dict'] = self.scaler.state_dict() torch.save(checkpoint, self.save_dir / filename) if is_best: torch.save(checkpoint, self.save_dir / 'best_model.pt') def load_checkpoint(self, path: str): checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.current_epoch = checkpoint['epoch'] self.global_step = checkpoint['global_step'] self.best_metric = checkpoint['best_metric'] if self.use_amp and 'scaler_state_dict' in checkpoint: self.scaler.load_state_dict(checkpoint['scaler_state_dict']) print(f"Loaded checkpoint from epoch {self.current_epoch}") def train(self): print(f"Training for {self.config.train.epochs} epochs on {self.device}") print(f"Artists: {len(self.artist_to_idx)}") for epoch in range(self.current_epoch, self.config.train.epochs): self.current_epoch = epoch train_metrics = self.train_epoch() print(f"\nEpoch {epoch} - Train Loss: {train_metrics['loss_total']:.4f}") val_metrics = self.validate() print(f"Epoch {epoch} - Val Loss: {val_metrics['loss_total']:.4f}, " f"Acc: {val_metrics['accuracy']:.4f}, " f"Top5: {val_metrics['accuracy_top5']:.4f}") self.scheduler.step() # Best model by accuracy is_best = val_metrics['accuracy'] > self.best_metric if is_best: self.best_metric = val_metrics['accuracy'] self.patience_counter = 0 else: self.patience_counter += 1 if (epoch + 1) % self.config.train.save_every_n_epochs == 0: self.save_checkpoint(f'checkpoint_epoch_{epoch}.pt', is_best) elif is_best: self.save_checkpoint('best_model.pt', is_best=True) if self.patience_counter >= self.config.train.patience: print(f"Early stopping at epoch {epoch}") break self.save_checkpoint('final_model.pt') if self.use_wandb: wandb.finish() print(f"Training complete. Best Accuracy: {self.best_metric:.4f}")