|
|
""" |
|
|
Artist Style Embedding - Trainer |
|
|
""" |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
from collections import defaultdict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
|
|
|
try: |
|
|
import wandb |
|
|
WANDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
WANDB_AVAILABLE = False |
|
|
|
|
|
|
|
|
class AverageMeter: |
|
|
def __init__(self): |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, model, loss_fn, train_loader, val_loader, config, artist_to_idx): |
|
|
self.model = model |
|
|
self.loss_fn = loss_fn |
|
|
self.train_loader = train_loader |
|
|
self.val_loader = val_loader |
|
|
self.config = config |
|
|
self.artist_to_idx = artist_to_idx |
|
|
self.idx_to_artist = {v: k for k, v in artist_to_idx.items()} |
|
|
|
|
|
self.device = torch.device(config.train.device) |
|
|
self.model = self.model.to(self.device) |
|
|
self.loss_fn = self.loss_fn.to(self.device) |
|
|
|
|
|
self.optimizer = self._create_optimizer() |
|
|
self.scheduler = self._create_scheduler() |
|
|
|
|
|
self.use_amp = config.train.use_amp |
|
|
self.scaler = GradScaler() if self.use_amp else None |
|
|
|
|
|
self.save_dir = Path(config.train.save_dir) |
|
|
self.save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.current_epoch = 0 |
|
|
self.global_step = 0 |
|
|
self.best_metric = 0.0 |
|
|
self.patience_counter = 0 |
|
|
|
|
|
self.use_wandb = WANDB_AVAILABLE and config.train.wandb_project |
|
|
if self.use_wandb: |
|
|
wandb.init( |
|
|
project=config.train.wandb_project, |
|
|
name=config.train.wandb_run_name, |
|
|
config={'model': config.model.__dict__, 'train': config.train.__dict__} |
|
|
) |
|
|
|
|
|
def _create_optimizer(self): |
|
|
backbone_params = self.model.encoder.get_backbone_params() |
|
|
head_params = self.model.encoder.get_head_params() |
|
|
arcface_params = [self.model.arcface_weight] |
|
|
loss_params = list(self.loss_fn.center_loss.parameters()) |
|
|
|
|
|
return AdamW([ |
|
|
{'params': backbone_params, 'lr': self.config.train.learning_rate * self.config.train.backbone_lr_multiplier}, |
|
|
{'params': head_params, 'lr': self.config.train.learning_rate}, |
|
|
{'params': arcface_params, 'lr': self.config.train.learning_rate}, |
|
|
{'params': loss_params, 'lr': self.config.train.learning_rate * 0.5}, |
|
|
], weight_decay=self.config.train.weight_decay) |
|
|
|
|
|
def _create_scheduler(self): |
|
|
warmup = LinearLR(self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=self.config.train.warmup_epochs) |
|
|
main = CosineAnnealingWarmRestarts(self.optimizer, T_0=self.config.train.epochs - self.config.train.warmup_epochs, eta_min=self.config.train.min_lr) |
|
|
return SequentialLR(self.optimizer, [warmup, main], milestones=[self.config.train.warmup_epochs]) |
|
|
|
|
|
def train_epoch(self) -> Dict[str, float]: |
|
|
self.model.train() |
|
|
loss_meters = defaultdict(AverageMeter) |
|
|
|
|
|
if self.current_epoch < self.config.model.freeze_backbone_epochs: |
|
|
self.model.encoder.freeze_backbone() |
|
|
else: |
|
|
self.model.encoder.unfreeze_backbone() |
|
|
|
|
|
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}") |
|
|
|
|
|
for batch in pbar: |
|
|
full = batch['full'].to(self.device) |
|
|
face = batch['face'].to(self.device) |
|
|
eye = batch['eye'].to(self.device) |
|
|
has_face = batch['has_face'].to(self.device) |
|
|
has_eye = batch['has_eye'].to(self.device) |
|
|
labels = batch['label'].to(self.device) |
|
|
|
|
|
with autocast(enabled=self.use_amp): |
|
|
output = self.model(full, face, eye, has_face, has_eye) |
|
|
loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
if self.use_amp: |
|
|
self.scaler.scale(loss).backward() |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.train.max_grad_norm) |
|
|
self.optimizer.step() |
|
|
|
|
|
for k, v in loss_dict.items(): |
|
|
loss_meters[k].update(v, full.size(0)) |
|
|
|
|
|
pbar.set_postfix({'loss': f"{loss_meters['loss_total'].avg:.4f}"}) |
|
|
|
|
|
self.global_step += 1 |
|
|
if self.global_step % self.config.train.log_every_n_steps == 0 and self.use_wandb: |
|
|
wandb.log({f"train/{k}": v.avg for k, v in loss_meters.items()}, step=self.global_step) |
|
|
|
|
|
return {k: v.avg for k, v in loss_meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(self) -> Dict[str, float]: |
|
|
self.model.eval() |
|
|
|
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
total_correct_top5 = 0 |
|
|
loss_meters = defaultdict(AverageMeter) |
|
|
|
|
|
for batch in tqdm(self.val_loader, desc="Validation"): |
|
|
full = batch['full'].to(self.device) |
|
|
face = batch['face'].to(self.device) |
|
|
eye = batch['eye'].to(self.device) |
|
|
has_face = batch['has_face'].to(self.device) |
|
|
has_eye = batch['has_eye'].to(self.device) |
|
|
labels = batch['label'].to(self.device) |
|
|
|
|
|
with autocast(enabled=self.use_amp): |
|
|
output = self.model(full, face, eye, has_face, has_eye) |
|
|
loss, loss_dict = self.loss_fn(output['embeddings'], output['cosine'], labels) |
|
|
|
|
|
|
|
|
preds = output['cosine'].argmax(dim=1) |
|
|
total_correct += (preds == labels).sum().item() |
|
|
|
|
|
|
|
|
_, top5_preds = output['cosine'].topk(5, dim=1) |
|
|
top5_correct = top5_preds.eq(labels.view(-1, 1).expand_as(top5_preds)) |
|
|
total_correct_top5 += top5_correct.any(dim=1).sum().item() |
|
|
|
|
|
total_samples += labels.size(0) |
|
|
|
|
|
for k, v in loss_dict.items(): |
|
|
loss_meters[k].update(v, full.size(0)) |
|
|
|
|
|
accuracy = total_correct / total_samples if total_samples > 0 else 0 |
|
|
accuracy_top5 = total_correct_top5 / total_samples if total_samples > 0 else 0 |
|
|
|
|
|
metrics = { |
|
|
'accuracy': accuracy, |
|
|
'accuracy_top5': accuracy_top5, |
|
|
} |
|
|
metrics.update({k: v.avg for k, v in loss_meters.items()}) |
|
|
|
|
|
if self.use_wandb: |
|
|
wandb.log({f"val/{k}": v for k, v in metrics.items()}, step=self.global_step) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def save_checkpoint(self, filename: str, is_best: bool = False): |
|
|
checkpoint = { |
|
|
'epoch': self.current_epoch, |
|
|
'global_step': self.global_step, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
'best_metric': self.best_metric, |
|
|
'config': {'model': self.config.model.__dict__, 'train': self.config.train.__dict__}, |
|
|
'artist_to_idx': self.artist_to_idx, |
|
|
} |
|
|
if self.use_amp: |
|
|
checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
|
|
|
|
|
torch.save(checkpoint, self.save_dir / filename) |
|
|
if is_best: |
|
|
torch.save(checkpoint, self.save_dir / 'best_model.pt') |
|
|
|
|
|
def load_checkpoint(self, path: str): |
|
|
checkpoint = torch.load(path, map_location=self.device) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
self.current_epoch = checkpoint['epoch'] |
|
|
self.global_step = checkpoint['global_step'] |
|
|
self.best_metric = checkpoint['best_metric'] |
|
|
if self.use_amp and 'scaler_state_dict' in checkpoint: |
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
print(f"Loaded checkpoint from epoch {self.current_epoch}") |
|
|
|
|
|
def train(self): |
|
|
print(f"Training for {self.config.train.epochs} epochs on {self.device}") |
|
|
print(f"Artists: {len(self.artist_to_idx)}") |
|
|
|
|
|
for epoch in range(self.current_epoch, self.config.train.epochs): |
|
|
self.current_epoch = epoch |
|
|
|
|
|
train_metrics = self.train_epoch() |
|
|
print(f"\nEpoch {epoch} - Train Loss: {train_metrics['loss_total']:.4f}") |
|
|
|
|
|
val_metrics = self.validate() |
|
|
print(f"Epoch {epoch} - Val Loss: {val_metrics['loss_total']:.4f}, " |
|
|
f"Acc: {val_metrics['accuracy']:.4f}, " |
|
|
f"Top5: {val_metrics['accuracy_top5']:.4f}") |
|
|
|
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
|
is_best = val_metrics['accuracy'] > self.best_metric |
|
|
if is_best: |
|
|
self.best_metric = val_metrics['accuracy'] |
|
|
self.patience_counter = 0 |
|
|
else: |
|
|
self.patience_counter += 1 |
|
|
|
|
|
if (epoch + 1) % self.config.train.save_every_n_epochs == 0: |
|
|
self.save_checkpoint(f'checkpoint_epoch_{epoch}.pt', is_best) |
|
|
elif is_best: |
|
|
self.save_checkpoint('best_model.pt', is_best=True) |
|
|
|
|
|
if self.patience_counter >= self.config.train.patience: |
|
|
print(f"Early stopping at epoch {epoch}") |
|
|
break |
|
|
|
|
|
self.save_checkpoint('final_model.pt') |
|
|
if self.use_wandb: |
|
|
wandb.finish() |
|
|
print(f"Training complete. Best Accuracy: {self.best_metric:.4f}") |