#!/usr/bin/env python """ Finalization pipeline (run ONCE after the sweep selects the best model by validation AUROC): 1. Load the best checkpoint (selected purely on val AUROC). 2. Recompute Valid + Test logits with deterministic eval preprocessing. 3. Calibrate on VALIDATION via temperature scaling; compare uncalibrated vs calibrated (ECE, Brier, AUROC). Keep calibration only if it does not harm discrimination (AUROC unchanged — temperature scaling is monotonic) and improves/maintains calibration. 4. Select a LOCKED threshold on VALIDATION: highest-specificity threshold with sensitivity >= 0.95 (primary); Youden's J reported as secondary. 5. Evaluate ONCE on TEST with calibrated probs + locked threshold. 6. Bootstrap 95% CIs (stratified, 2000 resamples, seed=42). 7. Save all figures (ROC, PR, calibration/reliability, confusion matrices), tables (markdown + CSV), per-image prediction CSVs (valid + test), calibration config, threshold config, preprocessing config. Outputs go to --output_dir (default model_repo/). """ import argparse import csv import json from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import thyroid_lib as L TARGET_SENS = 0.95 N_BOOT = 2000 BOOT_SEED = 42 def load_model(ckpt_path, device): import torch ck = torch.load(ckpt_path, map_location="cpu", weights_only=False) model, pp = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0), dropout=ck.get("dropout", 0.0)) model.load_state_dict(ck["model_state"]) model.to(device).eval() pp = L.PreprocessConfig.from_dict(ck["preprocess"]) return model, pp, ck def get_logits(model, data_dir, split, pp, device): import torch from torch.utils.data import DataLoader ds = L.ThyroidImageFolder(Path(data_dir) / split, L.build_eval_transform(pp)) loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=(device == "cuda")) logits, labels, ids = L.collect_logits(model, loader, device, amp=False) return logits, labels, ids def save_per_image_csv(path, ids, labels, probs, thr): pred = (np.asarray(probs) >= thr).astype(int) with open(path, "w", newline="") as f: w = csv.writer(f) w.writerow(["image_id", "true_label", "true_class", "probability_malignant", "predicted_label", "predicted_class"]) for i, y, p, pr in zip(ids, labels, probs, pred): w.writerow([i, int(y), L.IDX_TO_CLASS[int(y)], f"{p:.6f}", int(pr), L.IDX_TO_CLASS[int(pr)]]) def plot_roc(y, p, path, title): from sklearn.metrics import roc_curve, roc_auc_score fpr, tpr, _ = roc_curve(y, p) auc = roc_auc_score(y, p) fig, ax = plt.subplots(figsize=(5, 5)) ax.plot(fpr, tpr, label=f"AUROC = {auc:.3f}", color="#C44E52") ax.plot([0, 1], [0, 1], "--", color="gray") ax.set_xlabel("1 - Specificity (FPR)"); ax.set_ylabel("Sensitivity (TPR)") ax.set_title(title); ax.legend(loc="lower right") fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) def plot_pr(y, p, path, title): from sklearn.metrics import precision_recall_curve, average_precision_score prec, rec, _ = precision_recall_curve(y, p) ap = average_precision_score(y, p) fig, ax = plt.subplots(figsize=(5, 5)) ax.plot(rec, prec, label=f"AP = {ap:.3f}", color="#4C72B0") ax.set_xlabel("Recall (Sensitivity)"); ax.set_ylabel("Precision (PPV)") ax.set_title(title); ax.legend(loc="lower left") fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) def plot_reliability(y, p_uncal, p_cal, path, title): from sklearn.calibration import calibration_curve fig, ax = plt.subplots(figsize=(5.5, 5.5)) ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration") for p, lab, col in [(p_uncal, "Uncalibrated", "#888888"), (p_cal, "Temperature-scaled", "#C44E52")]: fpos, mpred = calibration_curve(y, p, n_bins=10, strategy="uniform") ece = L.expected_calibration_error(y, p) br = L.brier(y, p) ax.plot(mpred, fpos, "o-", color=col, label=f"{lab} (ECE={ece:.3f}, Brier={br:.3f})") ax.set_xlabel("Mean predicted probability"); ax.set_ylabel("Observed frequency") ax.set_title(title); ax.legend(loc="upper left", fontsize=8) fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) def plot_confusion(cm, path, title, normalize=False): cm = np.asarray(cm, dtype=float) disp = cm.copy() if normalize: disp = cm / cm.sum(axis=1, keepdims=True).clip(min=1e-9) fig, ax = plt.subplots(figsize=(4.8, 4.4)) im = ax.imshow(disp, cmap="Blues", vmin=0, vmax=disp.max()) ax.set_xticks([0, 1]); ax.set_yticks([0, 1]) ax.set_xticklabels(["Predicted benign", "Predicted malignant"]) ax.set_yticklabels(["True benign", "True malignant"]) for i in range(2): for j in range(2): txt = f"{int(cm[i,j])}" + (f"\n({disp[i,j]*100:.1f}%)" if normalize else "") ax.text(j, i, txt, ha="center", va="center", color="white" if disp[i, j] > disp.max() / 2 else "black", fontsize=11) ax.set_title(title) fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True) ap.add_argument("--data_dir", default="/app/TN5000") ap.add_argument("--output_dir", default="/app/model_repo") ap.add_argument("--best_run_name", default="") ap.add_argument("--best_val_auroc", default="") args = ap.parse_args() import torch device = "cuda" if torch.cuda.is_available() else "cpu" L.set_determinism(42, strict=True) out = Path(args.output_dir) res = out / "results"; figs = res / "figures"; tabs = res / "tables" for d in [out / "configs", figs, tabs]: d.mkdir(parents=True, exist_ok=True) model, pp, ck = load_model(args.ckpt, device) # ---- logits ---- val_logits, val_y, val_ids = get_logits(model, args.data_dir, "Valid", pp, device) test_logits, test_y, test_ids = get_logits(model, args.data_dir, "Test", pp, device) val_p_uncal = L.sigmoid(val_logits) test_p_uncal = L.sigmoid(test_logits) # ---- calibration: temperature scaling on VALIDATION ---- T = L.fit_temperature(val_logits, val_y) val_p_cal = L.apply_temperature(val_logits, T) test_p_cal = L.apply_temperature(test_logits, T) from sklearn.metrics import roc_auc_score cal_report = { "method": "temperature_scaling", "temperature": T, "valid": { "auroc_uncal": float(roc_auc_score(val_y, val_p_uncal)), "auroc_cal": float(roc_auc_score(val_y, val_p_cal)), "ece_uncal": L.expected_calibration_error(val_y, val_p_uncal), "ece_cal": L.expected_calibration_error(val_y, val_p_cal), "brier_uncal": L.brier(val_y, val_p_uncal), "brier_cal": L.brier(val_y, val_p_cal), }, } # Decision: temperature scaling is monotonic -> AUROC unchanged. Use calibrated # if ECE improves or is within tolerance; else fall back to uncalibrated. use_calibrated = cal_report["valid"]["ece_cal"] <= cal_report["valid"]["ece_uncal"] + 1e-6 cal_report["use_calibrated"] = bool(use_calibrated) val_p = val_p_cal if use_calibrated else val_p_uncal test_p = test_p_cal if use_calibrated else test_p_uncal L.save_json(cal_report, out / "configs" / "calibration.json") # ---- threshold selection on VALIDATION (calibrated probs) ---- thr, sens_v, spec_v, achievable = L.threshold_for_sensitivity(val_y, val_p, TARGET_SENS) yj_thr, yj_sens, yj_spec = L.youden_threshold(val_y, val_p) thr_report = { "primary_method": f"highest-specificity threshold with sensitivity >= {TARGET_SENS} on validation (calibrated probabilities)", "locked_threshold": thr, "valid_sensitivity_at_threshold": sens_v, "valid_specificity_at_threshold": spec_v, "target_sensitivity": TARGET_SENS, "target_achievable": bool(achievable), "secondary_youden": {"threshold": yj_thr, "sensitivity": yj_sens, "specificity": yj_spec}, "probabilities_used": "calibrated" if use_calibrated else "uncalibrated", } L.save_json(thr_report, out / "configs" / "threshold.json") # ---- preprocessing config (locked) ---- L.save_json({**pp.to_dict(), "positive_class": "Malignant", "positive_index": 1, "note": "Deterministic eval/inference preprocessing. No augmentation."}, out / "configs" / "preprocess.json") # ---- VALIDATION metrics at locked threshold ---- val_metrics = L.point_metrics(val_y, val_p, thr) # ---- TEST metrics (locked) ---- test_metrics = L.point_metrics(test_y, test_p, thr) test_ci = L.bootstrap_ci(test_y, test_p, thr, n_boot=N_BOOT, seed=BOOT_SEED) # ---- per-image CSVs ---- save_per_image_csv(res / "valid_predictions.csv", val_ids, val_y, val_p, thr) save_per_image_csv(res / "test_predictions.csv", test_ids, test_y, test_p, thr) # ---- figures ---- plot_roc(test_y, test_p, figs / "test_roc.png", "ROC — Test set") plot_pr(test_y, test_p, figs / "test_pr.png", "Precision-Recall — Test set") plot_reliability(val_y, val_p_uncal, val_p_cal, figs / "valid_calibration.png", "Reliability diagram — Validation") plot_reliability(test_y, test_p_uncal, test_p_cal, figs / "test_calibration.png", "Reliability diagram — Test") cm = np.array([[test_metrics["tn"], test_metrics["fp"]], [test_metrics["fn"], test_metrics["tp"]]]) plot_confusion(cm, figs / "test_confusion_counts.png", "Confusion matrix (counts) — Test") plot_confusion(cm, figs / "test_confusion_normalized.png", "Confusion matrix (row-normalized) — Test", normalize=True) # ---- metrics table (markdown + csv) with CIs ---- ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] md = ["# Final Test Metrics (locked model + locked threshold)\n", f"- Selected run: {args.best_run_name} | selection val AUROC: {args.best_val_auroc}", f"- Backbone: {ck['backbone']} | Calibration: temperature scaling (T={T:.4f}, " f"{'used' if use_calibrated else 'not used'})", f"- Locked threshold (val sens>={TARGET_SENS}): {thr:.4f} | probabilities: " f"{'calibrated' if use_calibrated else 'uncalibrated'}", f"- CI method: stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}\n", "| Metric | Point estimate | 95% CI |", "|--------|---------------:|:------:|"] rows_csv = [["metric", "point_estimate", "ci_low", "ci_high"]] for k in ci_keys: pe = test_metrics[k]; lo, hi = test_ci[k] md.append(f"| {k.upper()} | {pe:.4f} | [{lo:.4f}, {hi:.4f}] |") rows_csv.append([k, f"{pe:.6f}", f"{lo:.6f}", f"{hi:.6f}"]) for k in ["brier", "ece"]: md.append(f"| {k.upper()} | {test_metrics[k]:.4f} | — |") rows_csv.append([k, f"{test_metrics[k]:.6f}", "", ""]) md.append(f"\n**Confusion matrix (Test):** TN={test_metrics['tn']}, FP={test_metrics['fp']}, " f"FN={test_metrics['fn']}, TP={test_metrics['tp']}\n") (tabs / "test_metrics_with_ci.md").write_text("\n".join(md)) with open(tabs / "test_metrics_with_ci.csv", "w", newline="") as f: csv.writer(f).writerows(rows_csv) # ---- consolidated results json ---- final = { "selected_run": args.best_run_name, "selection_val_auroc": args.best_val_auroc, "backbone": ck["backbone"], "preprocess": pp.to_dict(), "calibration": cal_report, "threshold": thr_report, "valid_metrics_at_locked_threshold": val_metrics, "test_metrics_at_locked_threshold": test_metrics, "test_metrics_95ci": {k: list(test_ci[k]) for k in ci_keys}, "ci_method": f"stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}", } L.save_json(final, res / "final_results.json") print("=== CALIBRATION ===") print(json.dumps(cal_report, indent=2)) print("=== THRESHOLD ===") print(json.dumps(thr_report, indent=2)) print("=== TEST METRICS ===") for k in ci_keys: print(f" {k:12s} {test_metrics[k]:.4f} CI [{test_ci[k][0]:.4f}, {test_ci[k][1]:.4f}]") print(f" brier {test_metrics['brier']:.4f}") print(f" ece {test_metrics['ece']:.4f}") print("Saved to", out) if __name__ == "__main__": main()