#!/usr/bin/env python """ Shared library for the agentic thyroid ResNet-18 experiment. Centralizes everything that train.py / evaluate.py / evaluate_external.py must share so that preprocessing, model construction, calibration, and thresholding are guaranteed identical across training, validation, test, and external use. Positive class = Malignant (label 1). Benign = 0. """ import json import os import random from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Optional, List, Tuple import numpy as np IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] CLASS_TO_IDX = {"Benign": 0, "Malignant": 1} IDX_TO_CLASS = {0: "Benign", 1: "Malignant"} # --------------------------------------------------------------------------- # # Reproducibility # --------------------------------------------------------------------------- # def set_determinism(seed: int, strict: bool = True): """Set all RNG seeds and (optionally) enforce deterministic algorithms.""" import torch os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if strict: # cuBLAS workspace config required for deterministic matmul on CUDA. os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False try: torch.use_deterministic_algorithms(True, warn_only=True) except Exception: torch.use_deterministic_algorithms(True) else: torch.backends.cudnn.benchmark = True def seed_worker(worker_id): import torch worker_seed = torch.initial_seed() % 2 ** 32 np.random.seed(worker_seed) random.seed(worker_seed) def collect_env_info(): """Return a dict of package versions and hardware/CUDA settings for logging.""" info = {} try: import torch info["torch"] = torch.__version__ info["cuda_available"] = torch.cuda.is_available() info["cuda_version"] = torch.version.cuda info["cudnn_version"] = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None info["cudnn_deterministic"] = torch.backends.cudnn.deterministic info["cudnn_benchmark"] = torch.backends.cudnn.benchmark if torch.cuda.is_available(): info["gpu_name"] = torch.cuda.get_device_name(0) info["gpu_count"] = torch.cuda.device_count() props = torch.cuda.get_device_properties(0) info["gpu_total_mem_gb"] = round(props.total_memory / 1e9, 2) except Exception as e: info["torch_error"] = repr(e) for mod in ["torchvision", "timm", "sklearn", "numpy", "PIL", "trackio"]: try: m = __import__(mod) info[mod] = getattr(m, "__version__", "?") except Exception: info[mod] = None info["cublas_workspace_config"] = os.environ.get("CUBLAS_WORKSPACE_CONFIG") info["pythonhashseed"] = os.environ.get("PYTHONHASHSEED") return info # --------------------------------------------------------------------------- # # Preprocessing / augmentation # --------------------------------------------------------------------------- # @dataclass class PreprocessConfig: """Locked preprocessing config saved with the final model. Eval/inference path is fully deterministic: resize to image_size, ToTensor, Normalize with the given mean/std. No augmentation at eval time. """ image_size: int = 224 mean: List[float] = field(default_factory=lambda: list(IMAGENET_MEAN)) std: List[float] = field(default_factory=lambda: list(IMAGENET_STD)) interpolation: str = "bilinear" # 'bilinear' (torchvision) or 'bicubic' (timm a1/a2/a3) def to_dict(self): return asdict(self) @staticmethod def from_dict(d): fields = {"image_size", "mean", "std", "interpolation"} return PreprocessConfig(**{k: v for k, v in d.items() if k in fields}) def _interp(name): from torchvision.transforms import InterpolationMode return {"bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC}[name] def build_eval_transform(pp: PreprocessConfig): """Deterministic eval/inference transform (NO augmentation).""" import torchvision.transforms as T return T.Compose([ T.Resize((pp.image_size, pp.image_size), interpolation=_interp(pp.interpolation)), T.ToTensor(), T.Normalize(pp.mean, pp.std), ]) def build_train_transform(pp: PreprocessConfig, policy: str = "medical_default"): """Training augmentation. Medically plausible ultrasound augmentations only. Policies: none : eval transform (no augmentation) — baseline ablation. flip_only : horizontal flip only. medical_default : flip + mild affine(rot<=10,trans5%,scale0.9-1.1) + mild brightness/contrast + occasional light gaussian blur. medical_strong : medical_default + mild speckle/gaussian noise + narrow random-resized-crop (scale 0.8-1.0). clahe : medical_default + CLAHE applied as preprocessing (ablation). Explicitly AVOIDED: vertical flip, large rotation (>15deg), aggressive crop (<0.8 scale), shear, heavy blur, any color/HSV jitter beyond mild brightness/contrast — all of which distort ultrasound texture or nodule morphology (per MediAug arXiv:2504.18983 and thyroid-US best practice). """ import torch import torchvision.transforms as T interp = _interp(pp.interpolation) norm = T.Normalize(pp.mean, pp.std) if policy == "none": return build_eval_transform(pp) if policy == "flip_only": return T.Compose([ T.Resize((pp.image_size, pp.image_size), interpolation=interp), T.RandomHorizontalFlip(0.5), T.ToTensor(), norm, ]) if policy == "medical_default": return T.Compose([ T.Resize((pp.image_size, pp.image_size), interpolation=interp), T.RandomHorizontalFlip(0.5), T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.9, 1.1), interpolation=interp)], p=0.5), T.ColorJitter(brightness=0.15, contrast=0.15), T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2), T.ToTensor(), norm, ]) if policy == "medical_strong": class AddSpeckle: def __init__(self, sigma=0.05, p=0.2): self.sigma, self.p = sigma, p def __call__(self, x): if random.random() < self.p: return x + x * (self.sigma * torch.randn_like(x)) return x return T.Compose([ T.RandomResizedCrop(pp.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1), interpolation=interp), T.RandomHorizontalFlip(0.5), T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.9, 1.1), interpolation=interp)], p=0.5), T.ColorJitter(brightness=0.15, contrast=0.15), T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2), T.ToTensor(), AddSpeckle(sigma=0.05, p=0.2), norm, ]) if policy == "clahe": from PIL import Image class CLAHE: def __call__(self, img): import numpy as _np try: import cv2 arr = _np.asarray(img.convert("L")) cl = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(arr) return Image.fromarray(cl).convert("RGB") except Exception: return img return T.Compose([ CLAHE(), T.Resize((pp.image_size, pp.image_size), interpolation=interp), T.RandomHorizontalFlip(0.5), T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.9, 1.1), interpolation=interp)], p=0.5), T.ColorJitter(brightness=0.15, contrast=0.15), T.ToTensor(), norm, ]) raise ValueError(f"Unknown augmentation policy: {policy}") # --------------------------------------------------------------------------- # # Dataset # --------------------------------------------------------------------------- # class ThyroidImageFolder: """Lightweight ImageFolder that also returns the image filename id. Layout: //.png Returns (tensor, label, image_id). """ def __init__(self, root, transform): from PIL import Image self.Image = Image self.root = Path(root) self.transform = transform self.samples: List[Tuple[Path, int, str]] = [] for cls, idx in CLASS_TO_IDX.items(): d = self.root / cls if d.is_dir(): for p in sorted(d.glob("*.png")): self.samples.append((p, idx, p.stem)) if not self.samples: raise RuntimeError(f"No images found under {root}") self.targets = [s[1] for s in self.samples] def __len__(self): return len(self.samples) def __getitem__(self, i): path, label, img_id = self.samples[i] with self.Image.open(path) as im: im = im.convert("RGB") x = self.transform(im) return x, label, img_id def class_counts(targets): n_pos = int(sum(1 for t in targets if t == 1)) n_neg = int(sum(1 for t in targets if t == 0)) return n_neg, n_pos # --------------------------------------------------------------------------- # # Model # --------------------------------------------------------------------------- # def build_model(backbone: str, freeze_stage: int = 0, dropout: float = 0.0): """Build a single-logit ResNet-18 classifier. backbone: 'torchvision' -> torchvision resnet18 ImageNet1K_V1 'timm:resnet18.a1_in1k' -> any timm tag after 'timm:' freeze_stage: 0 = full fine-tune; 1 = freeze stem+layer1; 2 = +layer2; etc. Returns (model, preprocess_config). """ import torch import torch.nn as nn if backbone == "torchvision": from torchvision.models import resnet18, ResNet18_Weights weights = ResNet18_Weights.IMAGENET1K_V1 model = resnet18(weights=weights) in_f = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, 1)) if dropout > 0 \ else nn.Linear(in_f, 1) pp = PreprocessConfig(image_size=224, mean=list(IMAGENET_MEAN), std=list(IMAGENET_STD), interpolation="bilinear") _freeze_resnet(model, freeze_stage) return model, pp if backbone.startswith("timm:"): import timm from timm.data import resolve_model_data_config tag = backbone.split("timm:", 1)[1] model = timm.create_model(tag, pretrained=True, num_classes=1, drop_rate=dropout) cfg = resolve_model_data_config(model) mean = list(cfg.get("mean", IMAGENET_MEAN)) std = list(cfg.get("std", IMAGENET_STD)) interp = cfg.get("interpolation", "bicubic") size = cfg.get("input_size", (3, 224, 224))[-1] pp = PreprocessConfig(image_size=int(size), mean=mean, std=std, interpolation=interp if interp in ("bilinear", "bicubic") else "bicubic") _freeze_timm_resnet(model, freeze_stage) return model, pp raise ValueError(f"Unknown backbone: {backbone}") def _freeze_resnet(model, stage): if stage <= 0: return to_freeze = [] if stage >= 1: to_freeze += [model.conv1, model.bn1, model.layer1] if stage >= 2: to_freeze += [model.layer2] if stage >= 3: to_freeze += [model.layer3] for m in to_freeze: for p in m.parameters(): p.requires_grad = False def _freeze_timm_resnet(model, stage): if stage <= 0: return name_prefixes = [] if stage >= 1: name_prefixes += ["conv1", "bn1", "layer1"] if stage >= 2: name_prefixes += ["layer2"] if stage >= 3: name_prefixes += ["layer3"] for n, p in model.named_parameters(): if any(n.startswith(pref) for pref in name_prefixes): p.requires_grad = False # --------------------------------------------------------------------------- # # Loss # --------------------------------------------------------------------------- # def build_loss(name: str, pos_weight: Optional[float], focal_gamma: float = 2.0, focal_alpha: float = 0.5): import torch import torch.nn as nn if name == "bce": pw = torch.tensor([pos_weight]) if pos_weight is not None else None return nn.BCEWithLogitsLoss(pos_weight=pw) if name == "focal": class FocalLoss(nn.Module): def __init__(self, gamma, alpha): super().__init__() self.gamma, self.alpha = gamma, alpha def forward(self, logits, targets): logits = logits.view(-1) targets = targets.view(-1).float() p = torch.sigmoid(logits) ce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) loss = alpha_t * (1 - p_t) ** self.gamma * ce return loss.mean() return FocalLoss(focal_gamma, focal_alpha) raise ValueError(f"Unknown loss: {name}") # --------------------------------------------------------------------------- # # Inference: collect logits/probs/labels/ids # --------------------------------------------------------------------------- # def collect_logits(model, loader, device, amp=False): import torch model.eval() logits_all, labels_all, ids_all = [], [], [] use_ac = amp and device == "cuda" with torch.no_grad(): for x, y, ids in loader: x = x.to(device, non_blocking=True) if use_ac: with torch.autocast(device_type="cuda", dtype=torch.float16): out = model(x).view(-1) else: out = model(x).view(-1) logits_all.append(out.float().cpu().numpy()) labels_all.append(np.asarray(y)) ids_all.extend(list(ids)) return (np.concatenate(logits_all), np.concatenate(labels_all).astype(int), ids_all) # --------------------------------------------------------------------------- # # Calibration (temperature scaling) # --------------------------------------------------------------------------- # def fit_temperature(val_logits: np.ndarray, val_labels: np.ndarray) -> float: """Fit single-parameter temperature on validation logits (minimize NLL).""" import torch import torch.nn as nn logits = torch.tensor(val_logits, dtype=torch.float32) labels = torch.tensor(val_labels, dtype=torch.float32) T = nn.Parameter(torch.ones(1)) opt = torch.optim.LBFGS([T], lr=0.01, max_iter=200) bce = nn.BCEWithLogitsLoss() def closure(): opt.zero_grad() loss = bce(logits / T.clamp(min=1e-3), labels) loss.backward() return loss opt.step(closure) return float(T.detach().clamp(min=1e-3).item()) def apply_temperature(logits: np.ndarray, T: float) -> np.ndarray: return 1.0 / (1.0 + np.exp(-(logits / T))) def sigmoid(logits: np.ndarray) -> np.ndarray: return 1.0 / (1.0 + np.exp(-logits)) # --------------------------------------------------------------------------- # # Calibration metrics # --------------------------------------------------------------------------- # def expected_calibration_error(y_true, y_prob, n_bins=15): y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) bins = np.linspace(0, 1, n_bins + 1) ece = 0.0 for lo, hi in zip(bins[:-1], bins[1:]): m = (y_prob > lo) & (y_prob <= hi) if m.sum() > 0: ece += (m.sum() / len(y_prob)) * abs(y_true[m].mean() - y_prob[m].mean()) return float(ece) def brier(y_true, y_prob): from sklearn.metrics import brier_score_loss return float(brier_score_loss(np.asarray(y_true), np.asarray(y_prob))) # --------------------------------------------------------------------------- # # Thresholding # --------------------------------------------------------------------------- # def threshold_for_sensitivity(y_true, y_prob, target_sens=0.95): """Highest-specificity threshold achieving sensitivity >= target on these data. Returns (threshold, achieved_sens, achieved_spec, achievable_flag). """ from sklearn.metrics import roc_curve y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) fpr, tpr, thr = roc_curve(y_true, y_prob) spec = 1 - fpr ok = tpr >= target_sens if ok.any(): cand = np.where(ok)[0] best = cand[np.argmax(spec[cand])] return float(thr[best]), float(tpr[best]), float(spec[best]), True best = int(np.argmax(tpr)) return float(thr[best]), float(tpr[best]), float(spec[best]), False def youden_threshold(y_true, y_prob): from sklearn.metrics import roc_curve fpr, tpr, thr = roc_curve(np.asarray(y_true), np.asarray(y_prob)) j = tpr - fpr best = int(np.argmax(j)) return float(thr[best]), float(tpr[best]), float(1 - fpr[best]) # --------------------------------------------------------------------------- # # Metrics + bootstrap CIs # --------------------------------------------------------------------------- # def point_metrics(y_true, y_prob, thr): from sklearn.metrics import roc_auc_score, f1_score, accuracy_score y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) pred = (y_prob >= thr).astype(int) tp = int(((pred == 1) & (y_true == 1)).sum()) tn = int(((pred == 0) & (y_true == 0)).sum()) fp = int(((pred == 1) & (y_true == 0)).sum()) fn = int(((pred == 0) & (y_true == 1)).sum()) sens = tp / (tp + fn) if (tp + fn) else float("nan") spec = tn / (tn + fp) if (tn + fp) else float("nan") ppv = tp / (tp + fp) if (tp + fp) else float("nan") npv = tn / (tn + fn) if (tn + fn) else float("nan") return { "auroc": float(roc_auc_score(y_true, y_prob)), "accuracy": float(accuracy_score(y_true, pred)), "sensitivity": float(sens), "specificity": float(spec), "ppv": float(ppv), "npv": float(npv), "f1": float(f1_score(y_true, pred, zero_division=0)), "brier": brier(y_true, y_prob), "ece": expected_calibration_error(y_true, y_prob), "tp": tp, "tn": tn, "fp": fp, "fn": fn, "threshold": float(thr), "n": int(len(y_true)), "n_pos": int((y_true == 1).sum()), "n_neg": int((y_true == 0).sum()), } def bootstrap_ci(y_true, y_prob, thr, n_boot=2000, seed=42): """Stratified bootstrap 95% CIs for AUROC, sens, spec, ppv, npv, acc, f1.""" from sklearn.metrics import roc_auc_score, f1_score, accuracy_score y_true = np.asarray(y_true); y_prob = np.asarray(y_prob) rng = np.random.default_rng(seed) pos = np.where(y_true == 1)[0] neg = np.where(y_true == 0)[0] keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] acc = {k: [] for k in keys} for _ in range(n_boot): idx = np.concatenate([rng.choice(pos, len(pos), replace=True), rng.choice(neg, len(neg), replace=True)]) yt = y_true[idx]; yp = y_prob[idx] pred = (yp >= thr).astype(int) try: acc["auroc"].append(roc_auc_score(yt, yp)) except Exception: acc["auroc"].append(np.nan) tp = ((pred == 1) & (yt == 1)).sum(); tn = ((pred == 0) & (yt == 0)).sum() fp = ((pred == 1) & (yt == 0)).sum(); fn = ((pred == 0) & (yt == 1)).sum() acc["sensitivity"].append(tp / (tp + fn) if (tp + fn) else np.nan) acc["specificity"].append(tn / (tn + fp) if (tn + fp) else np.nan) acc["ppv"].append(tp / (tp + fp) if (tp + fp) else np.nan) acc["npv"].append(tn / (tn + fn) if (tn + fn) else np.nan) acc["accuracy"].append(accuracy_score(yt, pred)) acc["f1"].append(f1_score(yt, pred, zero_division=0)) out = {} for k in keys: v = np.asarray(acc[k], dtype=float) out[k] = (float(np.nanpercentile(v, 2.5)), float(np.nanpercentile(v, 97.5))) return out # --------------------------------------------------------------------------- # # JSON helpers # --------------------------------------------------------------------------- # def save_json(obj, path): Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: json.dump(obj, f, indent=2) def load_json(path): with open(path) as f: return json.load(f)