Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified | #!/usr/bin/env python | |
| """ | |
| Finalization pipeline (run ONCE after the sweep selects the best model by | |
| validation AUROC): | |
| 1. Load the best checkpoint (selected purely on val AUROC). | |
| 2. Recompute Valid + Test logits with deterministic eval preprocessing. | |
| 3. Calibrate on VALIDATION via temperature scaling; compare uncalibrated vs | |
| calibrated (ECE, Brier, AUROC). Keep calibration only if it does not harm | |
| discrimination (AUROC unchanged — temperature scaling is monotonic) and | |
| improves/maintains calibration. | |
| 4. Select a LOCKED threshold on VALIDATION: highest-specificity threshold with | |
| sensitivity >= 0.95 (primary); Youden's J reported as secondary. | |
| 5. Evaluate ONCE on TEST with calibrated probs + locked threshold. | |
| 6. Bootstrap 95% CIs (stratified, 2000 resamples, seed=42). | |
| 7. Save all figures (ROC, PR, calibration/reliability, confusion matrices), | |
| tables (markdown + CSV), per-image prediction CSVs (valid + test), | |
| calibration config, threshold config, preprocessing config. | |
| Outputs go to --output_dir (default model_repo/). | |
| """ | |
| import argparse | |
| import csv | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import thyroid_lib as L | |
| TARGET_SENS = 0.95 | |
| N_BOOT = 2000 | |
| BOOT_SEED = 42 | |
| def load_model(ckpt_path, device): | |
| import torch | |
| ck = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| model, pp = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0), | |
| dropout=ck.get("dropout", 0.0)) | |
| model.load_state_dict(ck["model_state"]) | |
| model.to(device).eval() | |
| pp = L.PreprocessConfig.from_dict(ck["preprocess"]) | |
| return model, pp, ck | |
| def get_logits(model, data_dir, split, pp, device): | |
| import torch | |
| from torch.utils.data import DataLoader | |
| ds = L.ThyroidImageFolder(Path(data_dir) / split, L.build_eval_transform(pp)) | |
| loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4, | |
| pin_memory=(device == "cuda")) | |
| logits, labels, ids = L.collect_logits(model, loader, device, amp=False) | |
| return logits, labels, ids | |
| def save_per_image_csv(path, ids, labels, probs, thr): | |
| pred = (np.asarray(probs) >= thr).astype(int) | |
| with open(path, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["image_id", "true_label", "true_class", "probability_malignant", | |
| "predicted_label", "predicted_class"]) | |
| for i, y, p, pr in zip(ids, labels, probs, pred): | |
| w.writerow([i, int(y), L.IDX_TO_CLASS[int(y)], f"{p:.6f}", | |
| int(pr), L.IDX_TO_CLASS[int(pr)]]) | |
| def plot_roc(y, p, path, title): | |
| from sklearn.metrics import roc_curve, roc_auc_score | |
| fpr, tpr, _ = roc_curve(y, p) | |
| auc = roc_auc_score(y, p) | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| ax.plot(fpr, tpr, label=f"AUROC = {auc:.3f}", color="#C44E52") | |
| ax.plot([0, 1], [0, 1], "--", color="gray") | |
| ax.set_xlabel("1 - Specificity (FPR)"); ax.set_ylabel("Sensitivity (TPR)") | |
| ax.set_title(title); ax.legend(loc="lower right") | |
| fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) | |
| def plot_pr(y, p, path, title): | |
| from sklearn.metrics import precision_recall_curve, average_precision_score | |
| prec, rec, _ = precision_recall_curve(y, p) | |
| ap = average_precision_score(y, p) | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| ax.plot(rec, prec, label=f"AP = {ap:.3f}", color="#4C72B0") | |
| ax.set_xlabel("Recall (Sensitivity)"); ax.set_ylabel("Precision (PPV)") | |
| ax.set_title(title); ax.legend(loc="lower left") | |
| fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) | |
| def plot_reliability(y, p_uncal, p_cal, path, title): | |
| from sklearn.calibration import calibration_curve | |
| fig, ax = plt.subplots(figsize=(5.5, 5.5)) | |
| ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration") | |
| for p, lab, col in [(p_uncal, "Uncalibrated", "#888888"), (p_cal, "Temperature-scaled", "#C44E52")]: | |
| fpos, mpred = calibration_curve(y, p, n_bins=10, strategy="uniform") | |
| ece = L.expected_calibration_error(y, p) | |
| br = L.brier(y, p) | |
| ax.plot(mpred, fpos, "o-", color=col, label=f"{lab} (ECE={ece:.3f}, Brier={br:.3f})") | |
| ax.set_xlabel("Mean predicted probability"); ax.set_ylabel("Observed frequency") | |
| ax.set_title(title); ax.legend(loc="upper left", fontsize=8) | |
| fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) | |
| def plot_confusion(cm, path, title, normalize=False): | |
| cm = np.asarray(cm, dtype=float) | |
| disp = cm.copy() | |
| if normalize: | |
| disp = cm / cm.sum(axis=1, keepdims=True).clip(min=1e-9) | |
| fig, ax = plt.subplots(figsize=(4.8, 4.4)) | |
| im = ax.imshow(disp, cmap="Blues", vmin=0, vmax=disp.max()) | |
| ax.set_xticks([0, 1]); ax.set_yticks([0, 1]) | |
| ax.set_xticklabels(["Predicted benign", "Predicted malignant"]) | |
| ax.set_yticklabels(["True benign", "True malignant"]) | |
| for i in range(2): | |
| for j in range(2): | |
| txt = f"{int(cm[i,j])}" + (f"\n({disp[i,j]*100:.1f}%)" if normalize else "") | |
| ax.text(j, i, txt, ha="center", va="center", | |
| color="white" if disp[i, j] > disp.max() / 2 else "black", fontsize=11) | |
| ax.set_title(title) | |
| fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig) | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--ckpt", required=True) | |
| ap.add_argument("--data_dir", default="/app/TN5000") | |
| ap.add_argument("--output_dir", default="/app/model_repo") | |
| ap.add_argument("--best_run_name", default="") | |
| ap.add_argument("--best_val_auroc", default="") | |
| args = ap.parse_args() | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| L.set_determinism(42, strict=True) | |
| out = Path(args.output_dir) | |
| res = out / "results"; figs = res / "figures"; tabs = res / "tables" | |
| for d in [out / "configs", figs, tabs]: | |
| d.mkdir(parents=True, exist_ok=True) | |
| model, pp, ck = load_model(args.ckpt, device) | |
| # ---- logits ---- | |
| val_logits, val_y, val_ids = get_logits(model, args.data_dir, "Valid", pp, device) | |
| test_logits, test_y, test_ids = get_logits(model, args.data_dir, "Test", pp, device) | |
| val_p_uncal = L.sigmoid(val_logits) | |
| test_p_uncal = L.sigmoid(test_logits) | |
| # ---- calibration: temperature scaling on VALIDATION ---- | |
| T = L.fit_temperature(val_logits, val_y) | |
| val_p_cal = L.apply_temperature(val_logits, T) | |
| test_p_cal = L.apply_temperature(test_logits, T) | |
| from sklearn.metrics import roc_auc_score | |
| cal_report = { | |
| "method": "temperature_scaling", | |
| "temperature": T, | |
| "valid": { | |
| "auroc_uncal": float(roc_auc_score(val_y, val_p_uncal)), | |
| "auroc_cal": float(roc_auc_score(val_y, val_p_cal)), | |
| "ece_uncal": L.expected_calibration_error(val_y, val_p_uncal), | |
| "ece_cal": L.expected_calibration_error(val_y, val_p_cal), | |
| "brier_uncal": L.brier(val_y, val_p_uncal), | |
| "brier_cal": L.brier(val_y, val_p_cal), | |
| }, | |
| } | |
| # Decision: temperature scaling is monotonic -> AUROC unchanged. Use calibrated | |
| # if ECE improves or is within tolerance; else fall back to uncalibrated. | |
| use_calibrated = cal_report["valid"]["ece_cal"] <= cal_report["valid"]["ece_uncal"] + 1e-6 | |
| cal_report["use_calibrated"] = bool(use_calibrated) | |
| val_p = val_p_cal if use_calibrated else val_p_uncal | |
| test_p = test_p_cal if use_calibrated else test_p_uncal | |
| L.save_json(cal_report, out / "configs" / "calibration.json") | |
| # ---- threshold selection on VALIDATION (calibrated probs) ---- | |
| thr, sens_v, spec_v, achievable = L.threshold_for_sensitivity(val_y, val_p, TARGET_SENS) | |
| yj_thr, yj_sens, yj_spec = L.youden_threshold(val_y, val_p) | |
| thr_report = { | |
| "primary_method": f"highest-specificity threshold with sensitivity >= {TARGET_SENS} on validation (calibrated probabilities)", | |
| "locked_threshold": thr, | |
| "valid_sensitivity_at_threshold": sens_v, | |
| "valid_specificity_at_threshold": spec_v, | |
| "target_sensitivity": TARGET_SENS, | |
| "target_achievable": bool(achievable), | |
| "secondary_youden": {"threshold": yj_thr, "sensitivity": yj_sens, "specificity": yj_spec}, | |
| "probabilities_used": "calibrated" if use_calibrated else "uncalibrated", | |
| } | |
| L.save_json(thr_report, out / "configs" / "threshold.json") | |
| # ---- preprocessing config (locked) ---- | |
| L.save_json({**pp.to_dict(), "positive_class": "Malignant", "positive_index": 1, | |
| "note": "Deterministic eval/inference preprocessing. No augmentation."}, | |
| out / "configs" / "preprocess.json") | |
| # ---- VALIDATION metrics at locked threshold ---- | |
| val_metrics = L.point_metrics(val_y, val_p, thr) | |
| # ---- TEST metrics (locked) ---- | |
| test_metrics = L.point_metrics(test_y, test_p, thr) | |
| test_ci = L.bootstrap_ci(test_y, test_p, thr, n_boot=N_BOOT, seed=BOOT_SEED) | |
| # ---- per-image CSVs ---- | |
| save_per_image_csv(res / "valid_predictions.csv", val_ids, val_y, val_p, thr) | |
| save_per_image_csv(res / "test_predictions.csv", test_ids, test_y, test_p, thr) | |
| # ---- figures ---- | |
| plot_roc(test_y, test_p, figs / "test_roc.png", "ROC — Test set") | |
| plot_pr(test_y, test_p, figs / "test_pr.png", "Precision-Recall — Test set") | |
| plot_reliability(val_y, val_p_uncal, val_p_cal, figs / "valid_calibration.png", | |
| "Reliability diagram — Validation") | |
| plot_reliability(test_y, test_p_uncal, test_p_cal, figs / "test_calibration.png", | |
| "Reliability diagram — Test") | |
| cm = np.array([[test_metrics["tn"], test_metrics["fp"]], | |
| [test_metrics["fn"], test_metrics["tp"]]]) | |
| plot_confusion(cm, figs / "test_confusion_counts.png", "Confusion matrix (counts) — Test") | |
| plot_confusion(cm, figs / "test_confusion_normalized.png", | |
| "Confusion matrix (row-normalized) — Test", normalize=True) | |
| # ---- metrics table (markdown + csv) with CIs ---- | |
| ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] | |
| md = ["# Final Test Metrics (locked model + locked threshold)\n", | |
| f"- Selected run: {args.best_run_name} | selection val AUROC: {args.best_val_auroc}", | |
| f"- Backbone: {ck['backbone']} | Calibration: temperature scaling (T={T:.4f}, " | |
| f"{'used' if use_calibrated else 'not used'})", | |
| f"- Locked threshold (val sens>={TARGET_SENS}): {thr:.4f} | probabilities: " | |
| f"{'calibrated' if use_calibrated else 'uncalibrated'}", | |
| f"- CI method: stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}\n", | |
| "| Metric | Point estimate | 95% CI |", | |
| "|--------|---------------:|:------:|"] | |
| rows_csv = [["metric", "point_estimate", "ci_low", "ci_high"]] | |
| for k in ci_keys: | |
| pe = test_metrics[k]; lo, hi = test_ci[k] | |
| md.append(f"| {k.upper()} | {pe:.4f} | [{lo:.4f}, {hi:.4f}] |") | |
| rows_csv.append([k, f"{pe:.6f}", f"{lo:.6f}", f"{hi:.6f}"]) | |
| for k in ["brier", "ece"]: | |
| md.append(f"| {k.upper()} | {test_metrics[k]:.4f} | — |") | |
| rows_csv.append([k, f"{test_metrics[k]:.6f}", "", ""]) | |
| md.append(f"\n**Confusion matrix (Test):** TN={test_metrics['tn']}, FP={test_metrics['fp']}, " | |
| f"FN={test_metrics['fn']}, TP={test_metrics['tp']}\n") | |
| (tabs / "test_metrics_with_ci.md").write_text("\n".join(md)) | |
| with open(tabs / "test_metrics_with_ci.csv", "w", newline="") as f: | |
| csv.writer(f).writerows(rows_csv) | |
| # ---- consolidated results json ---- | |
| final = { | |
| "selected_run": args.best_run_name, | |
| "selection_val_auroc": args.best_val_auroc, | |
| "backbone": ck["backbone"], | |
| "preprocess": pp.to_dict(), | |
| "calibration": cal_report, | |
| "threshold": thr_report, | |
| "valid_metrics_at_locked_threshold": val_metrics, | |
| "test_metrics_at_locked_threshold": test_metrics, | |
| "test_metrics_95ci": {k: list(test_ci[k]) for k in ci_keys}, | |
| "ci_method": f"stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}", | |
| } | |
| L.save_json(final, res / "final_results.json") | |
| print("=== CALIBRATION ===") | |
| print(json.dumps(cal_report, indent=2)) | |
| print("=== THRESHOLD ===") | |
| print(json.dumps(thr_report, indent=2)) | |
| print("=== TEST METRICS ===") | |
| for k in ci_keys: | |
| print(f" {k:12s} {test_metrics[k]:.4f} CI [{test_ci[k][0]:.4f}, {test_ci[k][1]:.4f}]") | |
| print(f" brier {test_metrics['brier']:.4f}") | |
| print(f" ece {test_metrics['ece']:.4f}") | |
| print("Saved to", out) | |
| if __name__ == "__main__": | |
| main() | |