| """DDP + AMP training loop. |
| |
| * DDP: launched via torchrun; uses DistributedSampler. Single-GPU also works. |
| * AMP: bf16 (A100+) / fp16 (V100, with GradScaler) / fp32. |
| * Best checkpoint chosen by mean foreground Dice on the val split. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import math |
| import json |
| import time |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| from .distributed import (is_dist, is_main, get_rank, get_world_size, |
| all_gather_object, print_main) |
| from .losses import build_loss |
| from ..metrics.metrics import per_image_metrics |
| from ..data.loaders import build_dataset, build_loader |
|
|
|
|
| _AMP_DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16} |
|
|
|
|
| def build_optimizer(cfg, params): |
| if cfg.optimizer == "sgd": |
| return torch.optim.SGD(params, lr=cfg.lr, momentum=0.9, |
| weight_decay=cfg.weight_decay, nesterov=True) |
| return torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay) |
|
|
|
|
| def lr_at(cfg, epoch: int) -> float: |
| if epoch < cfg.warmup_epochs: |
| return cfg.lr * (epoch + 1) / max(1, cfg.warmup_epochs) |
| e = epoch - cfg.warmup_epochs |
| total = max(1, cfg.epochs - cfg.warmup_epochs) |
| if cfg.scheduler == "poly": |
| return cfg.lr * (1 - e / total) ** 0.9 |
| if cfg.scheduler == "cosine": |
| return cfg.lr * 0.5 * (1 + math.cos(math.pi * e / total)) |
| return cfg.lr |
|
|
|
|
| class Trainer: |
| def __init__(self, cfg, model: nn.Module, local_rank: int): |
| self.cfg = cfg |
| self.device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu") |
| self.local_rank = local_rank |
|
|
| self.train_ds = build_dataset(cfg, "train") |
| self.val_ds = build_dataset(cfg, "val") |
| self.num_classes = self.train_ds.num_classes |
| self.train_loader = build_loader(cfg, "train", self.train_ds) |
| self.val_loader = build_loader(cfg, "val", self.val_ds) |
|
|
| self.model = model.to(self.device) |
| if is_dist(): |
| self.model = DDP(self.model, device_ids=[local_rank], output_device=local_rank, |
| find_unused_parameters=False) |
|
|
| self.criterion = build_loss(cfg.loss).to(self.device) |
| self.optimizer = build_optimizer(cfg, self.model.parameters()) |
| self.amp = cfg.amp |
| self.use_amp = self.amp in _AMP_DTYPE |
| self.scaler = torch.amp.GradScaler("cuda", enabled=(self.amp == "fp16")) |
| self.best = -1.0 |
| self.start_epoch = 0 |
| self.out_dir = cfg.out_dir() |
| if is_main(): |
| os.makedirs(self.out_dir, exist_ok=True) |
| cfg.to_yaml(os.path.join(self.out_dir, "config.yaml")) |
| if cfg.resume: |
| self._load(cfg.resume) |
|
|
| |
| def _bare(self): |
| return self.model.module if is_dist() else self.model |
|
|
| def _save(self, name: str, epoch: int): |
| if not is_main(): |
| return |
| torch.save({ |
| "epoch": epoch, |
| "model": self._bare().state_dict(), |
| "optimizer": self.optimizer.state_dict(), |
| "best": self.best, |
| "num_classes": self.num_classes, |
| "config": self.cfg.__dict__, |
| }, os.path.join(self.out_dir, name)) |
|
|
| def _load(self, path: str): |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| self._bare().load_state_dict(ckpt["model"]) |
| if "optimizer" in ckpt: |
| self.optimizer.load_state_dict(ckpt["optimizer"]) |
| self.best = ckpt.get("best", -1.0) |
| self.start_epoch = ckpt.get("epoch", -1) + 1 |
| print_main(f"[resume] from {path} at epoch {self.start_epoch}") |
|
|
| |
| def _autocast(self): |
| if self.use_amp: |
| return torch.autocast("cuda", dtype=_AMP_DTYPE[self.amp]) |
| return torch.autocast("cuda", enabled=False) |
|
|
| def train_one_epoch(self, epoch: int): |
| self.model.train() |
| if is_dist(): |
| self.train_loader.sampler.set_epoch(epoch) |
| for g in self.optimizer.param_groups: |
| g["lr"] = lr_at(self.cfg, epoch) |
|
|
| running, n = 0.0, 0 |
| t0 = time.time() |
| for it, batch in enumerate(self.train_loader): |
| img = batch["image"].to(self.device, non_blocking=True) |
| msk = batch["mask"].to(self.device, non_blocking=True) |
| self.optimizer.zero_grad(set_to_none=True) |
| with self._autocast(): |
| logits = self.model(img) |
| loss = self.criterion(logits, msk) |
| if self.amp == "fp16": |
| self.scaler.scale(loss).backward() |
| if self.cfg.grad_clip > 0: |
| self.scaler.unscale_(self.optimizer) |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| loss.backward() |
| if self.cfg.grad_clip > 0: |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) |
| self.optimizer.step() |
| running += loss.item() * img.size(0) |
| n += img.size(0) |
| if is_main(): |
| print_main(f"[ep {epoch:03d}] loss={running/max(1,n):.4f} " |
| f"lr={self.optimizer.param_groups[0]['lr']:.2e} " |
| f"({time.time()-t0:.1f}s)") |
|
|
| @torch.no_grad() |
| def validate(self) -> float: |
| self.model.eval() |
| records = [] |
| for batch in self.val_loader: |
| img = batch["image"].to(self.device, non_blocking=True) |
| msk = batch["mask"].numpy() |
| with self._autocast(): |
| logits = self.model(img) |
| pred = logits.argmax(1).cpu().numpy() |
| for i in range(pred.shape[0]): |
| records.append(per_image_metrics( |
| pred[i], msk[i], self.num_classes, |
| include_background=self.cfg.include_background, |
| compute_hd95=False)) |
| gathered = all_gather_object(records) |
| flat = [r for part in gathered for r in part] |
| dices = np.array([r["dice"] for r in flat], dtype=np.float64) |
| dices = dices[~np.isnan(dices)] |
| return float(dices.mean()) if dices.size else 0.0 |
|
|
| def fit(self): |
| best_epoch = self.start_epoch - 1 |
| for epoch in range(self.start_epoch, self.cfg.epochs): |
| self.train_one_epoch(epoch) |
| do_val = ((epoch + 1) % self.cfg.val_interval == 0) or (epoch + 1 == self.cfg.epochs) |
| if do_val: |
| dice = self.validate() |
| if dice > self.best: |
| self.best = dice |
| best_epoch = epoch |
| self._save("best.pth", epoch) |
| print_main(f"[ep {epoch:03d}] val_dice={dice:.4f} " |
| f"(best={self.best:.4f} @ep{best_epoch})") |
| |
| if (self.cfg.patience > 0 and (epoch + 1) >= self.cfg.min_epochs |
| and (epoch - best_epoch) >= self.cfg.patience): |
| print_main(f"[early-stop] no val improvement for {epoch - best_epoch} epochs " |
| f"(patience={self.cfg.patience}); best={self.best:.4f} @ep{best_epoch}") |
| self._save("last.pth", epoch) |
| break |
| if self.cfg.save_interval and (epoch + 1) % self.cfg.save_interval == 0: |
| self._save(f"epoch{epoch+1}.pth", epoch) |
| self._save("last.pth", epoch) |
| print_main(f"[done] best val_dice={self.best:.4f} @ep{best_epoch} -> {self.out_dir}/best.pth") |
|
|