agentic_thyroid_model / thyroid_lib.py
Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
21.3 kB
#!/usr/bin/env python
"""
Shared library for the agentic thyroid ResNet-18 experiment.
Centralizes everything that train.py / evaluate.py / evaluate_external.py must
share so that preprocessing, model construction, calibration, and thresholding
are guaranteed identical across training, validation, test, and external use.
Positive class = Malignant (label 1). Benign = 0.
"""
import json
import os
import random
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional, List, Tuple
import numpy as np
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
CLASS_TO_IDX = {"Benign": 0, "Malignant": 1}
IDX_TO_CLASS = {0: "Benign", 1: "Malignant"}
# --------------------------------------------------------------------------- #
# Reproducibility
# --------------------------------------------------------------------------- #
def set_determinism(seed: int, strict: bool = True):
"""Set all RNG seeds and (optionally) enforce deterministic algorithms."""
import torch
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if strict:
# cuBLAS workspace config required for deterministic matmul on CUDA.
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
torch.use_deterministic_algorithms(True, warn_only=True)
except Exception:
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
def seed_worker(worker_id):
import torch
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
def collect_env_info():
"""Return a dict of package versions and hardware/CUDA settings for logging."""
info = {}
try:
import torch
info["torch"] = torch.__version__
info["cuda_available"] = torch.cuda.is_available()
info["cuda_version"] = torch.version.cuda
info["cudnn_version"] = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None
info["cudnn_deterministic"] = torch.backends.cudnn.deterministic
info["cudnn_benchmark"] = torch.backends.cudnn.benchmark
if torch.cuda.is_available():
info["gpu_name"] = torch.cuda.get_device_name(0)
info["gpu_count"] = torch.cuda.device_count()
props = torch.cuda.get_device_properties(0)
info["gpu_total_mem_gb"] = round(props.total_memory / 1e9, 2)
except Exception as e:
info["torch_error"] = repr(e)
for mod in ["torchvision", "timm", "sklearn", "numpy", "PIL", "trackio"]:
try:
m = __import__(mod)
info[mod] = getattr(m, "__version__", "?")
except Exception:
info[mod] = None
info["cublas_workspace_config"] = os.environ.get("CUBLAS_WORKSPACE_CONFIG")
info["pythonhashseed"] = os.environ.get("PYTHONHASHSEED")
return info
# --------------------------------------------------------------------------- #
# Preprocessing / augmentation
# --------------------------------------------------------------------------- #
@dataclass
class PreprocessConfig:
"""Locked preprocessing config saved with the final model.
Eval/inference path is fully deterministic: resize to image_size, ToTensor,
Normalize with the given mean/std. No augmentation at eval time.
"""
image_size: int = 224
mean: List[float] = field(default_factory=lambda: list(IMAGENET_MEAN))
std: List[float] = field(default_factory=lambda: list(IMAGENET_STD))
interpolation: str = "bilinear" # 'bilinear' (torchvision) or 'bicubic' (timm a1/a2/a3)
def to_dict(self):
return asdict(self)
@staticmethod
def from_dict(d):
fields = {"image_size", "mean", "std", "interpolation"}
return PreprocessConfig(**{k: v for k, v in d.items() if k in fields})
def _interp(name):
from torchvision.transforms import InterpolationMode
return {"bilinear": InterpolationMode.BILINEAR,
"bicubic": InterpolationMode.BICUBIC}[name]
def build_eval_transform(pp: PreprocessConfig):
"""Deterministic eval/inference transform (NO augmentation)."""
import torchvision.transforms as T
return T.Compose([
T.Resize((pp.image_size, pp.image_size), interpolation=_interp(pp.interpolation)),
T.ToTensor(),
T.Normalize(pp.mean, pp.std),
])
def build_train_transform(pp: PreprocessConfig, policy: str = "medical_default"):
"""Training augmentation. Medically plausible ultrasound augmentations only.
Policies:
none : eval transform (no augmentation) — baseline ablation.
flip_only : horizontal flip only.
medical_default : flip + mild affine(rot<=10,trans5%,scale0.9-1.1) + mild
brightness/contrast + occasional light gaussian blur.
medical_strong : medical_default + mild speckle/gaussian noise +
narrow random-resized-crop (scale 0.8-1.0).
clahe : medical_default + CLAHE applied as preprocessing (ablation).
Explicitly AVOIDED: vertical flip, large rotation (>15deg), aggressive crop
(<0.8 scale), shear, heavy blur, any color/HSV jitter beyond mild
brightness/contrast — all of which distort ultrasound texture or nodule
morphology (per MediAug arXiv:2504.18983 and thyroid-US best practice).
"""
import torch
import torchvision.transforms as T
interp = _interp(pp.interpolation)
norm = T.Normalize(pp.mean, pp.std)
if policy == "none":
return build_eval_transform(pp)
if policy == "flip_only":
return T.Compose([
T.Resize((pp.image_size, pp.image_size), interpolation=interp),
T.RandomHorizontalFlip(0.5),
T.ToTensor(), norm,
])
if policy == "medical_default":
return T.Compose([
T.Resize((pp.image_size, pp.image_size), interpolation=interp),
T.RandomHorizontalFlip(0.5),
T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05),
scale=(0.9, 1.1), interpolation=interp)], p=0.5),
T.ColorJitter(brightness=0.15, contrast=0.15),
T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2),
T.ToTensor(), norm,
])
if policy == "medical_strong":
class AddSpeckle:
def __init__(self, sigma=0.05, p=0.2):
self.sigma, self.p = sigma, p
def __call__(self, x):
if random.random() < self.p:
return x + x * (self.sigma * torch.randn_like(x))
return x
return T.Compose([
T.RandomResizedCrop(pp.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1),
interpolation=interp),
T.RandomHorizontalFlip(0.5),
T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05),
scale=(0.9, 1.1), interpolation=interp)], p=0.5),
T.ColorJitter(brightness=0.15, contrast=0.15),
T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 1.0))], p=0.2),
T.ToTensor(),
AddSpeckle(sigma=0.05, p=0.2),
norm,
])
if policy == "clahe":
from PIL import Image
class CLAHE:
def __call__(self, img):
import numpy as _np
try:
import cv2
arr = _np.asarray(img.convert("L"))
cl = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(arr)
return Image.fromarray(cl).convert("RGB")
except Exception:
return img
return T.Compose([
CLAHE(),
T.Resize((pp.image_size, pp.image_size), interpolation=interp),
T.RandomHorizontalFlip(0.5),
T.RandomApply([T.RandomAffine(degrees=10, translate=(0.05, 0.05),
scale=(0.9, 1.1), interpolation=interp)], p=0.5),
T.ColorJitter(brightness=0.15, contrast=0.15),
T.ToTensor(), norm,
])
raise ValueError(f"Unknown augmentation policy: {policy}")
# --------------------------------------------------------------------------- #
# Dataset
# --------------------------------------------------------------------------- #
class ThyroidImageFolder:
"""Lightweight ImageFolder that also returns the image filename id.
Layout: <root>/<Benign|Malignant>/<id>.png
Returns (tensor, label, image_id).
"""
def __init__(self, root, transform):
from PIL import Image
self.Image = Image
self.root = Path(root)
self.transform = transform
self.samples: List[Tuple[Path, int, str]] = []
for cls, idx in CLASS_TO_IDX.items():
d = self.root / cls
if d.is_dir():
for p in sorted(d.glob("*.png")):
self.samples.append((p, idx, p.stem))
if not self.samples:
raise RuntimeError(f"No images found under {root}")
self.targets = [s[1] for s in self.samples]
def __len__(self):
return len(self.samples)
def __getitem__(self, i):
path, label, img_id = self.samples[i]
with self.Image.open(path) as im:
im = im.convert("RGB")
x = self.transform(im)
return x, label, img_id
def class_counts(targets):
n_pos = int(sum(1 for t in targets if t == 1))
n_neg = int(sum(1 for t in targets if t == 0))
return n_neg, n_pos
# --------------------------------------------------------------------------- #
# Model
# --------------------------------------------------------------------------- #
def build_model(backbone: str, freeze_stage: int = 0, dropout: float = 0.0):
"""Build a single-logit ResNet-18 classifier.
backbone:
'torchvision' -> torchvision resnet18 ImageNet1K_V1
'timm:resnet18.a1_in1k' -> any timm tag after 'timm:'
freeze_stage: 0 = full fine-tune; 1 = freeze stem+layer1; 2 = +layer2; etc.
Returns (model, preprocess_config).
"""
import torch
import torch.nn as nn
if backbone == "torchvision":
from torchvision.models import resnet18, ResNet18_Weights
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights)
in_f = model.fc.in_features
model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_f, 1)) if dropout > 0 \
else nn.Linear(in_f, 1)
pp = PreprocessConfig(image_size=224, mean=list(IMAGENET_MEAN),
std=list(IMAGENET_STD), interpolation="bilinear")
_freeze_resnet(model, freeze_stage)
return model, pp
if backbone.startswith("timm:"):
import timm
from timm.data import resolve_model_data_config
tag = backbone.split("timm:", 1)[1]
model = timm.create_model(tag, pretrained=True, num_classes=1, drop_rate=dropout)
cfg = resolve_model_data_config(model)
mean = list(cfg.get("mean", IMAGENET_MEAN))
std = list(cfg.get("std", IMAGENET_STD))
interp = cfg.get("interpolation", "bicubic")
size = cfg.get("input_size", (3, 224, 224))[-1]
pp = PreprocessConfig(image_size=int(size), mean=mean, std=std,
interpolation=interp if interp in ("bilinear", "bicubic") else "bicubic")
_freeze_timm_resnet(model, freeze_stage)
return model, pp
raise ValueError(f"Unknown backbone: {backbone}")
def _freeze_resnet(model, stage):
if stage <= 0:
return
to_freeze = []
if stage >= 1:
to_freeze += [model.conv1, model.bn1, model.layer1]
if stage >= 2:
to_freeze += [model.layer2]
if stage >= 3:
to_freeze += [model.layer3]
for m in to_freeze:
for p in m.parameters():
p.requires_grad = False
def _freeze_timm_resnet(model, stage):
if stage <= 0:
return
name_prefixes = []
if stage >= 1:
name_prefixes += ["conv1", "bn1", "layer1"]
if stage >= 2:
name_prefixes += ["layer2"]
if stage >= 3:
name_prefixes += ["layer3"]
for n, p in model.named_parameters():
if any(n.startswith(pref) for pref in name_prefixes):
p.requires_grad = False
# --------------------------------------------------------------------------- #
# Loss
# --------------------------------------------------------------------------- #
def build_loss(name: str, pos_weight: Optional[float], focal_gamma: float = 2.0,
focal_alpha: float = 0.5):
import torch
import torch.nn as nn
if name == "bce":
pw = torch.tensor([pos_weight]) if pos_weight is not None else None
return nn.BCEWithLogitsLoss(pos_weight=pw)
if name == "focal":
class FocalLoss(nn.Module):
def __init__(self, gamma, alpha):
super().__init__()
self.gamma, self.alpha = gamma, alpha
def forward(self, logits, targets):
logits = logits.view(-1)
targets = targets.view(-1).float()
p = torch.sigmoid(logits)
ce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * (1 - p_t) ** self.gamma * ce
return loss.mean()
return FocalLoss(focal_gamma, focal_alpha)
raise ValueError(f"Unknown loss: {name}")
# --------------------------------------------------------------------------- #
# Inference: collect logits/probs/labels/ids
# --------------------------------------------------------------------------- #
def collect_logits(model, loader, device, amp=False):
import torch
model.eval()
logits_all, labels_all, ids_all = [], [], []
use_ac = amp and device == "cuda"
with torch.no_grad():
for x, y, ids in loader:
x = x.to(device, non_blocking=True)
if use_ac:
with torch.autocast(device_type="cuda", dtype=torch.float16):
out = model(x).view(-1)
else:
out = model(x).view(-1)
logits_all.append(out.float().cpu().numpy())
labels_all.append(np.asarray(y))
ids_all.extend(list(ids))
return (np.concatenate(logits_all), np.concatenate(labels_all).astype(int), ids_all)
# --------------------------------------------------------------------------- #
# Calibration (temperature scaling)
# --------------------------------------------------------------------------- #
def fit_temperature(val_logits: np.ndarray, val_labels: np.ndarray) -> float:
"""Fit single-parameter temperature on validation logits (minimize NLL)."""
import torch
import torch.nn as nn
logits = torch.tensor(val_logits, dtype=torch.float32)
labels = torch.tensor(val_labels, dtype=torch.float32)
T = nn.Parameter(torch.ones(1))
opt = torch.optim.LBFGS([T], lr=0.01, max_iter=200)
bce = nn.BCEWithLogitsLoss()
def closure():
opt.zero_grad()
loss = bce(logits / T.clamp(min=1e-3), labels)
loss.backward()
return loss
opt.step(closure)
return float(T.detach().clamp(min=1e-3).item())
def apply_temperature(logits: np.ndarray, T: float) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-(logits / T)))
def sigmoid(logits: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-logits))
# --------------------------------------------------------------------------- #
# Calibration metrics
# --------------------------------------------------------------------------- #
def expected_calibration_error(y_true, y_prob, n_bins=15):
y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
bins = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for lo, hi in zip(bins[:-1], bins[1:]):
m = (y_prob > lo) & (y_prob <= hi)
if m.sum() > 0:
ece += (m.sum() / len(y_prob)) * abs(y_true[m].mean() - y_prob[m].mean())
return float(ece)
def brier(y_true, y_prob):
from sklearn.metrics import brier_score_loss
return float(brier_score_loss(np.asarray(y_true), np.asarray(y_prob)))
# --------------------------------------------------------------------------- #
# Thresholding
# --------------------------------------------------------------------------- #
def threshold_for_sensitivity(y_true, y_prob, target_sens=0.95):
"""Highest-specificity threshold achieving sensitivity >= target on these data.
Returns (threshold, achieved_sens, achieved_spec, achievable_flag).
"""
from sklearn.metrics import roc_curve
y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
fpr, tpr, thr = roc_curve(y_true, y_prob)
spec = 1 - fpr
ok = tpr >= target_sens
if ok.any():
cand = np.where(ok)[0]
best = cand[np.argmax(spec[cand])]
return float(thr[best]), float(tpr[best]), float(spec[best]), True
best = int(np.argmax(tpr))
return float(thr[best]), float(tpr[best]), float(spec[best]), False
def youden_threshold(y_true, y_prob):
from sklearn.metrics import roc_curve
fpr, tpr, thr = roc_curve(np.asarray(y_true), np.asarray(y_prob))
j = tpr - fpr
best = int(np.argmax(j))
return float(thr[best]), float(tpr[best]), float(1 - fpr[best])
# --------------------------------------------------------------------------- #
# Metrics + bootstrap CIs
# --------------------------------------------------------------------------- #
def point_metrics(y_true, y_prob, thr):
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
pred = (y_prob >= thr).astype(int)
tp = int(((pred == 1) & (y_true == 1)).sum())
tn = int(((pred == 0) & (y_true == 0)).sum())
fp = int(((pred == 1) & (y_true == 0)).sum())
fn = int(((pred == 0) & (y_true == 1)).sum())
sens = tp / (tp + fn) if (tp + fn) else float("nan")
spec = tn / (tn + fp) if (tn + fp) else float("nan")
ppv = tp / (tp + fp) if (tp + fp) else float("nan")
npv = tn / (tn + fn) if (tn + fn) else float("nan")
return {
"auroc": float(roc_auc_score(y_true, y_prob)),
"accuracy": float(accuracy_score(y_true, pred)),
"sensitivity": float(sens),
"specificity": float(spec),
"ppv": float(ppv),
"npv": float(npv),
"f1": float(f1_score(y_true, pred, zero_division=0)),
"brier": brier(y_true, y_prob),
"ece": expected_calibration_error(y_true, y_prob),
"tp": tp, "tn": tn, "fp": fp, "fn": fn,
"threshold": float(thr),
"n": int(len(y_true)),
"n_pos": int((y_true == 1).sum()),
"n_neg": int((y_true == 0).sum()),
}
def bootstrap_ci(y_true, y_prob, thr, n_boot=2000, seed=42):
"""Stratified bootstrap 95% CIs for AUROC, sens, spec, ppv, npv, acc, f1."""
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
y_true = np.asarray(y_true); y_prob = np.asarray(y_prob)
rng = np.random.default_rng(seed)
pos = np.where(y_true == 1)[0]
neg = np.where(y_true == 0)[0]
keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
acc = {k: [] for k in keys}
for _ in range(n_boot):
idx = np.concatenate([rng.choice(pos, len(pos), replace=True),
rng.choice(neg, len(neg), replace=True)])
yt = y_true[idx]; yp = y_prob[idx]
pred = (yp >= thr).astype(int)
try:
acc["auroc"].append(roc_auc_score(yt, yp))
except Exception:
acc["auroc"].append(np.nan)
tp = ((pred == 1) & (yt == 1)).sum(); tn = ((pred == 0) & (yt == 0)).sum()
fp = ((pred == 1) & (yt == 0)).sum(); fn = ((pred == 0) & (yt == 1)).sum()
acc["sensitivity"].append(tp / (tp + fn) if (tp + fn) else np.nan)
acc["specificity"].append(tn / (tn + fp) if (tn + fp) else np.nan)
acc["ppv"].append(tp / (tp + fp) if (tp + fp) else np.nan)
acc["npv"].append(tn / (tn + fn) if (tn + fn) else np.nan)
acc["accuracy"].append(accuracy_score(yt, pred))
acc["f1"].append(f1_score(yt, pred, zero_division=0))
out = {}
for k in keys:
v = np.asarray(acc[k], dtype=float)
out[k] = (float(np.nanpercentile(v, 2.5)), float(np.nanpercentile(v, 97.5)))
return out
# --------------------------------------------------------------------------- #
# JSON helpers
# --------------------------------------------------------------------------- #
def save_json(obj, path):
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(obj, f, indent=2)
def load_json(path):
with open(path) as f:
return json.load(f)