MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
5.4 kB
"""Unified experiment configuration.
A single dataclass drives every run. Values can come from (in priority order):
1. command-line flags (argparse) 2. a YAML file (--config) 3. dataclass defaults.
The same config object is used by train.py / test.py so that a training run and
its evaluation are guaranteed to agree on dataset, model, image size, etc.
"""
from __future__ import annotations
import argparse
import dataclasses
from dataclasses import dataclass, field, asdict
from typing import Optional, List
import yaml
@dataclass
class Config:
# ---- experiment identity ----
exp_name: str = "default" # results/<exp_name>/<dataset>/<arch>/seed<seed>/
seed: int = 0
# ---- data ----
data_root: str = "dataset/processed_unified"
dataset: str = "cvc_clinicdb" # folder name under data_root
protocol: str = "official" # e.g. official / fold01 ...
in_channels: int = 0 # 0 = auto-detect from metadata/first image
num_classes: int = 0 # 0 = auto-detect from metadata/masks (incl. background)
img_size: int = 256 # square resize target (Swin/TransUNet need 224)
# extra synthetic (image,mask) pairs to MERGE into the train split.
# Points at a dir laid out like a split: <synth_train_dir>/{images,masks}/.
synth_train_dir: str = "" # "" = real data only (no generative augmentation)
# ---- augmentation (conventional baseline tier) ----
aug: str = "standard" # none | standard | strong (albumentations online)
aug_backend: str = "albumentations" # albumentations | monai
normalize: str = "auto" # auto(imagenet for RGB, 0.5 for gray) | imagenet | none
# ---- model ----
arch: str = "unet" # see models/registry.py REGISTRY
encoder: str = "resnet34" # SMP encoder name (ignored by non-SMP archs)
encoder_weights: str = "imagenet" # imagenet | none
pretrained_ckpt: str = "" # ViT/Swin pretrain for transunet/swinunet (optional)
# ---- optimization ----
epochs: int = 100
batch_size: int = 16 # per-GPU batch size
lr: float = 1e-4
weight_decay: float = 1e-4
optimizer: str = "adamw" # adamw | sgd
scheduler: str = "poly" # poly | cosine | none
warmup_epochs: int = 0
loss: str = "ce_dice" # ce_dice | ce | dice
num_workers: int = 8
grad_clip: float = 0.0 # 0 = disabled
# ---- precision / hardware ----
amp: str = "bf16" # bf16(A100+) | fp16(V100) | fp32
# DDP is driven by torchrun env vars (RANK/WORLD_SIZE/LOCAL_RANK); nothing to set here.
# ---- evaluation / logging ----
val_interval: int = 5 # epochs between validations
min_epochs: int = 0 # never early-stop before this many epochs
patience: int = 0 # early-stop after this many epochs w/o val improvement (0 = off)
save_interval: int = 0 # 0 = only save best + last
include_background: bool = False # include class 0 in reported Dice/IoU
compute_hd95: bool = True
out_root: str = "results"
resume: str = "" # path to checkpoint to resume from
visualize: bool = True # save overlays at test time
vis_max: int = 32 # max number of overlay images to save
def out_dir(self) -> str:
return f"{self.out_root}/{self.exp_name}/{self.dataset}_{self.protocol}/{self.arch}/seed{self.seed}"
def to_yaml(self, path: str) -> None:
with open(path, "w") as f:
yaml.safe_dump(asdict(self), f, sort_keys=False, allow_unicode=True)
@classmethod
def from_args(cls, argv: Optional[List[str]] = None) -> "Config":
# First pass: only grab --config so YAML can set defaults that flags then override.
pre = argparse.ArgumentParser(add_help=False)
pre.add_argument("--config", type=str, default="")
known, _ = pre.parse_known_args(argv)
base = cls()
if known.config:
with open(known.config) as f:
ydata = yaml.safe_load(f) or {}
base = dataclasses.replace(base, **{k: v for k, v in ydata.items()
if k in {f.name for f in dataclasses.fields(cls)}})
p = argparse.ArgumentParser(parents=[pre],
description="SegGen unified segmentation framework")
for f in dataclasses.fields(cls):
default = getattr(base, f.name)
if f.type is bool or isinstance(default, bool):
# support --flag / --no-flag
p.add_argument(f"--{f.name}", dest=f.name, action="store_true", default=default)
p.add_argument(f"--no-{f.name}", dest=f.name, action="store_false")
else:
p.add_argument(f"--{f.name}", type=type(default) if default is not None else str,
default=default)
ns = p.parse_args(argv)
kwargs = {f.name: getattr(ns, f.name) for f in dataclasses.fields(cls)}
return cls(**kwargs)