| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import List | |
| import torch | |
| class Config: | |
| # ββ Data paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| csv_path: str = "../../data/train_val.csv" | |
| image_dir: str = "../../data/images" | |
| submission_test_dir: str = "../../data/test_images" | |
| output_dir: str = "results" | |
| results_log_path: str = "results_log.csv" # global run log (one row per training run) | |
| # ββ Reproducibility ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| seed: int = 42 | |
| # ββ Image / DataLoader βββββββββββββββββββββββββββββββββββββββββββββββ | |
| img_size: int = 300 # torchxrayvision DenseNet-121 native resolution | |
| batch_size: int = 32 | |
| num_workers: int = 4 | |
| # ββ Train / val / test split (stratified; disjoint image rows) βββββββ | |
| val_size: float = 0.15 | |
| test_size: float = 0.10 | |
| # ββ Training schedule (two-stage) ββββββββββββββββββββββββββββββββββββ | |
| frozen_epochs: int = 3 # stage 1: head-only warmup | |
| finetune_epochs: int = 22 # stage 2: full unfreeze with cosine LR | |
| # Stage 2: linear LR warmup (all param groups) before cosine decay. 0 = disabled. | |
| # Clamped to < finetune_epochs at runtime. | |
| finetune_warmup_epochs: int = 0 | |
| early_stop_patience: int = 6 # early stop when val checkpoint metric plateaus (stage 2) | |
| # Metric for best checkpoint + early stopping in stage 2 (finetune): | |
| # "composite" β 0.5Β·val_AUC + 0.25Β·val_sens + 0.25Β·val_spec (threshold 0.5) | |
| # "auc" β val ROC-AUC only | |
| # "sensitivity" β val sensitivity at threshold 0.5 (maximise recall of positives) | |
| checkpoint_metric: str = "composite" | |
| # BCE positive-class weight: 0 = disabled. | |
| # If > 0: pos_weight = scale * (n_neg / n_pos) on the *training* split (computed once). | |
| # scale=1.0 balances errors by inverse frequency; 0.5 is a gentler boost (often safer). | |
| bce_pos_weight_scale: float = 0.0 | |
| # How many backbone blocks to keep frozen in stage 2 (0 = unfreeze all): | |
| # DenseNet-121 : 0β4 dense block groups | |
| # RAD-DINO ViT : 0β12 transformer blocks (recommended: 8) | |
| frozen_blocks: int = 0 | |
| # ββ Optimiser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| head_lr: float = 3e-4 # classifier LR (both stages) | |
| backbone_lr: float = 1e-4 # features LR (stage 2 only) | |
| weight_decay: float = 1e-4 | |
| grad_clip: float = 1.0 | |
| # ββ Data augmentation ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Mixup: interpolates two samples and their labels in every training batch. | |
| # mixup_alpha > 0 enables it; Ξ» ~ Beta(Ξ±, Ξ±). 0 = disabled. | |
| # Typical range: 0.2 β 0.4. | |
| mixup_alpha: float = 0.0 | |
| # Label smoothing: prevents overconfidence by softening hard {0,1} targets. | |
| # y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅. 0 = disabled. Typical range: 0.05 β 0.15. | |
| label_smoothing: float = 0.0 | |
| # ββ Architecture βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Options: "densenet121" | "rad-dino" | "mobilenet_v3_large" | "efficientnet_b0" | "efficientnet_b3" | |
| # densenet121 β torchxrayvision DenseNet-121, pretrained on ~1M chest X-rays (recommended) | |
| # rad-dino β microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M chest X-rays; | |
| # use img_size=518 (native: 37Γ37 patches at 14 px); 12 frozen_blocks max | |
| # mobilenet_v3_large β torchvision MobileNetV3-Large, pretrained on ImageNet (faster, lighter) | |
| # efficientnet_b0 β torchvision EfficientNet-B0, pretrained on ImageNet (good accuracy/size trade-off) | |
| # efficientnet_b3 β torchvision EfficientNet-B3, pretrained on ImageNet (higher accuracy, more params) | |
| backbone: str = "rad-dino" | |
| # ββ Ensemble βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # True: train one model per entry in `seeds` and average predictions | |
| # False: train a single model using only `seed` (faster experimentation) | |
| use_ensemble: bool = True | |
| # ββ Multi-seed ensemble ββββββββββββββββββββββββββββββββββββββββββββββ | |
| seeds: List[int] = field(default_factory=lambda: [42, 7, 2024]) | |
| # ββ Loss function βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # False: standard BCE | True: 0.5*BCE + 0.5*(1 - soft_composite) | |
| use_composite_loss: bool = False | |
| # Blend weight Ξ±: Ξ±Β·BCE + (1-Ξ±)Β·(1-soft_composite). 0 = pure composite, 1 = pure BCE. | |
| composite_loss_alpha: float = 0.5 | |
| # Temperature for the pairwise-sigmoid soft-AUC term (higher β sharper ranking signal) | |
| composite_loss_gamma: float = 1.0 | |
| # SoftCompositeLoss: Ο(thr_tempΒ·logit) approximates I[logit>0] (aligns with prob 0.5 threshold) | |
| composite_thr_temperature: float = 6.0 | |
| # If fewer hard positives or negatives than this in a batch, skip composite term (BCE only). | |
| composite_min_class_per_batch: int = 2 | |
| # ββ Inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tta_passes: int = 6 # number of deterministic TTA transforms (max 6) | |
| n_bootstrap: int = 1000 # bootstrap iterations for threshold stabilisation | |
| # ββ Device (auto-detected) βββββββββββββββββββββββββββββββββββββββββββ | |
| device: str = ( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| def setup(self) -> "Config": | |
| """Create output directory and return self (for chaining).""" | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| return self | |
| # Global singleton β import and use directly, or override fields before training | |
| CFG = Config().setup() | |