Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified | #!/usr/bin/env python | |
| """ | |
| Shared library for the agentic thyroid ResNet-18 experiment. | |
| Centralizes everything that train.py / evaluate.py / evaluate_external.py must | |
| share so that preprocessing, model construction, calibration, and thresholding | |
| are guaranteed identical across training, validation, test, and external use. | |
| Positive class = Malignant (label 1). Benign = 0. | |
| """ | |
| import json | |
| import os | |
| import random | |
| from dataclasses import dataclass, field, asdict | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple | |
| import numpy as np | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| CLASS_TO_IDX = {"Benign": 0, "Malignant": 1} | |
| IDX_TO_CLASS = {0: "Benign", 1: "Malignant"} | |
| # --------------------------------------------------------------------------- # | |
| # Reproducibility | |
| # --------------------------------------------------------------------------- # | |
| def set_determinism(seed: int, strict: bool = True): | |
| """Set all RNG seeds and (optionally) enforce deterministic algorithms.""" | |
| import torch | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| if strict: | |
| # cuBLAS workspace config required for deterministic matmul on CUDA. | |
| os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| try: | |
| torch.use_deterministic_algorithms(True, warn_only=True) | |
| except Exception: | |
| torch.use_deterministic_algorithms(True) | |
| else: | |
| torch.backends.cudnn.benchmark = True | |
| def seed_worker(worker_id): | |
| import torch | |
| worker_seed = torch.initial_seed() % 2 ** 32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| def collect_env_info(): | |
| """Return a dict of package versions and hardware/CUDA settings for logging.""" | |
| info = {} | |
| try: | |
| import torch | |
| info["torch"] = torch.__version__ | |
| info["cuda_available"] = torch.cuda.is_available() | |
| info["cuda_version"] = torch.version.cuda | |
| info["cudnn_version"] = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None | |
| info["cudnn_deterministic"] = torch.backends.cudnn.deterministic | |
| info["cudnn_benchmark"] = torch.backends.cudnn.benchmark | |
| if torch.cuda.is_available(): | |
| info["gpu_name"] = torch.cuda.get_device_name(0) | |
| info["gpu_count"] = torch.cuda.device_count() | |
| props = torch.cuda.get_device_properties(0) | |
| info["gpu_total_mem_gb"] = round(props.total_memory / 1e9, 2) | |
| except Exception as e: | |
| info["torch_error"] = repr(e) | |
| for mod in ["torchvision", "timm", "sklearn", "numpy", "PIL", "trackio"]: | |
| try: | |
| m = __import__(mod) | |
| info[mod] = getattr(m, "__version__", "?") | |
| except Exception: | |
| info[mod] = None | |
| info["cublas_workspace_config"] = os.environ.get("CUBLAS_WORKSPACE_CONFIG") | |
| info["pythonhashseed"] = os.environ.get("PYTHONHASHSEED") | |
| return info | |
| # --------------------------------------------------------------------------- # | |
| # Preprocessing / augmentation | |
| # --------------------------------------------------------------------------- # | |
| class PreprocessConfig: | |
| """Locked preprocessing config saved with the final model. | |
| Eval/inference path is fully deterministic: resize to image_size, ToTensor, | |
| Normalize with the given mean/std. No augmentation at eval time. | |
| """ | |
| image_size: int = 224 | |
| mean: List[float] = field(default_factory=lambda: list(IMAGENET_MEAN)) | |
| std: List[float] = field(default_factory=lambda: list(IMAGENET_STD)) | |
| interpolation: str = "bilinear" # 'bilinear' (torchvision) or 'bicubic' (timm a1/a2/a3) | |
| def to_dict(self): | |
| return asdict(self) | |
| def from_dict(d): | |
| fields = {"image_size", "mean", "std", "interpolation"} | |
| return PreprocessConfig(**{k: v for k, v in d.items() if k in fields}) | |
| def _interp(name): | |
| from torchvision.transforms import InterpolationMode | |
| return {"bilinear": InterpolationMode.BILINEAR, | |
| "bicubic": InterpolationMode.BICUBIC}[name] | |
| def build_eval_transform(pp: PreprocessConfig): | |
| """Deterministic eval/inference transform (NO augmentation).""" | |
| import torchvision.transforms as T | |
| return T.Compose([ | |
| T.Resize((pp.image_size, pp.image_size), interpolation=_interp(pp.interpolation)), | |
| T.ToTensor(), | |
| T.Normalize(pp.mean, pp.std), | |
| ]) | |
| def build_train_transform(pp: PreprocessConfig, policy: str = "medical_default"): | |
| """Training augmentation. Medically plausible ultrasound augmentations only. | |
| Policies: | |
| none : eval transform (no augmentation) — baseline ablation. | |
| flip_only : horizontal flip only. | |
| medical_default : flip + mild affine(rot<=10,trans5%,scale0.9-1.1) + mild | |
| brightness/contrast + occasional light gaussian blur. | |
| medical_strong : medical_default + mild speckle/gaussian noise + | |
| narrow random-resized-crop (scale 0.8-1.0). | |
| clahe : medical_default + CLAHE applied as preprocessing (ablation). | |
| Explicitly AVOIDED: vertical flip, large rotation (>15deg), aggressive crop | |
| (<0.8 scale), shear, heavy blur, any color/HSV jitter beyond mild | |
| brightness/contrast — all of which distort ultrasound texture or nodule | |
| morphology (per MediAug arXiv:2504.18983 and thyroid-US best practice). | |
| """ | |
| import torch | |
| import torchvision.transforms as T | |
| interp = _interp(pp.interpolation) | |
| norm = T.Normalize(pp.mean, pp.std) | |
| if policy == "none": | |
| return build_eval_transform(pp) | |
| if policy == "flip_only": | |
| return T.Compose([ | |
| T.Resize((pp.image_size, pp.image_size), interpolation=interp), | |
| T.RandomHorizontalFlip(0.5), | |
| T.ToTensor(), norm, | |
| ]) | |
| if policy == "medical_default": | |
| return T.Compose([ | |
| T.Resize((pp.image_size, pp.image_size), interpolation=interp), | |
| T.RandomHorizontalFlip(0.5), | |
| T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), | |
| scale=(0.9, 1.1), interpolation=interp)], p=0.5), | |
| T.ColorJitter(brightness=0.15, contrast=0.15), | |
| T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2), | |
| T.ToTensor(), norm, | |
| ]) | |
| if policy == "medical_strong": | |
| class AddSpeckle: | |
| def __init__(self, sigma=0.05, p=0.2): | |
| self.sigma, self.p = sigma, p | |
| def __call__(self, x): | |
| if random.random() < self.p: | |
| return x + x * (self.sigma * torch.randn_like(x)) | |
| return x | |
| return T.Compose([ | |
| T.RandomResizedCrop(pp.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1), | |
| interpolation=interp), | |
| T.RandomHorizontalFlip(0.5), | |
| T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), | |
| scale=(0.9, 1.1), interpolation=interp)], p=0.5), | |
| T.ColorJitter(brightness=0.15, contrast=0.15), | |
| T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2), | |
| T.ToTensor(), | |
| AddSpeckle(sigma=0.05, p=0.2), | |
| norm, | |
| ]) | |
| if policy == "clahe": | |
| from PIL import Image | |
| class CLAHE: | |
| def __call__(self, img): | |
| import numpy as _np | |
| try: | |
| import cv2 | |
| arr = _np.asarray(img.convert("L")) | |
| cl = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(arr) | |
| return Image.fromarray(cl).convert("RGB") | |
| except Exception: | |
| return img | |
| return T.Compose([ | |
| CLAHE(), | |
| T.Resize((pp.image_size, pp.image_size), interpolation=interp), | |
| T.RandomHorizontalFlip(0.5), | |
| T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), | |
| scale=(0.9, 1.1), interpolation=interp)], p=0.5), | |
| T.ColorJitter(brightness=0.15, contrast=0.15), | |
| T.ToTensor(), norm, | |
| ]) | |
| raise ValueError(f"Unknown augmentation policy: {policy}") | |
| # --------------------------------------------------------------------------- # | |
| # Dataset | |
| # --------------------------------------------------------------------------- # | |
| class ThyroidImageFolder: | |
| """Lightweight ImageFolder that also returns the image filename id. | |
| Layout: <root>/<Benign|Malignant>/<id>.png | |
| Returns (tensor, label, image_id). | |
| """ | |
| def __init__(self, root, transform): | |
| from PIL import Image | |
| self.Image = Image | |
| self.root = Path(root) | |
| self.transform = transform | |
| self.samples: List[Tuple[Path, int, str]] = [] | |
| for cls, idx in CLASS_TO_IDX.items(): | |
| d = self.root / cls | |
| if d.is_dir(): | |
| for p in sorted(d.glob("*.png")): | |
| self.samples.append((p, idx, p.stem)) | |
| if not self.samples: | |
| raise RuntimeError(f"No images found under {root}") | |
| self.targets = [s[1] for s in self.samples] | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, i): | |
| path, label, img_id = self.samples[i] | |
| with self.Image.open(path) as im: | |
| im = im.convert("RGB") | |
| x = self.transform(im) | |
| return x, label, img_id | |
| def class_counts(targets): | |
| n_pos = int(sum(1 for t in targets if t == 1)) | |
| n_neg = int(sum(1 for t in targets if t == 0)) | |
| return n_neg, n_pos | |
| # --------------------------------------------------------------------------- # | |
| # Model | |
| # --------------------------------------------------------------------------- # | |
| def build_model(backbone: str, freeze_stage: int = 0, dropout: float = 0.0): | |
| """Build a single-logit ResNet-18 classifier. | |
| backbone: | |
| 'torchvision' -> torchvision resnet18 ImageNet1K_V1 | |
| 'timm:resnet18.a1_in1k' -> any timm tag after 'timm:' | |
| freeze_stage: 0 = full fine-tune; 1 = freeze stem+layer1; 2 = +layer2; etc. | |
| Returns (model, preprocess_config). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| if backbone == "torchvision": | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| weights = ResNet18_Weights.IMAGENET1K_V1 | |
| model = resnet18(weights=weights) | |
| in_f = model.fc.in_features | |
| model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, 1)) if dropout > 0 \ | |
| else nn.Linear(in_f, 1) | |
| pp = PreprocessConfig(image_size=224, mean=list(IMAGENET_MEAN), | |
| std=list(IMAGENET_STD), interpolation="bilinear") | |
| _freeze_resnet(model, freeze_stage) | |
| return model, pp | |
| if backbone.startswith("timm:"): | |
| import timm | |
| from timm.data import resolve_model_data_config | |
| tag = backbone.split("timm:", 1)[1] | |
| model = timm.create_model(tag, pretrained=True, num_classes=1, drop_rate=dropout) | |
| cfg = resolve_model_data_config(model) | |
| mean = list(cfg.get("mean", IMAGENET_MEAN)) | |
| std = list(cfg.get("std", IMAGENET_STD)) | |
| interp = cfg.get("interpolation", "bicubic") | |
| size = cfg.get("input_size", (3, 224, 224))[-1] | |
| pp = PreprocessConfig(image_size=int(size), mean=mean, std=std, | |
| interpolation=interp if interp in ("bilinear", "bicubic") else "bicubic") | |
| _freeze_timm_resnet(model, freeze_stage) | |
| return model, pp | |
| raise ValueError(f"Unknown backbone: {backbone}") | |
| def _freeze_resnet(model, stage): | |
| if stage <= 0: | |
| return | |
| to_freeze = [] | |
| if stage >= 1: | |
| to_freeze += [model.conv1, model.bn1, model.layer1] | |
| if stage >= 2: | |
| to_freeze += [model.layer2] | |
| if stage >= 3: | |
| to_freeze += [model.layer3] | |
| for m in to_freeze: | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| def _freeze_timm_resnet(model, stage): | |
| if stage <= 0: | |
| return | |
| name_prefixes = [] | |
| if stage >= 1: | |
| name_prefixes += ["conv1", "bn1", "layer1"] | |
| if stage >= 2: | |
| name_prefixes += ["layer2"] | |
| if stage >= 3: | |
| name_prefixes += ["layer3"] | |
| for n, p in model.named_parameters(): | |
| if any(n.startswith(pref) for pref in name_prefixes): | |
| p.requires_grad = False | |
| # --------------------------------------------------------------------------- # | |
| # Loss | |
| # --------------------------------------------------------------------------- # | |
| def build_loss(name: str, pos_weight: Optional[float], focal_gamma: float = 2.0, | |
| focal_alpha: float = 0.5): | |
| import torch | |
| import torch.nn as nn | |
| if name == "bce": | |
| pw = torch.tensor([pos_weight]) if pos_weight is not None else None | |
| return nn.BCEWithLogitsLoss(pos_weight=pw) | |
| if name == "focal": | |
| class FocalLoss(nn.Module): | |
| def __init__(self, gamma, alpha): | |
| super().__init__() | |
| self.gamma, self.alpha = gamma, alpha | |
| def forward(self, logits, targets): | |
| logits = logits.view(-1) | |
| targets = targets.view(-1).float() | |
| p = torch.sigmoid(logits) | |
| ce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none") | |
| p_t = p * targets + (1 - p) * (1 - targets) | |
| alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) | |
| loss = alpha_t * (1 - p_t) ** self.gamma * ce | |
| return loss.mean() | |
| return FocalLoss(focal_gamma, focal_alpha) | |
| raise ValueError(f"Unknown loss: {name}") | |
| # --------------------------------------------------------------------------- # | |
| # Inference: collect logits/probs/labels/ids | |
| # --------------------------------------------------------------------------- # | |
| def collect_logits(model, loader, device, amp=False): | |
| import torch | |
| model.eval() | |
| logits_all, labels_all, ids_all = [], [], [] | |
| use_ac = amp and device == "cuda" | |
| with torch.no_grad(): | |
| for x, y, ids in loader: | |
| x = x.to(device, non_blocking=True) | |
| if use_ac: | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| out = model(x).view(-1) | |
| else: | |
| out = model(x).view(-1) | |
| logits_all.append(out.float().cpu().numpy()) | |
| labels_all.append(np.asarray(y)) | |
| ids_all.extend(list(ids)) | |
| return (np.concatenate(logits_all), np.concatenate(labels_all).astype(int), ids_all) | |
| # --------------------------------------------------------------------------- # | |
| # Calibration (temperature scaling) | |
| # --------------------------------------------------------------------------- # | |
| def fit_temperature(val_logits: np.ndarray, val_labels: np.ndarray) -> float: | |
| """Fit single-parameter temperature on validation logits (minimize NLL).""" | |
| import torch | |
| import torch.nn as nn | |
| logits = torch.tensor(val_logits, dtype=torch.float32) | |
| labels = torch.tensor(val_labels, dtype=torch.float32) | |
| T = nn.Parameter(torch.ones(1)) | |
| opt = torch.optim.LBFGS([T], lr=0.01, max_iter=200) | |
| bce = nn.BCEWithLogitsLoss() | |
| def closure(): | |
| opt.zero_grad() | |
| loss = bce(logits / T.clamp(min=1e-3), labels) | |
| loss.backward() | |
| return loss | |
| opt.step(closure) | |
| return float(T.detach().clamp(min=1e-3).item()) | |
| def apply_temperature(logits: np.ndarray, T: float) -> np.ndarray: | |
| return 1.0 / (1.0 + np.exp(-(logits / T))) | |
| def sigmoid(logits: np.ndarray) -> np.ndarray: | |
| return 1.0 / (1.0 + np.exp(-logits)) | |
| # --------------------------------------------------------------------------- # | |
| # Calibration metrics | |
| # --------------------------------------------------------------------------- # | |
| def expected_calibration_error(y_true, y_prob, n_bins=15): | |
| y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) | |
| bins = np.linspace(0, 1, n_bins + 1) | |
| ece = 0.0 | |
| for lo, hi in zip(bins[:-1], bins[1:]): | |
| m = (y_prob > lo) & (y_prob <= hi) | |
| if m.sum() > 0: | |
| ece += (m.sum() / len(y_prob)) * abs(y_true[m].mean() - y_prob[m].mean()) | |
| return float(ece) | |
| def brier(y_true, y_prob): | |
| from sklearn.metrics import brier_score_loss | |
| return float(brier_score_loss(np.asarray(y_true), np.asarray(y_prob))) | |
| # --------------------------------------------------------------------------- # | |
| # Thresholding | |
| # --------------------------------------------------------------------------- # | |
| def threshold_for_sensitivity(y_true, y_prob, target_sens=0.95): | |
| """Highest-specificity threshold achieving sensitivity >= target on these data. | |
| Returns (threshold, achieved_sens, achieved_spec, achievable_flag). | |
| """ | |
| from sklearn.metrics import roc_curve | |
| y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) | |
| fpr, tpr, thr = roc_curve(y_true, y_prob) | |
| spec = 1 - fpr | |
| ok = tpr >= target_sens | |
| if ok.any(): | |
| cand = np.where(ok)[0] | |
| best = cand[np.argmax(spec[cand])] | |
| return float(thr[best]), float(tpr[best]), float(spec[best]), True | |
| best = int(np.argmax(tpr)) | |
| return float(thr[best]), float(tpr[best]), float(spec[best]), False | |
| def youden_threshold(y_true, y_prob): | |
| from sklearn.metrics import roc_curve | |
| fpr, tpr, thr = roc_curve(np.asarray(y_true), np.asarray(y_prob)) | |
| j = tpr - fpr | |
| best = int(np.argmax(j)) | |
| return float(thr[best]), float(tpr[best]), float(1 - fpr[best]) | |
| # --------------------------------------------------------------------------- # | |
| # Metrics + bootstrap CIs | |
| # --------------------------------------------------------------------------- # | |
| def point_metrics(y_true, y_prob, thr): | |
| from sklearn.metrics import roc_auc_score, f1_score, accuracy_score | |
| y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) | |
| pred = (y_prob >= thr).astype(int) | |
| tp = int(((pred == 1) & (y_true == 1)).sum()) | |
| tn = int(((pred == 0) & (y_true == 0)).sum()) | |
| fp = int(((pred == 1) & (y_true == 0)).sum()) | |
| fn = int(((pred == 0) & (y_true == 1)).sum()) | |
| sens = tp / (tp + fn) if (tp + fn) else float("nan") | |
| spec = tn / (tn + fp) if (tn + fp) else float("nan") | |
| ppv = tp / (tp + fp) if (tp + fp) else float("nan") | |
| npv = tn / (tn + fn) if (tn + fn) else float("nan") | |
| return { | |
| "auroc": float(roc_auc_score(y_true, y_prob)), | |
| "accuracy": float(accuracy_score(y_true, pred)), | |
| "sensitivity": float(sens), | |
| "specificity": float(spec), | |
| "ppv": float(ppv), | |
| "npv": float(npv), | |
| "f1": float(f1_score(y_true, pred, zero_division=0)), | |
| "brier": brier(y_true, y_prob), | |
| "ece": expected_calibration_error(y_true, y_prob), | |
| "tp": tp, "tn": tn, "fp": fp, "fn": fn, | |
| "threshold": float(thr), | |
| "n": int(len(y_true)), | |
| "n_pos": int((y_true == 1).sum()), | |
| "n_neg": int((y_true == 0).sum()), | |
| } | |
| def bootstrap_ci(y_true, y_prob, thr, n_boot=2000, seed=42): | |
| """Stratified bootstrap 95% CIs for AUROC, sens, spec, ppv, npv, acc, f1.""" | |
| from sklearn.metrics import roc_auc_score, f1_score, accuracy_score | |
| y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) | |
| rng = np.random.default_rng(seed) | |
| pos = np.where(y_true == 1)[0] | |
| neg = np.where(y_true == 0)[0] | |
| keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] | |
| acc = {k: [] for k in keys} | |
| for _ in range(n_boot): | |
| idx = np.concatenate([rng.choice(pos, len(pos), replace=True), | |
| rng.choice(neg, len(neg), replace=True)]) | |
| yt = y_true[idx]; yp = y_prob[idx] | |
| pred = (yp >= thr).astype(int) | |
| try: | |
| acc["auroc"].append(roc_auc_score(yt, yp)) | |
| except Exception: | |
| acc["auroc"].append(np.nan) | |
| tp = ((pred == 1) & (yt == 1)).sum(); tn = ((pred == 0) & (yt == 0)).sum() | |
| fp = ((pred == 1) & (yt == 0)).sum(); fn = ((pred == 0) & (yt == 1)).sum() | |
| acc["sensitivity"].append(tp / (tp + fn) if (tp + fn) else np.nan) | |
| acc["specificity"].append(tn / (tn + fp) if (tn + fp) else np.nan) | |
| acc["ppv"].append(tp / (tp + fp) if (tp + fp) else np.nan) | |
| acc["npv"].append(tn / (tn + fn) if (tn + fn) else np.nan) | |
| acc["accuracy"].append(accuracy_score(yt, pred)) | |
| acc["f1"].append(f1_score(yt, pred, zero_division=0)) | |
| out = {} | |
| for k in keys: | |
| v = np.asarray(acc[k], dtype=float) | |
| out[k] = (float(np.nanpercentile(v, 2.5)), float(np.nanpercentile(v, 97.5))) | |
| return out | |
| # --------------------------------------------------------------------------- # | |
| # JSON helpers | |
| # --------------------------------------------------------------------------- # | |
| def save_json(obj, path): | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w") as f: | |
| json.dump(obj, f, indent=2) | |
| def load_json(path): | |
| with open(path) as f: | |
| return json.load(f) | |