"""DDP + AMP training loop. * DDP: launched via torchrun; uses DistributedSampler. Single-GPU also works. * AMP: bf16 (A100+) / fp16 (V100, with GradScaler) / fp32. * Best checkpoint chosen by mean foreground Dice on the val split. """ from __future__ import annotations import os import math import json import time import numpy as np import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from .distributed import (is_dist, is_main, get_rank, get_world_size, all_gather_object, print_main) from .losses import build_loss from ..metrics.metrics import per_image_metrics from ..data.loaders import build_dataset, build_loader _AMP_DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16} def build_optimizer(cfg, params): if cfg.optimizer == "sgd": return torch.optim.SGD(params, lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay, nesterov=True) return torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay) def lr_at(cfg, epoch: int) -> float: if epoch < cfg.warmup_epochs: return cfg.lr * (epoch + 1) / max(1, cfg.warmup_epochs) e = epoch - cfg.warmup_epochs total = max(1, cfg.epochs - cfg.warmup_epochs) if cfg.scheduler == "poly": return cfg.lr * (1 - e / total) ** 0.9 if cfg.scheduler == "cosine": return cfg.lr * 0.5 * (1 + math.cos(math.pi * e / total)) return cfg.lr class Trainer: def __init__(self, cfg, model: nn.Module, local_rank: int): self.cfg = cfg self.device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") self.local_rank = local_rank self.train_ds = build_dataset(cfg, "train") self.val_ds = build_dataset(cfg, "val") self.num_classes = self.train_ds.num_classes self.train_loader = build_loader(cfg, "train", self.train_ds) self.val_loader = build_loader(cfg, "val", self.val_ds) self.model = model.to(self.device) if is_dist(): self.model = DDP(self.model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) self.criterion = build_loss(cfg.loss).to(self.device) self.optimizer = build_optimizer(cfg, self.model.parameters()) self.amp = cfg.amp self.use_amp = self.amp in _AMP_DTYPE self.scaler = torch.amp.GradScaler("cuda", enabled=(self.amp == "fp16")) self.best = -1.0 self.start_epoch = 0 self.out_dir = cfg.out_dir() if is_main(): os.makedirs(self.out_dir, exist_ok=True) cfg.to_yaml(os.path.join(self.out_dir, "config.yaml")) if cfg.resume: self._load(cfg.resume) # ---- checkpoint ---- def _bare(self): return self.model.module if is_dist() else self.model def _save(self, name: str, epoch: int): if not is_main(): return torch.save({ "epoch": epoch, "model": self._bare().state_dict(), "optimizer": self.optimizer.state_dict(), "best": self.best, "num_classes": self.num_classes, "config": self.cfg.__dict__, }, os.path.join(self.out_dir, name)) def _load(self, path: str): ckpt = torch.load(path, map_location="cpu", weights_only=False) self._bare().load_state_dict(ckpt["model"]) if "optimizer" in ckpt: self.optimizer.load_state_dict(ckpt["optimizer"]) self.best = ckpt.get("best", -1.0) self.start_epoch = ckpt.get("epoch", -1) + 1 print_main(f"[resume] from {path} at epoch {self.start_epoch}") # ---- loops ---- def _autocast(self): if self.use_amp: return torch.autocast("cuda", dtype=_AMP_DTYPE[self.amp]) return torch.autocast("cuda", enabled=False) def train_one_epoch(self, epoch: int): self.model.train() if is_dist(): self.train_loader.sampler.set_epoch(epoch) for g in self.optimizer.param_groups: g["lr"] = lr_at(self.cfg, epoch) running, n = 0.0, 0 t0 = time.time() for it, batch in enumerate(self.train_loader): img = batch["image"].to(self.device, non_blocking=True) msk = batch["mask"].to(self.device, non_blocking=True) self.optimizer.zero_grad(set_to_none=True) with self._autocast(): logits = self.model(img) loss = self.criterion(logits, msk) if self.amp == "fp16": self.scaler.scale(loss).backward() if self.cfg.grad_clip > 0: self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() if self.cfg.grad_clip > 0: nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) self.optimizer.step() running += loss.item() * img.size(0) n += img.size(0) if is_main(): print_main(f"[ep {epoch:03d}] loss={running/max(1,n):.4f} " f"lr={self.optimizer.param_groups[0]['lr']:.2e} " f"({time.time()-t0:.1f}s)") @torch.no_grad() def validate(self) -> float: self.model.eval() records = [] for batch in self.val_loader: img = batch["image"].to(self.device, non_blocking=True) msk = batch["mask"].numpy() with self._autocast(): logits = self.model(img) pred = logits.argmax(1).cpu().numpy() for i in range(pred.shape[0]): records.append(per_image_metrics( pred[i], msk[i], self.num_classes, include_background=self.cfg.include_background, compute_hd95=False)) gathered = all_gather_object(records) flat = [r for part in gathered for r in part] dices = np.array([r["dice"] for r in flat], dtype=np.float64) dices = dices[~np.isnan(dices)] return float(dices.mean()) if dices.size else 0.0 def fit(self): best_epoch = self.start_epoch - 1 for epoch in range(self.start_epoch, self.cfg.epochs): self.train_one_epoch(epoch) do_val = ((epoch + 1) % self.cfg.val_interval == 0) or (epoch + 1 == self.cfg.epochs) if do_val: dice = self.validate() if dice > self.best: self.best = dice best_epoch = epoch self._save("best.pth", epoch) print_main(f"[ep {epoch:03d}] val_dice={dice:.4f} " f"(best={self.best:.4f} @ep{best_epoch})") # early stopping: stop if val Dice hasn't improved for `patience` epochs if (self.cfg.patience > 0 and (epoch + 1) >= self.cfg.min_epochs and (epoch - best_epoch) >= self.cfg.patience): print_main(f"[early-stop] no val improvement for {epoch - best_epoch} epochs " f"(patience={self.cfg.patience}); best={self.best:.4f} @ep{best_epoch}") self._save("last.pth", epoch) break if self.cfg.save_interval and (epoch + 1) % self.cfg.save_interval == 0: self._save(f"epoch{epoch+1}.pth", epoch) self._save("last.pth", epoch) print_main(f"[done] best val_dice={self.best:.4f} @ep{best_epoch} -> {self.out_dir}/best.pth")