code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Unified experiment configuration. | |
| A single dataclass drives every run. Values can come from (in priority order): | |
| 1. command-line flags (argparse) 2. a YAML file (--config) 3. dataclass defaults. | |
| The same config object is used by train.py / test.py so that a training run and | |
| its evaluation are guaranteed to agree on dataset, model, image size, etc. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import dataclasses | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Optional, List | |
| import yaml | |
| class Config: | |
| # ---- experiment identity ---- | |
| exp_name: str = "default" # results/<exp_name>/<dataset>/<arch>/seed<seed>/ | |
| seed: int = 0 | |
| # ---- data ---- | |
| data_root: str = "dataset/processed_unified" | |
| dataset: str = "cvc_clinicdb" # folder name under data_root | |
| protocol: str = "official" # e.g. official / fold01 ... | |
| in_channels: int = 0 # 0 = auto-detect from metadata/first image | |
| num_classes: int = 0 # 0 = auto-detect from metadata/masks (incl. background) | |
| img_size: int = 256 # square resize target (Swin/TransUNet need 224) | |
| # extra synthetic (image,mask) pairs to MERGE into the train split. | |
| # Points at a dir laid out like a split: <synth_train_dir>/{images,masks}/. | |
| synth_train_dir: str = "" # "" = real data only (no generative augmentation) | |
| # ---- augmentation (conventional baseline tier) ---- | |
| aug: str = "standard" # none | standard | strong (albumentations online) | |
| aug_backend: str = "albumentations" # albumentations | monai | |
| normalize: str = "auto" # auto(imagenet for RGB, 0.5 for gray) | imagenet | none | |
| # ---- model ---- | |
| arch: str = "unet" # see models/registry.py REGISTRY | |
| encoder: str = "resnet34" # SMP encoder name (ignored by non-SMP archs) | |
| encoder_weights: str = "imagenet" # imagenet | none | |
| pretrained_ckpt: str = "" # ViT/Swin pretrain for transunet/swinunet (optional) | |
| # ---- optimization ---- | |
| epochs: int = 100 | |
| batch_size: int = 16 # per-GPU batch size | |
| lr: float = 1e-4 | |
| weight_decay: float = 1e-4 | |
| optimizer: str = "adamw" # adamw | sgd | |
| scheduler: str = "poly" # poly | cosine | none | |
| warmup_epochs: int = 0 | |
| loss: str = "ce_dice" # ce_dice | ce | dice | |
| num_workers: int = 8 | |
| grad_clip: float = 0.0 # 0 = disabled | |
| # ---- precision / hardware ---- | |
| amp: str = "bf16" # bf16(A100+) | fp16(V100) | fp32 | |
| # DDP is driven by torchrun env vars (RANK/WORLD_SIZE/LOCAL_RANK); nothing to set here. | |
| # ---- evaluation / logging ---- | |
| val_interval: int = 5 # epochs between validations | |
| min_epochs: int = 0 # never early-stop before this many epochs | |
| patience: int = 0 # early-stop after this many epochs w/o val improvement (0 = off) | |
| save_interval: int = 0 # 0 = only save best + last | |
| include_background: bool = False # include class 0 in reported Dice/IoU | |
| compute_hd95: bool = True | |
| out_root: str = "results" | |
| resume: str = "" # path to checkpoint to resume from | |
| visualize: bool = True # save overlays at test time | |
| vis_max: int = 32 # max number of overlay images to save | |
| def out_dir(self) -> str: | |
| return f"{self.out_root}/{self.exp_name}/{self.dataset}_{self.protocol}/{self.arch}/seed{self.seed}" | |
| def to_yaml(self, path: str) -> None: | |
| with open(path, "w") as f: | |
| yaml.safe_dump(asdict(self), f, sort_keys=False, allow_unicode=True) | |
| def from_args(cls, argv: Optional[List[str]] = None) -> "Config": | |
| # First pass: only grab --config so YAML can set defaults that flags then override. | |
| pre = argparse.ArgumentParser(add_help=False) | |
| pre.add_argument("--config", type=str, default="") | |
| known, _ = pre.parse_known_args(argv) | |
| base = cls() | |
| if known.config: | |
| with open(known.config) as f: | |
| ydata = yaml.safe_load(f) or {} | |
| base = dataclasses.replace(base, **{k: v for k, v in ydata.items() | |
| if k in {f.name for f in dataclasses.fields(cls)}}) | |
| p = argparse.ArgumentParser(parents=[pre], | |
| description="SegGen unified segmentation framework") | |
| for f in dataclasses.fields(cls): | |
| default = getattr(base, f.name) | |
| if f.type is bool or isinstance(default, bool): | |
| # support --flag / --no-flag | |
| p.add_argument(f"--{f.name}", dest=f.name, action="store_true", default=default) | |
| p.add_argument(f"--no-{f.name}", dest=f.name, action="store_false") | |
| else: | |
| p.add_argument(f"--{f.name}", type=type(default) if default is not None else str, | |
| default=default) | |
| ns = p.parse_args(argv) | |
| kwargs = {f.name: getattr(ns, f.name) for f in dataclasses.fields(cls)} | |
| return cls(**kwargs) | |