File size: 12,754 Bytes
45af8e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | #!/usr/bin/env python
"""
Finalization pipeline (run ONCE after the sweep selects the best model by
validation AUROC):
1. Load the best checkpoint (selected purely on val AUROC).
2. Recompute Valid + Test logits with deterministic eval preprocessing.
3. Calibrate on VALIDATION via temperature scaling; compare uncalibrated vs
calibrated (ECE, Brier, AUROC). Keep calibration only if it does not harm
discrimination (AUROC unchanged — temperature scaling is monotonic) and
improves/maintains calibration.
4. Select a LOCKED threshold on VALIDATION: highest-specificity threshold with
sensitivity >= 0.95 (primary); Youden's J reported as secondary.
5. Evaluate ONCE on TEST with calibrated probs + locked threshold.
6. Bootstrap 95% CIs (stratified, 2000 resamples, seed=42).
7. Save all figures (ROC, PR, calibration/reliability, confusion matrices),
tables (markdown + CSV), per-image prediction CSVs (valid + test),
calibration config, threshold config, preprocessing config.
Outputs go to --output_dir (default model_repo/).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import thyroid_lib as L
TARGET_SENS = 0.95
N_BOOT = 2000
BOOT_SEED = 42
def load_model(ckpt_path, device):
import torch
ck = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model, pp = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"])
model.to(device).eval()
pp = L.PreprocessConfig.from_dict(ck["preprocess"])
return model, pp, ck
def get_logits(model, data_dir, split, pp, device):
import torch
from torch.utils.data import DataLoader
ds = L.ThyroidImageFolder(Path(data_dir) / split, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits, labels, ids = L.collect_logits(model, loader, device, amp=False)
return logits, labels, ids
def save_per_image_csv(path, ids, labels, probs, thr):
pred = (np.asarray(probs) >= thr).astype(int)
with open(path, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "true_class", "probability_malignant",
"predicted_label", "predicted_class"])
for i, y, p, pr in zip(ids, labels, probs, pred):
w.writerow([i, int(y), L.IDX_TO_CLASS[int(y)], f"{p:.6f}",
int(pr), L.IDX_TO_CLASS[int(pr)]])
def plot_roc(y, p, path, title):
from sklearn.metrics import roc_curve, roc_auc_score
fpr, tpr, _ = roc_curve(y, p)
auc = roc_auc_score(y, p)
fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(fpr, tpr, label=f"AUROC = {auc:.3f}", color="#C44E52")
ax.plot([0, 1], [0, 1], "--", color="gray")
ax.set_xlabel("1 - Specificity (FPR)"); ax.set_ylabel("Sensitivity (TPR)")
ax.set_title(title); ax.legend(loc="lower right")
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_pr(y, p, path, title):
from sklearn.metrics import precision_recall_curve, average_precision_score
prec, rec, _ = precision_recall_curve(y, p)
ap = average_precision_score(y, p)
fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(rec, prec, label=f"AP = {ap:.3f}", color="#4C72B0")
ax.set_xlabel("Recall (Sensitivity)"); ax.set_ylabel("Precision (PPV)")
ax.set_title(title); ax.legend(loc="lower left")
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_reliability(y, p_uncal, p_cal, path, title):
from sklearn.calibration import calibration_curve
fig, ax = plt.subplots(figsize=(5.5, 5.5))
ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration")
for p, lab, col in [(p_uncal, "Uncalibrated", "#888888"), (p_cal, "Temperature-scaled", "#C44E52")]:
fpos, mpred = calibration_curve(y, p, n_bins=10, strategy="uniform")
ece = L.expected_calibration_error(y, p)
br = L.brier(y, p)
ax.plot(mpred, fpos, "o-", color=col, label=f"{lab} (ECE={ece:.3f}, Brier={br:.3f})")
ax.set_xlabel("Mean predicted probability"); ax.set_ylabel("Observed frequency")
ax.set_title(title); ax.legend(loc="upper left", fontsize=8)
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def plot_confusion(cm, path, title, normalize=False):
cm = np.asarray(cm, dtype=float)
disp = cm.copy()
if normalize:
disp = cm / cm.sum(axis=1, keepdims=True).clip(min=1e-9)
fig, ax = plt.subplots(figsize=(4.8, 4.4))
im = ax.imshow(disp, cmap="Blues", vmin=0, vmax=disp.max())
ax.set_xticks([0, 1]); ax.set_yticks([0, 1])
ax.set_xticklabels(["Predicted benign", "Predicted malignant"])
ax.set_yticklabels(["True benign", "True malignant"])
for i in range(2):
for j in range(2):
txt = f"{int(cm[i,j])}" + (f"\n({disp[i,j]*100:.1f}%)" if normalize else "")
ax.text(j, i, txt, ha="center", va="center",
color="white" if disp[i, j] > disp.max() / 2 else "black", fontsize=11)
ax.set_title(title)
fig.tight_layout(); fig.savefig(path, dpi=150); plt.close(fig)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--data_dir", default="/app/TN5000")
ap.add_argument("--output_dir", default="/app/model_repo")
ap.add_argument("--best_run_name", default="")
ap.add_argument("--best_val_auroc", default="")
args = ap.parse_args()
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(42, strict=True)
out = Path(args.output_dir)
res = out / "results"; figs = res / "figures"; tabs = res / "tables"
for d in [out / "configs", figs, tabs]:
d.mkdir(parents=True, exist_ok=True)
model, pp, ck = load_model(args.ckpt, device)
# ---- logits ----
val_logits, val_y, val_ids = get_logits(model, args.data_dir, "Valid", pp, device)
test_logits, test_y, test_ids = get_logits(model, args.data_dir, "Test", pp, device)
val_p_uncal = L.sigmoid(val_logits)
test_p_uncal = L.sigmoid(test_logits)
# ---- calibration: temperature scaling on VALIDATION ----
T = L.fit_temperature(val_logits, val_y)
val_p_cal = L.apply_temperature(val_logits, T)
test_p_cal = L.apply_temperature(test_logits, T)
from sklearn.metrics import roc_auc_score
cal_report = {
"method": "temperature_scaling",
"temperature": T,
"valid": {
"auroc_uncal": float(roc_auc_score(val_y, val_p_uncal)),
"auroc_cal": float(roc_auc_score(val_y, val_p_cal)),
"ece_uncal": L.expected_calibration_error(val_y, val_p_uncal),
"ece_cal": L.expected_calibration_error(val_y, val_p_cal),
"brier_uncal": L.brier(val_y, val_p_uncal),
"brier_cal": L.brier(val_y, val_p_cal),
},
}
# Decision: temperature scaling is monotonic -> AUROC unchanged. Use calibrated
# if ECE improves or is within tolerance; else fall back to uncalibrated.
use_calibrated = cal_report["valid"]["ece_cal"] <= cal_report["valid"]["ece_uncal"] + 1e-6
cal_report["use_calibrated"] = bool(use_calibrated)
val_p = val_p_cal if use_calibrated else val_p_uncal
test_p = test_p_cal if use_calibrated else test_p_uncal
L.save_json(cal_report, out / "configs" / "calibration.json")
# ---- threshold selection on VALIDATION (calibrated probs) ----
thr, sens_v, spec_v, achievable = L.threshold_for_sensitivity(val_y, val_p, TARGET_SENS)
yj_thr, yj_sens, yj_spec = L.youden_threshold(val_y, val_p)
thr_report = {
"primary_method": f"highest-specificity threshold with sensitivity >= {TARGET_SENS} on validation (calibrated probabilities)",
"locked_threshold": thr,
"valid_sensitivity_at_threshold": sens_v,
"valid_specificity_at_threshold": spec_v,
"target_sensitivity": TARGET_SENS,
"target_achievable": bool(achievable),
"secondary_youden": {"threshold": yj_thr, "sensitivity": yj_sens, "specificity": yj_spec},
"probabilities_used": "calibrated" if use_calibrated else "uncalibrated",
}
L.save_json(thr_report, out / "configs" / "threshold.json")
# ---- preprocessing config (locked) ----
L.save_json({**pp.to_dict(), "positive_class": "Malignant", "positive_index": 1,
"note": "Deterministic eval/inference preprocessing. No augmentation."},
out / "configs" / "preprocess.json")
# ---- VALIDATION metrics at locked threshold ----
val_metrics = L.point_metrics(val_y, val_p, thr)
# ---- TEST metrics (locked) ----
test_metrics = L.point_metrics(test_y, test_p, thr)
test_ci = L.bootstrap_ci(test_y, test_p, thr, n_boot=N_BOOT, seed=BOOT_SEED)
# ---- per-image CSVs ----
save_per_image_csv(res / "valid_predictions.csv", val_ids, val_y, val_p, thr)
save_per_image_csv(res / "test_predictions.csv", test_ids, test_y, test_p, thr)
# ---- figures ----
plot_roc(test_y, test_p, figs / "test_roc.png", "ROC — Test set")
plot_pr(test_y, test_p, figs / "test_pr.png", "Precision-Recall — Test set")
plot_reliability(val_y, val_p_uncal, val_p_cal, figs / "valid_calibration.png",
"Reliability diagram — Validation")
plot_reliability(test_y, test_p_uncal, test_p_cal, figs / "test_calibration.png",
"Reliability diagram — Test")
cm = np.array([[test_metrics["tn"], test_metrics["fp"]],
[test_metrics["fn"], test_metrics["tp"]]])
plot_confusion(cm, figs / "test_confusion_counts.png", "Confusion matrix (counts) — Test")
plot_confusion(cm, figs / "test_confusion_normalized.png",
"Confusion matrix (row-normalized) — Test", normalize=True)
# ---- metrics table (markdown + csv) with CIs ----
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
md = ["# Final Test Metrics (locked model + locked threshold)\n",
f"- Selected run: {args.best_run_name} | selection val AUROC: {args.best_val_auroc}",
f"- Backbone: {ck['backbone']} | Calibration: temperature scaling (T={T:.4f}, "
f"{'used' if use_calibrated else 'not used'})",
f"- Locked threshold (val sens>={TARGET_SENS}): {thr:.4f} | probabilities: "
f"{'calibrated' if use_calibrated else 'uncalibrated'}",
f"- CI method: stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}\n",
"| Metric | Point estimate | 95% CI |",
"|--------|---------------:|:------:|"]
rows_csv = [["metric", "point_estimate", "ci_low", "ci_high"]]
for k in ci_keys:
pe = test_metrics[k]; lo, hi = test_ci[k]
md.append(f"| {k.upper()} | {pe:.4f} | [{lo:.4f}, {hi:.4f}] |")
rows_csv.append([k, f"{pe:.6f}", f"{lo:.6f}", f"{hi:.6f}"])
for k in ["brier", "ece"]:
md.append(f"| {k.upper()} | {test_metrics[k]:.4f} | — |")
rows_csv.append([k, f"{test_metrics[k]:.6f}", "", ""])
md.append(f"\n**Confusion matrix (Test):** TN={test_metrics['tn']}, FP={test_metrics['fp']}, "
f"FN={test_metrics['fn']}, TP={test_metrics['tp']}\n")
(tabs / "test_metrics_with_ci.md").write_text("\n".join(md))
with open(tabs / "test_metrics_with_ci.csv", "w", newline="") as f:
csv.writer(f).writerows(rows_csv)
# ---- consolidated results json ----
final = {
"selected_run": args.best_run_name,
"selection_val_auroc": args.best_val_auroc,
"backbone": ck["backbone"],
"preprocess": pp.to_dict(),
"calibration": cal_report,
"threshold": thr_report,
"valid_metrics_at_locked_threshold": val_metrics,
"test_metrics_at_locked_threshold": test_metrics,
"test_metrics_95ci": {k: list(test_ci[k]) for k in ci_keys},
"ci_method": f"stratified bootstrap, {N_BOOT} resamples, seed={BOOT_SEED}",
}
L.save_json(final, res / "final_results.json")
print("=== CALIBRATION ===")
print(json.dumps(cal_report, indent=2))
print("=== THRESHOLD ===")
print(json.dumps(thr_report, indent=2))
print("=== TEST METRICS ===")
for k in ci_keys:
print(f" {k:12s} {test_metrics[k]:.4f} CI [{test_ci[k][0]:.4f}, {test_ci[k][1]:.4f}]")
print(f" brier {test_metrics['brier']:.4f}")
print(f" ece {test_metrics['ece']:.4f}")
print("Saved to", out)
if __name__ == "__main__":
main()
|