#!/usr/bin/env python """ Final evaluation script for the validation/test splits of TN5000. Uses the LOCKED final model weights, LOCKED preprocessing, LOCKED calibration (temperature scaling) and LOCKED decision threshold stored in this repo. The test split is intended to be evaluated only once, after model/calibration/ threshold were frozen. Usage: python evaluate.py --split test --config configs/final_config.yaml python evaluate.py --split valid Reads (relative to --repo_dir, default '.'): final_model.pt (or weights pointed to by final_config.yaml) configs/preprocess.json configs/calibration.json configs/threshold.json Writes per-image predictions + a metrics table (with bootstrap 95% CI) to --output_dir (default results/eval_/). """ import argparse import csv import json from pathlib import Path import numpy as np import yaml import thyroid_lib as L def main(): ap = argparse.ArgumentParser() ap.add_argument("--split", choices=["valid", "test"], required=True) ap.add_argument("--config", default=None, help="optional final_config.yaml (for data_dir/dataset_id)") ap.add_argument("--repo_dir", default=".") ap.add_argument("--data_dir", default=None, help="local TN5000 dir; else downloaded from Hub") ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification") ap.add_argument("--weights", default="final_model.pt") ap.add_argument("--output_dir", default=None) ap.add_argument("--n_boot", type=int, default=2000) ap.add_argument("--boot_seed", type=int, default=42) args = ap.parse_args() if args.config and Path(args.config).exists(): cfg = yaml.safe_load(open(args.config)) or {} if not args.data_dir and cfg.get("data_dir"): args.data_dir = cfg["data_dir"] if cfg.get("dataset_id"): args.dataset_id = cfg["dataset_id"] import torch from torch.utils.data import DataLoader device = "cuda" if torch.cuda.is_available() else "cpu" L.set_determinism(args.boot_seed, strict=True) repo = Path(args.repo_dir) split_name = {"valid": "Valid", "test": "Test"}[args.split] out_dir = Path(args.output_dir) if args.output_dir else repo / "results" / f"eval_{args.split}" out_dir.mkdir(parents=True, exist_ok=True) # data if args.data_dir: data_dir = Path(args.data_dir) else: from huggingface_hub import snapshot_download data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset", local_dir=str(out_dir / "_data"), allow_patterns=[f"{split_name}/**"])) # locked artifacts pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json"))) calib = json.load(open(repo / "configs" / "calibration.json")) thr_cfg = json.load(open(repo / "configs" / "threshold.json")) T = calib["temperature"]; use_cal = calib.get("use_calibrated", True) thr = thr_cfg["locked_threshold"] # model ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False) model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0), dropout=ck.get("dropout", 0.0)) model.load_state_dict(ck["model_state"]); model.to(device).eval() ds = L.ThyroidImageFolder(data_dir / split_name, L.build_eval_transform(pp)) loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4, pin_memory=(device == "cuda")) logits, y, ids = L.collect_logits(model, loader, device, amp=False) probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits) metrics = L.point_metrics(y, probs, thr) ci = L.bootstrap_ci(y, probs, thr, n_boot=args.n_boot, seed=args.boot_seed) # per-image csv pred = (probs >= thr).astype(int) with open(out_dir / f"{args.split}_predictions.csv", "w", newline="") as f: w = csv.writer(f) w.writerow(["image_id", "true_label", "true_class", "probability_malignant", "predicted_label", "predicted_class"]) for i, yy, pr, pd in zip(ids, y, probs, pred): w.writerow([i, int(yy), L.IDX_TO_CLASS[int(yy)], f"{pr:.6f}", int(pd), L.IDX_TO_CLASS[int(pd)]]) ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] out = {"split": args.split, "n": metrics["n"], "n_pos": metrics["n_pos"], "n_neg": metrics["n_neg"], "threshold": thr, "calibration": "temperature(T=%.4f)" % T if use_cal else "none", "metrics": metrics, "metrics_95ci": {k: list(ci[k]) for k in ci_keys}, "ci_method": f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"} json.dump(out, open(out_dir / f"{args.split}_metrics.json", "w"), indent=2) print(f"=== {args.split.upper()} (n={metrics['n']}, thr={thr:.4f}) ===") for k in ci_keys: print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]") print(f" brier {metrics['brier']:.4f}") print(f" ece {metrics['ece']:.4f}") print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}") print("Saved to", out_dir) if __name__ == "__main__": main()