MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
7.81 kB
"""DDP + AMP training loop.
* DDP: launched via torchrun; uses DistributedSampler. Single-GPU also works.
* AMP: bf16 (A100+) / fp16 (V100, with GradScaler) / fp32.
* Best checkpoint chosen by mean foreground Dice on the val split.
"""
from __future__ import annotations
import os
import math
import json
import time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from .distributed import (is_dist, is_main, get_rank, get_world_size,
all_gather_object, print_main)
from .losses import build_loss
from ..metrics.metrics import per_image_metrics
from ..data.loaders import build_dataset, build_loader
_AMP_DTYPE = {"bf16": torch.bfloat16, "fp16": torch.float16}
def build_optimizer(cfg, params):
if cfg.optimizer == "sgd":
return torch.optim.SGD(params, lr=cfg.lr, momentum=0.9,
weight_decay=cfg.weight_decay, nesterov=True)
return torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
def lr_at(cfg, epoch: int) -> float:
if epoch < cfg.warmup_epochs:
return cfg.lr * (epoch + 1) / max(1, cfg.warmup_epochs)
e = epoch - cfg.warmup_epochs
total = max(1, cfg.epochs - cfg.warmup_epochs)
if cfg.scheduler == "poly":
return cfg.lr * (1 - e / total) ** 0.9
if cfg.scheduler == "cosine":
return cfg.lr * 0.5 * (1 + math.cos(math.pi * e / total))
return cfg.lr
class Trainer:
def __init__(self, cfg, model: nn.Module, local_rank: int):
self.cfg = cfg
self.device = torch.device("cuda", local_rank) if torch.cuda.is_available() else torch.device("cpu")
self.local_rank = local_rank
self.train_ds = build_dataset(cfg, "train")
self.val_ds = build_dataset(cfg, "val")
self.num_classes = self.train_ds.num_classes
self.train_loader = build_loader(cfg, "train", self.train_ds)
self.val_loader = build_loader(cfg, "val", self.val_ds)
self.model = model.to(self.device)
if is_dist():
self.model = DDP(self.model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
self.criterion = build_loss(cfg.loss).to(self.device)
self.optimizer = build_optimizer(cfg, self.model.parameters())
self.amp = cfg.amp
self.use_amp = self.amp in _AMP_DTYPE
self.scaler = torch.amp.GradScaler("cuda", enabled=(self.amp == "fp16"))
self.best = -1.0
self.start_epoch = 0
self.out_dir = cfg.out_dir()
if is_main():
os.makedirs(self.out_dir, exist_ok=True)
cfg.to_yaml(os.path.join(self.out_dir, "config.yaml"))
if cfg.resume:
self._load(cfg.resume)
# ---- checkpoint ----
def _bare(self):
return self.model.module if is_dist() else self.model
def _save(self, name: str, epoch: int):
if not is_main():
return
torch.save({
"epoch": epoch,
"model": self._bare().state_dict(),
"optimizer": self.optimizer.state_dict(),
"best": self.best,
"num_classes": self.num_classes,
"config": self.cfg.__dict__,
}, os.path.join(self.out_dir, name))
def _load(self, path: str):
ckpt = torch.load(path, map_location="cpu", weights_only=False)
self._bare().load_state_dict(ckpt["model"])
if "optimizer" in ckpt:
self.optimizer.load_state_dict(ckpt["optimizer"])
self.best = ckpt.get("best", -1.0)
self.start_epoch = ckpt.get("epoch", -1) + 1
print_main(f"[resume] from {path} at epoch {self.start_epoch}")
# ---- loops ----
def _autocast(self):
if self.use_amp:
return torch.autocast("cuda", dtype=_AMP_DTYPE[self.amp])
return torch.autocast("cuda", enabled=False)
def train_one_epoch(self, epoch: int):
self.model.train()
if is_dist():
self.train_loader.sampler.set_epoch(epoch)
for g in self.optimizer.param_groups:
g["lr"] = lr_at(self.cfg, epoch)
running, n = 0.0, 0
t0 = time.time()
for it, batch in enumerate(self.train_loader):
img = batch["image"].to(self.device, non_blocking=True)
msk = batch["mask"].to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
with self._autocast():
logits = self.model(img)
loss = self.criterion(logits, msk)
if self.amp == "fp16":
self.scaler.scale(loss).backward()
if self.cfg.grad_clip > 0:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
if self.cfg.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip)
self.optimizer.step()
running += loss.item() * img.size(0)
n += img.size(0)
if is_main():
print_main(f"[ep {epoch:03d}] loss={running/max(1,n):.4f} "
f"lr={self.optimizer.param_groups[0]['lr']:.2e} "
f"({time.time()-t0:.1f}s)")
@torch.no_grad()
def validate(self) -> float:
self.model.eval()
records = []
for batch in self.val_loader:
img = batch["image"].to(self.device, non_blocking=True)
msk = batch["mask"].numpy()
with self._autocast():
logits = self.model(img)
pred = logits.argmax(1).cpu().numpy()
for i in range(pred.shape[0]):
records.append(per_image_metrics(
pred[i], msk[i], self.num_classes,
include_background=self.cfg.include_background,
compute_hd95=False))
gathered = all_gather_object(records)
flat = [r for part in gathered for r in part]
dices = np.array([r["dice"] for r in flat], dtype=np.float64)
dices = dices[~np.isnan(dices)]
return float(dices.mean()) if dices.size else 0.0
def fit(self):
best_epoch = self.start_epoch - 1
for epoch in range(self.start_epoch, self.cfg.epochs):
self.train_one_epoch(epoch)
do_val = ((epoch + 1) % self.cfg.val_interval == 0) or (epoch + 1 == self.cfg.epochs)
if do_val:
dice = self.validate()
if dice > self.best:
self.best = dice
best_epoch = epoch
self._save("best.pth", epoch)
print_main(f"[ep {epoch:03d}] val_dice={dice:.4f} "
f"(best={self.best:.4f} @ep{best_epoch})")
# early stopping: stop if val Dice hasn't improved for `patience` epochs
if (self.cfg.patience > 0 and (epoch + 1) >= self.cfg.min_epochs
and (epoch - best_epoch) >= self.cfg.patience):
print_main(f"[early-stop] no val improvement for {epoch - best_epoch} epochs "
f"(patience={self.cfg.patience}); best={self.best:.4f} @ep{best_epoch}")
self._save("last.pth", epoch)
break
if self.cfg.save_interval and (epoch + 1) % self.cfg.save_interval == 0:
self._save(f"epoch{epoch+1}.pth", epoch)
self._save("last.pth", epoch)
print_main(f"[done] best val_dice={self.best:.4f} @ep{best_epoch} -> {self.out_dir}/best.pth")