iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Trainer
"""
from pathlib import Path
from typing import Dict
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
from tqdm import tqdm
import numpy as np
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Trainer:
def __init__(self, model, loss_fn, train_loader, val_loader, config, artist_to_idx):
self.model = model
self.loss_fn = loss_fn
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.artist_to_idx = artist_to_idx
self.idx_to_artist = {v: k for k, v in artist_to_idx.items()}
self.device = torch.device(config.train.device)
self.model = self.model.to(self.device)
self.loss_fn = self.loss_fn.to(self.device)
self.optimizer = self._create_optimizer()
self.scheduler = self._create_scheduler()
self.use_amp = config.train.use_amp
self.scaler = GradScaler() if self.use_amp else None
self.save_dir = Path(config.train.save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.current_epoch = 0
self.global_step = 0
self.best_metric = 0.0
self.patience_counter = 0
self.use_wandb = WANDB_AVAILABLE and config.train.wandb_project
if self.use_wandb:
wandb.init(
project=config.train.wandb_project,
name=config.train.wandb_run_name,
config={'model': config.model.__dict__, 'train': config.train.__dict__}
)
def _create_optimizer(self):
backbone_params = self.model.encoder.get_backbone_params()
head_params = self.model.encoder.get_head_params()
arcface_params = [self.model.arcface_weight]
loss_params = list(self.loss_fn.center_loss.parameters())
return AdamW([
{'params': backbone_params, 'lr': self.config.train.learning_rate * self.config.train.backbone_lr_multiplier},
{'params': head_params, 'lr': self.config.train.learning_rate},
{'params': arcface_params, 'lr': self.config.train.learning_rate},
{'params': loss_params, 'lr': self.config.train.learning_rate * 0.5},
], weight_decay=self.config.train.weight_decay)
def _create_scheduler(self):
warmup = LinearLR(self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config.train.warmup_epochs)
main = CosineAnnealingWarmRestarts(self.optimizer, T_0=self.config.train.epochs - self.config.train.warmup_epochs, eta_min=self.config.train.min_lr)
return SequentialLR(self.optimizer, [warmup, main], milestones=[self.config.train.warmup_epochs])
def train_epoch(self) -> Dict[str, float]:
self.model.train()
loss_meters = defaultdict(AverageMeter)
if self.current_epoch < self.config.model.freeze_backbone_epochs:
self.model.encoder.freeze_backbone()
else:
self.model.encoder.unfreeze_backbone()
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
for batch in pbar:
full = batch['full'].to(self.device)
face = batch['face'].to(self.device)
eye = batch['eye'].to(self.device)
has_face = batch['has_face'].to(self.device)
has_eye = batch['has_eye'].to(self.device)
labels = batch['label'].to(self.device)
with autocast(enabled=self.use_amp):
output = self.model(full, face, eye, has_face, has_eye)
loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels)
self.optimizer.zero_grad()
if self.use_amp:
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm)
self.optimizer.step()
for k, v in loss_dict.items():
loss_meters[k].update(v, full.size(0))
pbar.set_postfix({'loss': f"{loss_meters['loss_total'].avg:.4f}"})
self.global_step += 1
if self.global_step % self.config.train.log_every_n_steps == 0 and self.use_wandb:
wandb.log({f"train/{k}": v.avg for k, v in loss_meters.items()}, step=self.global_step)
return {k: v.avg for k, v in loss_meters.items()}
@torch.no_grad()
def validate(self) -> Dict[str, float]:
self.model.eval()
total_correct = 0
total_samples = 0
total_correct_top5 = 0
loss_meters = defaultdict(AverageMeter)
for batch in tqdm(self.val_loader, desc="Validation"):
full = batch['full'].to(self.device)
face = batch['face'].to(self.device)
eye = batch['eye'].to(self.device)
has_face = batch['has_face'].to(self.device)
has_eye = batch['has_eye'].to(self.device)
labels = batch['label'].to(self.device)
with autocast(enabled=self.use_amp):
output = self.model(full, face, eye, has_face, has_eye)
loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels)
# Top-1 accuracy
preds = output['cosine'].argmax(dim=1)
total_correct += (preds == labels).sum().item()
# Top-5 accuracy
_, top5_preds = output['cosine'].topk(5, dim=1)
top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds))
total_correct_top5 += top5_correct.any(dim=1).sum().item()
total_samples += labels.size(0)
for k, v in loss_dict.items():
loss_meters[k].update(v, full.size(0))
accuracy = total_correct / total_samples if total_samples > 0 else 0
accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0
metrics = {
'accuracy': accuracy,
'accuracy_top5': accuracy_top5,
}
metrics.update({k: v.avg for k, v in loss_meters.items()})
if self.use_wandb:
wandb.log({f"val/{k}": v for k, v in metrics.items()}, step=self.global_step)
return metrics
def save_checkpoint(self, filename: str, is_best: bool = False):
checkpoint = {
'epoch': self.current_epoch,
'global_step': self.global_step,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_metric': self.best_metric,
'config': {'model': self.config.model.__dict__, 'train': self.config.train.__dict__},
'artist_to_idx': self.artist_to_idx,
}
if self.use_amp:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
torch.save(checkpoint, self.save_dir / filename)
if is_best:
torch.save(checkpoint, self.save_dir / 'best_model.pt')
def load_checkpoint(self, path: str):
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.current_epoch = checkpoint['epoch']
self.global_step = checkpoint['global_step']
self.best_metric = checkpoint['best_metric']
if self.use_amp and 'scaler_state_dict' in checkpoint:
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
print(f"Loaded checkpoint from epoch {self.current_epoch}")
def train(self):
print(f"Training for {self.config.train.epochs} epochs on {self.device}")
print(f"Artists: {len(self.artist_to_idx)}")
for epoch in range(self.current_epoch, self.config.train.epochs):
self.current_epoch = epoch
train_metrics = self.train_epoch()
print(f"\nEpoch {epoch} - Train Loss: {train_metrics['loss_total']:.4f}")
val_metrics = self.validate()
print(f"Epoch {epoch} - Val Loss: {val_metrics['loss_total']:.4f}, "
f"Acc: {val_metrics['accuracy']:.4f}, "
f"Top5: {val_metrics['accuracy_top5']:.4f}")
self.scheduler.step()
# Best model by accuracy
is_best = val_metrics['accuracy'] > self.best_metric
if is_best:
self.best_metric = val_metrics['accuracy']
self.patience_counter = 0
else:
self.patience_counter += 1
if (epoch + 1) % self.config.train.save_every_n_epochs == 0:
self.save_checkpoint(f'checkpoint_epoch_{epoch}.pt', is_best)
elif is_best:
self.save_checkpoint('best_model.pt', is_best=True)
if self.patience_counter >= self.config.train.patience:
print(f"Early stopping at epoch {epoch}")
break
self.save_checkpoint('final_model.pt')
if self.use_wandb:
wandb.finish()
print(f"Training complete. Best Accuracy: {self.best_metric:.4f}")