"""Training entrypoint. Single GPU: python framework/train.py --dataset cvc_clinicdb --arch unet ... Multi-GPU : torchrun --nproc_per_node=4 framework/train.py --dataset ... --arch ... """ from __future__ import annotations import os import sys # allow `python framework/train.py` (add repo root to path) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import torch import cv2 # Each DataLoader worker single-threaded for OpenCV; parallelism comes from num_workers. # Without this, cv2 spawns an nproc-sized (~384) thread pool per worker, whose per-op # dispatch overhead starves the GPU at high resolution (768) -> ~4x slower epochs. cv2.setNumThreads(1) from framework.config import Config from framework.engine.distributed import setup_distributed, cleanup_distributed, set_seed, print_main from framework.models.registry import build_model, required_img_size from framework.engine.trainer import Trainer def main(): cfg = Config.from_args() # some backbones require a fixed input size req = required_img_size(cfg.arch) if req and cfg.img_size != req: print_main(f"[info] arch '{cfg.arch}' requires img_size={req}; overriding {cfg.img_size}.") cfg.img_size = req local_rank = setup_distributed() set_seed(cfg.seed, rank=local_rank) # peek dataset to get in/out channels before building the model from framework.data.loaders import build_dataset probe = build_dataset(cfg, "train") in_ch, n_cls = probe.in_channels, probe.num_classes print_main(f"[data] {cfg.dataset}/{cfg.protocol}: in_channels={in_ch} num_classes={n_cls} " f"train={len(probe)}") model = build_model(cfg.arch, in_channels=in_ch, num_classes=n_cls, img_size=cfg.img_size, encoder=cfg.encoder, encoder_weights=cfg.encoder_weights, pretrained_ckpt=cfg.pretrained_ckpt) print_main(f"[model] {cfg.arch} params={sum(p.numel() for p in model.parameters())/1e6:.1f}M " f"amp={cfg.amp}") trainer = Trainer(cfg, model, local_rank) trainer.fit() cleanup_distributed() if __name__ == "__main__": main()