cardio-deploy
Deploy CardioScan inference 2026-04-24T10:51:24Z
41d6ec3
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import List
import torch
@dataclass
class Config:
# ── Data paths ────────────────────────────────────────────────────────
csv_path: str = "../../data/train_val.csv"
image_dir: str = "../../data/images"
submission_test_dir: str = "../../data/test_images"
output_dir: str = "results"
results_log_path: str = "results_log.csv" # global run log (one row per training run)
# ── Reproducibility ──────────────────────────────────────────────────
seed: int = 42
# ── Image / DataLoader ───────────────────────────────────────────────
img_size: int = 300 # torchxrayvision DenseNet-121 native resolution
batch_size: int = 32
num_workers: int = 4
# ── Train / val / test split (stratified; disjoint image rows) ───────
val_size: float = 0.15
test_size: float = 0.10
# ── Training schedule (two-stage) ────────────────────────────────────
frozen_epochs: int = 3 # stage 1: head-only warmup
finetune_epochs: int = 22 # stage 2: full unfreeze with cosine LR
# Stage 2: linear LR warmup (all param groups) before cosine decay. 0 = disabled.
# Clamped to < finetune_epochs at runtime.
finetune_warmup_epochs: int = 0
early_stop_patience: int = 6 # early stop when val checkpoint metric plateaus (stage 2)
# Metric for best checkpoint + early stopping in stage 2 (finetune):
# "composite" β€” 0.5Β·val_AUC + 0.25Β·val_sens + 0.25Β·val_spec (threshold 0.5)
# "auc" β€” val ROC-AUC only
# "sensitivity" β€” val sensitivity at threshold 0.5 (maximise recall of positives)
checkpoint_metric: str = "composite"
# BCE positive-class weight: 0 = disabled.
# If > 0: pos_weight = scale * (n_neg / n_pos) on the *training* split (computed once).
# scale=1.0 balances errors by inverse frequency; 0.5 is a gentler boost (often safer).
bce_pos_weight_scale: float = 0.0
# How many backbone blocks to keep frozen in stage 2 (0 = unfreeze all):
# DenseNet-121 : 0–4 dense block groups
# RAD-DINO ViT : 0–12 transformer blocks (recommended: 8)
frozen_blocks: int = 0
# ── Optimiser ────────────────────────────────────────────────────────
head_lr: float = 3e-4 # classifier LR (both stages)
backbone_lr: float = 1e-4 # features LR (stage 2 only)
weight_decay: float = 1e-4
grad_clip: float = 1.0
# ── Data augmentation ────────────────────────────────────────────────
# Mixup: interpolates two samples and their labels in every training batch.
# mixup_alpha > 0 enables it; Ξ» ~ Beta(Ξ±, Ξ±). 0 = disabled.
# Typical range: 0.2 – 0.4.
mixup_alpha: float = 0.0
# Label smoothing: prevents overconfidence by softening hard {0,1} targets.
# y_smooth = y*(1-Ξ΅) + 0.5*Ξ΅. 0 = disabled. Typical range: 0.05 – 0.15.
label_smoothing: float = 0.0
# ── Architecture ─────────────────────────────────────────────────────
# Options: "densenet121" | "rad-dino" | "mobilenet_v3_large" | "efficientnet_b0" | "efficientnet_b3"
# densenet121 β€” torchxrayvision DenseNet-121, pretrained on ~1M chest X-rays (recommended)
# rad-dino β€” microsoft/rad-dino, DINOv2 ViT-B/14 pretrained on ~1M chest X-rays;
# use img_size=518 (native: 37Γ—37 patches at 14 px); 12 frozen_blocks max
# mobilenet_v3_large β€” torchvision MobileNetV3-Large, pretrained on ImageNet (faster, lighter)
# efficientnet_b0 β€” torchvision EfficientNet-B0, pretrained on ImageNet (good accuracy/size trade-off)
# efficientnet_b3 β€” torchvision EfficientNet-B3, pretrained on ImageNet (higher accuracy, more params)
backbone: str = "rad-dino"
# ── Ensemble ─────────────────────────────────────────────────────────
# True: train one model per entry in `seeds` and average predictions
# False: train a single model using only `seed` (faster experimentation)
use_ensemble: bool = True
# ── Multi-seed ensemble ──────────────────────────────────────────────
seeds: List[int] = field(default_factory=lambda: [42, 7, 2024])
# ── Loss function ─────────────────────────────────────────────────────
# False: standard BCE | True: 0.5*BCE + 0.5*(1 - soft_composite)
use_composite_loss: bool = False
# Blend weight Ξ±: Ξ±Β·BCE + (1-Ξ±)Β·(1-soft_composite). 0 = pure composite, 1 = pure BCE.
composite_loss_alpha: float = 0.5
# Temperature for the pairwise-sigmoid soft-AUC term (higher β†’ sharper ranking signal)
composite_loss_gamma: float = 1.0
# SoftCompositeLoss: Οƒ(thr_tempΒ·logit) approximates I[logit>0] (aligns with prob 0.5 threshold)
composite_thr_temperature: float = 6.0
# If fewer hard positives or negatives than this in a batch, skip composite term (BCE only).
composite_min_class_per_batch: int = 2
# ── Inference ────────────────────────────────────────────────────────
tta_passes: int = 6 # number of deterministic TTA transforms (max 6)
n_bootstrap: int = 1000 # bootstrap iterations for threshold stabilisation
# ── Device (auto-detected) ───────────────────────────────────────────
device: str = (
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
def setup(self) -> "Config":
"""Create output directory and return self (for chaining)."""
os.makedirs(self.output_dir, exist_ok=True)
return self
# Global singleton β€” import and use directly, or override fields before training
CFG = Config().setup()