"""Unified experiment configuration. A single dataclass drives every run. Values can come from (in priority order): 1. command-line flags (argparse) 2. a YAML file (--config) 3. dataclass defaults. The same config object is used by train.py / test.py so that a training run and its evaluation are guaranteed to agree on dataset, model, image size, etc. """ from __future__ import annotations import argparse import dataclasses from dataclasses import dataclass, field, asdict from typing import Optional, List import yaml @dataclass class Config: # ---- experiment identity ---- exp_name: str = "default" # results////seed/ seed: int = 0 # ---- data ---- data_root: str = "dataset/processed_unified" dataset: str = "cvc_clinicdb" # folder name under data_root protocol: str = "official" # e.g. official / fold01 ... in_channels: int = 0 # 0 = auto-detect from metadata/first image num_classes: int = 0 # 0 = auto-detect from metadata/masks (incl. background) img_size: int = 256 # square resize target (Swin/TransUNet need 224) # extra synthetic (image,mask) pairs to MERGE into the train split. # Points at a dir laid out like a split: /{images,masks}/. synth_train_dir: str = "" # "" = real data only (no generative augmentation) # ---- augmentation (conventional baseline tier) ---- aug: str = "standard" # none | standard | strong (albumentations online) aug_backend: str = "albumentations" # albumentations | monai normalize: str = "auto" # auto(imagenet for RGB, 0.5 for gray) | imagenet | none # ---- model ---- arch: str = "unet" # see models/registry.py REGISTRY encoder: str = "resnet34" # SMP encoder name (ignored by non-SMP archs) encoder_weights: str = "imagenet" # imagenet | none pretrained_ckpt: str = "" # ViT/Swin pretrain for transunet/swinunet (optional) # ---- optimization ---- epochs: int = 100 batch_size: int = 16 # per-GPU batch size lr: float = 1e-4 weight_decay: float = 1e-4 optimizer: str = "adamw" # adamw | sgd scheduler: str = "poly" # poly | cosine | none warmup_epochs: int = 0 loss: str = "ce_dice" # ce_dice | ce | dice num_workers: int = 8 grad_clip: float = 0.0 # 0 = disabled # ---- precision / hardware ---- amp: str = "bf16" # bf16(A100+) | fp16(V100) | fp32 # DDP is driven by torchrun env vars (RANK/WORLD_SIZE/LOCAL_RANK); nothing to set here. # ---- evaluation / logging ---- val_interval: int = 5 # epochs between validations min_epochs: int = 0 # never early-stop before this many epochs patience: int = 0 # early-stop after this many epochs w/o val improvement (0 = off) save_interval: int = 0 # 0 = only save best + last include_background: bool = False # include class 0 in reported Dice/IoU compute_hd95: bool = True out_root: str = "results" resume: str = "" # path to checkpoint to resume from visualize: bool = True # save overlays at test time vis_max: int = 32 # max number of overlay images to save def out_dir(self) -> str: return f"{self.out_root}/{self.exp_name}/{self.dataset}_{self.protocol}/{self.arch}/seed{self.seed}" def to_yaml(self, path: str) -> None: with open(path, "w") as f: yaml.safe_dump(asdict(self), f, sort_keys=False, allow_unicode=True) @classmethod def from_args(cls, argv: Optional[List[str]] = None) -> "Config": # First pass: only grab --config so YAML can set defaults that flags then override. pre = argparse.ArgumentParser(add_help=False) pre.add_argument("--config", type=str, default="") known, _ = pre.parse_known_args(argv) base = cls() if known.config: with open(known.config) as f: ydata = yaml.safe_load(f) or {} base = dataclasses.replace(base, **{k: v for k, v in ydata.items() if k in {f.name for f in dataclasses.fields(cls)}}) p = argparse.ArgumentParser(parents=[pre], description="SegGen unified segmentation framework") for f in dataclasses.fields(cls): default = getattr(base, f.name) if f.type is bool or isinstance(default, bool): # support --flag / --no-flag p.add_argument(f"--{f.name}", dest=f.name, action="store_true", default=default) p.add_argument(f"--no-{f.name}", dest=f.name, action="store_false") else: p.add_argument(f"--{f.name}", type=type(default) if default is not None else str, default=default) ns = p.parse_args(argv) kwargs = {f.name: getattr(ns, f.name) for f in dataclasses.fields(cls)} return cls(**kwargs)