agentic_thyroid_model / finalize.py
Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
12.8 kB
#!/usr/bin/env python
"""
Finalization pipeline (run ONCE after the sweep selects the best model by
validation AUROC):
1. Load the best checkpoint (selected purely on val AUROC).
2. Recompute Valid + Test logits with deterministic eval preprocessing.
3. Calibrate on VALIDATION via temperature scaling; compare uncalibrated vs
calibrated (ECE, Brier, AUROC). Keep calibration only if it does not harm
discrimination (AUROC unchanged — temperature scaling is monotonic) and
improves/maintains calibration.
4. Select a LOCKED threshold on VALIDATION: highest-specificity threshold with
sensitivity >= 0.95 (primary); Youden's J reported as secondary.
5. Evaluate ONCE on TEST with calibrated probs + locked threshold.
6. Bootstrap 95% CIs (stratified, 2000 resamples, seed=42).
7. Save all figures (ROC, PR, calibration/reliability, confusion matrices),
tables (markdown + CSV), per-image prediction CSVs (valid + test),
calibration config, threshold config, preprocessing config.
Outputs go to --output_dir (default model_repo/).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import thyroid_lib as L
TARGET_SENS = 0.95
N_BOOT = 2000
BOOT_SEED = 42
def load_model(ckpt_path, device):
import torch
ck = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model, pp = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"])
model.to(device).eval()
pp = L.PreprocessConfig.from_dict(ck["preprocess"])
return model, pp, ck
def get_logits(model, data_dir, split, pp, device):
import torch
from torch.utils.data import DataLoader
ds = L.ThyroidImageFolder(Path(data_dir) / split, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits, labels, ids = L.collect_logits(model, loader, device, amp=False)
return logits, labels, ids
def save_per_image_csv(path, ids, labels, probs, thr):
pred = (np.asarray(probs) >= thr).astype(int)
with open(path, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "true_class", "probability_malignant",
"predicted_label", "predicted_class"])
for i, y, p, pr in zip(ids, labels, probs, pred):
w.writerow([i, int(y), L.IDX_TO_CLASS[int(y)], f"{p:.6f}",
int(pr), L.IDX_TO_CLASS[int(pr)]])
def plot_roc(y, p, path, title):
from sklearn.metrics import roc_curve, roc_auc_score
fpr, tpr, _ = roc_curve(y, p)
auc = roc_auc_score(y, p)
fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(fpr, tpr, label=f"AUROC = {auc:.3f}", color="#C44E52")
ax.plot([0, 1], [0, 1], "--", color="gray")
ax.set_xlabel("1 - Specificity (FPR)"); ax.set_ylabel("Sensitivity (TPR)")
ax.set_title(title); ax.legend(loc="lower right")
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_pr(y, p, path, title):
from sklearn.metrics import precision_recall_curve, average_precision_score
prec, rec, _ = precision_recall_curve(y, p)
ap = average_precision_score(y, p)
fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(rec, prec, label=f"AP = {ap:.3f}", color="#4C72B0")
ax.set_xlabel("Recall (Sensitivity)"); ax.set_ylabel("Precision (PPV)")
ax.set_title(title); ax.legend(loc="lower left")
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_reliability(y, p_uncal, p_cal, path, title):
from sklearn.calibration import calibration_curve
fig, ax = plt.subplots(figsize=(5.5, 5.5))
ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration")
for p, lab, col in [(p_uncal, "Uncalibrated", "#888888"), (p_cal, "Temperature-scaled", "#C44E52")]:
fpos, mpred = calibration_curve(y, p, n_bins=10, strategy="uniform")
ece = L.expected_calibration_error(y, p)
br = L.brier(y, p)
ax.plot(mpred, fpos, "o-", color=col, label=f"{lab} (ECE={ece:.3f}, Brier={br:.3f})")
ax.set_xlabel("Mean predicted probability"); ax.set_ylabel("Observed frequency")
ax.set_title(title); ax.legend(loc="upper left", fontsize=8)
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_confusion(cm, path, title, normalize=False):
cm = np.asarray(cm, dtype=float)
disp = cm.copy()
if normalize:
disp = cm / cm.sum(axis=1, keepdims=True).clip(min=1e-9)
fig, ax = plt.subplots(figsize=(4.8, 4.4))
im = ax.imshow(disp, cmap="Blues", vmin=0, vmax=disp.max())
ax.set_xticks([0, 1]); ax.set_yticks([0, 1])
ax.set_xticklabels(["Predicted benign", "Predicted malignant"])
ax.set_yticklabels(["True benign", "True malignant"])
for i in range(2):
for j in range(2):
txt = f"{int(cm[i,j])}" + (f"\n({disp[i,j]*100:.1f}%)" if normalize else "")
ax.text(j, i, txt, ha="center", va="center",
color="white" if disp[i, j] > disp.max() / 2 else "black", fontsize=11)
ax.set_title(title)
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--data_dir", default="/app/TN5000")
ap.add_argument("--output_dir", default="/app/model_repo")
ap.add_argument("--best_run_name", default="")
ap.add_argument("--best_val_auroc", default="")
args = ap.parse_args()
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(42, strict=True)
out = Path(args.output_dir)
res = out / "results"; figs = res / "figures"; tabs = res / "tables"
for d in [out / "configs", figs, tabs]:
d.mkdir(parents=True, exist_ok=True)
model, pp, ck = load_model(args.ckpt, device)
# ---- logits ----
val_logits, val_y, val_ids = get_logits(model, args.data_dir, "Valid", pp, device)
test_logits, test_y, test_ids = get_logits(model, args.data_dir, "Test", pp, device)
val_p_uncal = L.sigmoid(val_logits)
test_p_uncal = L.sigmoid(test_logits)
# ---- calibration: temperature scaling on VALIDATION ----
T = L.fit_temperature(val_logits, val_y)
val_p_cal = L.apply_temperature(val_logits, T)
test_p_cal = L.apply_temperature(test_logits, T)
from sklearn.metrics import roc_auc_score
cal_report = {
"method": "temperature_scaling",
"temperature": T,
"valid": {
"auroc_uncal": float(roc_auc_score(val_y, val_p_uncal)),
"auroc_cal": float(roc_auc_score(val_y, val_p_cal)),
"ece_uncal": L.expected_calibration_error(val_y, val_p_uncal),
"ece_cal": L.expected_calibration_error(val_y, val_p_cal),
"brier_uncal": L.brier(val_y, val_p_uncal),
"brier_cal": L.brier(val_y, val_p_cal),
},
}
# Decision: temperature scaling is monotonic -> AUROC unchanged. Use calibrated
# if ECE improves or is within tolerance; else fall back to uncalibrated.
use_calibrated = cal_report["valid"]["ece_cal"] <= cal_report["valid"]["ece_uncal"] + 1e-6
cal_report["use_calibrated"] = bool(use_calibrated)
val_p = val_p_cal if use_calibrated else val_p_uncal
test_p = test_p_cal if use_calibrated else test_p_uncal
L.save_json(cal_report, out / "configs" / "calibration.json")
# ---- threshold selection on VALIDATION (calibrated probs) ----
thr, sens_v, spec_v, achievable = L.threshold_for_sensitivity(val_y, val_p, TARGET_SENS)
yj_thr, yj_sens, yj_spec = L.youden_threshold(val_y, val_p)
thr_report = {
"primary_method": f"highest-specificity threshold with sensitivity >= {TARGET_SENS} on validation (calibrated probabilities)",
"locked_threshold": thr,
"valid_sensitivity_at_threshold": sens_v,
"valid_specificity_at_threshold": spec_v,
"target_sensitivity": TARGET_SENS,
"target_achievable": bool(achievable),
"secondary_youden": {"threshold": yj_thr, "sensitivity": yj_sens, "specificity": yj_spec},
"probabilities_used": "calibrated" if use_calibrated else "uncalibrated",
}
L.save_json(thr_report, out / "configs" / "threshold.json")
# ---- preprocessing config (locked) ----
L.save_json({**pp.to_dict(), "positive_class": "Malignant", "positive_index": 1,
"note": "Deterministic eval/inference preprocessing. No augmentation."},
out / "configs" / "preprocess.json")
# ---- VALIDATION metrics at locked threshold ----
val_metrics = L.point_metrics(val_y, val_p, thr)
# ---- TEST metrics (locked) ----
test_metrics = L.point_metrics(test_y, test_p, thr)
test_ci = L.bootstrap_ci(test_y, test_p, thr, n_boot=N_BOOT, seed=BOOT_SEED)
# ---- per-image CSVs ----
save_per_image_csv(res / "valid_predictions.csv", val_ids, val_y, val_p, thr)
save_per_image_csv(res / "test_predictions.csv", test_ids, test_y, test_p, thr)
# ---- figures ----
plot_roc(test_y, test_p, figs / "test_roc.png", "ROC — Test set")
plot_pr(test_y, test_p, figs / "test_pr.png", "Precision-Recall — Test set")
plot_reliability(val_y, val_p_uncal, val_p_cal, figs / "valid_calibration.png",
"Reliability diagram — Validation")
plot_reliability(test_y, test_p_uncal, test_p_cal, figs / "test_calibration.png",
"Reliability diagram — Test")
cm = np.array([[test_metrics["tn"], test_metrics["fp"]],
[test_metrics["fn"], test_metrics["tp"]]])
plot_confusion(cm, figs / "test_confusion_counts.png", "Confusion matrix (counts) — Test")
plot_confusion(cm, figs / "test_confusion_normalized.png",
"Confusion matrix (row-normalized) — Test", normalize=True)
# ---- metrics table (markdown + csv) with CIs ----
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
md = ["# Final Test Metrics (locked model + locked threshold)\n",
f"- Selected run: {args.best_run_name} | selection val AUROC: {args.best_val_auroc}",
f"- Backbone: {ck['backbone']} | Calibration: temperature scaling (T={T:.4f}, "
f"{'used' if use_calibrated else 'not used'})",
f"- Locked threshold (val sens>={TARGET_SENS}): {thr:.4f} | probabilities: "
f"{'calibrated' if use_calibrated else 'uncalibrated'}",
f"- CI method: stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}\n",
"| Metric | Point estimate | 95% CI |",
"|--------|---------------:|:------:|"]
rows_csv = [["metric", "point_estimate", "ci_low", "ci_high"]]
for k in ci_keys:
pe = test_metrics[k]; lo, hi = test_ci[k]
md.append(f"| {k.upper()} | {pe:.4f} | [{lo:.4f}, {hi:.4f}] |")
rows_csv.append([k, f"{pe:.6f}", f"{lo:.6f}", f"{hi:.6f}"])
for k in ["brier", "ece"]:
md.append(f"| {k.upper()} | {test_metrics[k]:.4f} | — |")
rows_csv.append([k, f"{test_metrics[k]:.6f}", "", ""])
md.append(f"\n**Confusion matrix (Test):** TN={test_metrics['tn']}, FP={test_metrics['fp']}, "
f"FN={test_metrics['fn']}, TP={test_metrics['tp']}\n")
(tabs / "test_metrics_with_ci.md").write_text("\n".join(md))
with open(tabs / "test_metrics_with_ci.csv", "w", newline="") as f:
csv.writer(f).writerows(rows_csv)
# ---- consolidated results json ----
final = {
"selected_run": args.best_run_name,
"selection_val_auroc": args.best_val_auroc,
"backbone": ck["backbone"],
"preprocess": pp.to_dict(),
"calibration": cal_report,
"threshold": thr_report,
"valid_metrics_at_locked_threshold": val_metrics,
"test_metrics_at_locked_threshold": test_metrics,
"test_metrics_95ci": {k: list(test_ci[k]) for k in ci_keys},
"ci_method": f"stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}",
}
L.save_json(final, res / "final_results.json")
print("=== CALIBRATION ===")
print(json.dumps(cal_report, indent=2))
print("=== THRESHOLD ===")
print(json.dumps(thr_report, indent=2))
print("=== TEST METRICS ===")
for k in ci_keys:
print(f" {k:12s} {test_metrics[k]:.4f} CI [{test_ci[k][0]:.4f}, {test_ci[k][1]:.4f}]")
print(f" brier {test_metrics['brier']:.4f}")
print(f" ece {test_metrics['ece']:.4f}")
print("Saved to", out)
if __name__ == "__main__":
main()