| import torch |
| import wandb |
| from omegaconf import OmegaConf |
| from timm.utils import ModelEmaV2 |
| from torch import nn |
| from torch.amp import GradScaler, autocast |
| from torchvision.transforms import v2 |
| from tqdm import tqdm |
|
|
| from .metrics import MetricTracker |
| from .utils import EarlyStopping, save_checkpoint |
|
|
|
|
| class Trainer: |
| def __init__( |
| self, |
| model, |
| train_loader, |
| val_loader, |
| criterion, |
| optimizer, |
| scheduler, |
| config, |
| device, |
| ): |
| self.model = model |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.criterion = criterion |
| self.optimizer = optimizer |
| self.scheduler = scheduler |
| self.config = config |
| self.device = device |
|
|
| self.early_stopping = EarlyStopping( |
| patience=config.training.early_stopping_patience, mode="max" |
| ) |
|
|
| self.scaler = GradScaler(device.type, enabled=config.training.mixed_precision) |
|
|
| self.use_ema = ( |
| getattr(config.training, "ema", None) and config.training.ema.enabled |
| ) |
| if self.use_ema: |
| ema_decay = getattr(config.training.ema, "decay", 0.9999) |
| self.model_ema = ModelEmaV2(self.model, decay=ema_decay, device=device) |
| else: |
| self.model_ema = None |
|
|
| self.num_classes = config.model.num_classes |
|
|
| self.use_mixup = False |
| if config.augmentation.prob > 0: |
| self.use_mixup = True |
| cutmix = v2.CutMix( |
| alpha=config.augmentation.cutmix_alpha, num_classes=self.num_classes |
| ) |
| mixup = v2.MixUp( |
| alpha=config.augmentation.mixup_alpha, num_classes=self.num_classes |
| ) |
| self.cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) |
|
|
| self.train_metrics = MetricTracker(num_classes=self.num_classes, device=device) |
| self.val_metrics = MetricTracker(num_classes=self.num_classes, device=device) |
| if self.use_ema: |
| self.val_ema_metrics = MetricTracker( |
| num_classes=self.num_classes, device=device |
| ) |
|
|
| def train_one_epoch(self, epoch): |
| self.model.train() |
| if self.config.model.freeze_bn: |
| for module in self.model.modules(): |
| if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
| module.eval() |
|
|
| self.train_metrics.reset() |
|
|
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch} [Train]") |
| for batch_idx, (images, targets) in enumerate(pbar): |
| images, targets = images.to(self.device), targets.to(self.device) |
| is_mixed = False |
|
|
| |
| if self.use_mixup and torch.rand(1).item() < self.config.augmentation.prob: |
| images, targets = self.cutmix_or_mixup(images, targets) |
| is_mixed = True |
| if targets.ndim == 1: |
| targets = torch.nn.functional.one_hot( |
| targets, num_classes=self.num_classes |
| ).float() |
| with autocast( |
| device_type=self.device.type, |
| enabled=self.config.training.mixed_precision, |
| ): |
| outputs = self.model(images) |
| loss = self.criterion(outputs, targets) |
| |
| loss = loss / self.config.training.gradient_accumulation_steps |
|
|
| self.scaler.scale(loss).backward() |
|
|
| if (batch_idx + 1) % self.config.training.gradient_accumulation_steps == 0: |
| if self.config.training.clip_grad_norm > 0: |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.config.training.clip_grad_norm |
| ) |
|
|
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.optimizer.zero_grad() |
|
|
| if self.config.scheduler.name == "cosine_warmup": |
| self.scheduler.step() |
|
|
| if self.use_ema: |
| self.model_ema.update(self.model) |
|
|
| batch_loss = loss.item() * self.config.training.gradient_accumulation_steps |
| self.train_metrics.update( |
| outputs.detach(), |
| targets.detach(), |
| loss=batch_loss, |
| skip_metrics=is_mixed, |
| ) |
|
|
| pbar.set_postfix({"loss": f"{batch_loss:.4f}"}) |
|
|
| if self.config.logging.use_wandb: |
| wandb.log({"train/batch_loss": batch_loss}) |
|
|
| if (batch_idx + 1) % self.config.training.gradient_accumulation_steps != 0: |
| if self.config.training.clip_grad_norm > 0: |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.config.training.clip_grad_norm |
| ) |
|
|
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.optimizer.zero_grad() |
|
|
| if self.config.scheduler.name == "cosine_warmup": |
| self.scheduler.step() |
|
|
| if self.use_ema: |
| self.model_ema.update(self.model) |
|
|
| metrics = self.train_metrics.compute() |
|
|
| |
| if self.config.scheduler.name == "step": |
| self.scheduler.step() |
| elif self.config.scheduler.name == "cosine": |
| self.scheduler.step() |
|
|
| return metrics |
|
|
| def validate(self, epoch): |
| self.model.eval() |
| self.val_metrics.reset() |
|
|
| if self.use_ema: |
| self.model_ema.module.eval() |
| self.val_ema_metrics.reset() |
|
|
| pbar = tqdm(self.val_loader, desc=f"Epoch {epoch} [Val]") |
| with torch.no_grad(): |
| for images, targets in pbar: |
| images, targets = images.to(self.device), targets.to(self.device) |
|
|
| if targets.ndim == 1: |
| targets = torch.nn.functional.one_hot( |
| targets, num_classes=self.num_classes |
| ).float() |
|
|
| with autocast( |
| device_type=self.device.type, |
| enabled=self.config.training.mixed_precision, |
| ): |
| outputs = self.model(images) |
| loss = self.criterion(outputs, targets) |
|
|
| if self.use_ema: |
| ema_outputs = self.model_ema.module(images) |
| ema_loss = self.criterion(ema_outputs, targets) |
|
|
| self.val_metrics.update( |
| outputs.detach(), targets.detach(), loss=loss.detach() |
| ) |
| if self.use_ema: |
| self.val_ema_metrics.update( |
| ema_outputs.detach(), targets.detach(), loss=ema_loss.detach() |
| ) |
| pbar.set_postfix( |
| { |
| "loss": f"{loss.item():.4f}", |
| "ema_loss": f"{ema_loss.item():.4f}", |
| } |
| ) |
| else: |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
| metrics = {"current": self.val_metrics.compute()} |
| if self.use_ema: |
| metrics["ema"] = self.val_ema_metrics.compute() |
|
|
| primary_map = metrics[self.config.training.ema.eval_mode]["mAP"] |
|
|
| if self.config.scheduler.name == "plateau": |
| self.scheduler.step(primary_map) |
|
|
| return metrics |
|
|
| def fit(self, start_epoch=1): |
| best_map = 0.0 |
|
|
| for epoch in range(start_epoch, self.config.training.epochs + 1): |
| train_metrics = self.train_one_epoch(epoch) |
| val_metrics = self.validate(epoch) |
|
|
| lrs = [pg["lr"] for pg in self.optimizer.param_groups] |
|
|
| log_dict = { |
| "train/loss": train_metrics["loss"], |
| "train/mAP": train_metrics["mAP"], |
| "train/accuracy": train_metrics["accuracy"], |
| "lr/backbone": lrs[0], |
| "lr/head": lrs[1], |
| "epoch": epoch, |
| } |
|
|
| if self.use_ema: |
| log_dict.update( |
| { |
| "val/loss": val_metrics["current"]["loss"], |
| "val/mAP": val_metrics["current"]["mAP"], |
| "val/accuracy": val_metrics["current"]["accuracy"], |
| "val/ema_loss": val_metrics["ema"]["loss"], |
| "val/ema_mAP": val_metrics["ema"]["mAP"], |
| "val/ema_accuracy": val_metrics["ema"]["accuracy"], |
| } |
| ) |
| else: |
| log_dict.update( |
| { |
| "val/loss": val_metrics["current"]["loss"], |
| "val/mAP": val_metrics["current"]["mAP"], |
| "val/accuracy": val_metrics["current"]["accuracy"], |
| } |
| ) |
|
|
| if self.config.logging.use_wandb: |
| wandb.log(log_dict) |
|
|
| print(f"\nEpoch {epoch} Summary:") |
| print(f"LR: Backbone: {lrs[0]:.2e} | Head: {lrs[1]:.2e}") |
| print( |
| f"Train - Loss: {train_metrics['loss']:.4f}, mAP: {train_metrics['mAP']:.4f}, Acc: {train_metrics['accuracy']:.4f}" |
| ) |
| if self.use_ema: |
| print( |
| f"Val (Current) - Loss: {val_metrics['current']['loss']:.4f}, mAP: {val_metrics['current']['mAP']:.4f}, Acc: {val_metrics['current']['accuracy']:.4f}" |
| ) |
| print( |
| f"Val (EMA) - Loss: {val_metrics['ema']['loss']:.4f}, mAP: {val_metrics['ema']['mAP']:.4f}, Acc: {val_metrics['ema']['accuracy']:.4f}" |
| ) |
| else: |
| print( |
| f"Val - Loss: {val_metrics['current']['loss']:.4f}, mAP: {val_metrics['current']['mAP']:.4f}, Acc: {val_metrics['current']['accuracy']:.4f}" |
| ) |
|
|
| primary_map = val_metrics[self.config.training.ema.eval_mode]["mAP"] |
| is_best = self.early_stopping(primary_map) |
|
|
| if is_best: |
| best_map = primary_map |
| print(f"Epoch {epoch} is the new best model. mAP: {best_map:.4f}") |
|
|
| |
| state = { |
| "epoch": epoch, |
| "state_dict": self.model.state_dict(), |
| "state_dict_ema": self.model_ema.module.state_dict() |
| if self.use_ema |
| else None, |
| "optimizer": self.optimizer.state_dict(), |
| "scheduler": self.scheduler.state_dict() if self.scheduler else None, |
| "scaler": self.scaler.state_dict(), |
| "early_stopping": { |
| "best_score": self.early_stopping.best_score, |
| "counter": self.early_stopping.counter, |
| "early_stop": self.early_stopping.early_stop, |
| }, |
| "rng_states": { |
| "torch": torch.get_rng_state(), |
| "cuda": torch.cuda.get_rng_state_all() |
| if torch.cuda.is_available() |
| else None, |
| }, |
| "val_mAP": primary_map, |
| "config": OmegaConf.to_yaml(self.config), |
| "wandb_run_id": wandb.run.id if wandb.run is not None else None, |
| } |
| save_checkpoint(state, is_best, self.config.logging.checkpoint_dir) |
|
|
| if self.early_stopping.early_stop: |
| print(f"Early stopping triggered at epoch {epoch}") |
| break |
|
|
| print("Training complete!") |
|
|