agentic_thyroid_model / evaluate.py
Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
5.39 kB
#!/usr/bin/env python
"""
Final evaluation script for the validation/test splits of TN5000.
Uses the LOCKED final model weights, LOCKED preprocessing, LOCKED calibration
(temperature scaling) and LOCKED decision threshold stored in this repo. The
test split is intended to be evaluated only once, after model/calibration/
threshold were frozen.
Usage:
python evaluate.py --split test --config configs/final_config.yaml
python evaluate.py --split valid
Reads (relative to --repo_dir, default '.'):
final_model.pt (or weights pointed to by final_config.yaml)
configs/preprocess.json
configs/calibration.json
configs/threshold.json
Writes per-image predictions + a metrics table (with bootstrap 95% CI) to
--output_dir (default results/eval_<split>/).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import yaml
import thyroid_lib as L
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--split", choices=["valid", "test"], required=True)
ap.add_argument("--config", default=None, help="optional final_config.yaml (for data_dir/dataset_id)")
ap.add_argument("--repo_dir", default=".")
ap.add_argument("--data_dir", default=None, help="local TN5000 dir; else downloaded from Hub")
ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification")
ap.add_argument("--weights", default="final_model.pt")
ap.add_argument("--output_dir", default=None)
ap.add_argument("--n_boot", type=int, default=2000)
ap.add_argument("--boot_seed", type=int, default=42)
args = ap.parse_args()
if args.config and Path(args.config).exists():
cfg = yaml.safe_load(open(args.config)) or {}
if not args.data_dir and cfg.get("data_dir"):
args.data_dir = cfg["data_dir"]
if cfg.get("dataset_id"):
args.dataset_id = cfg["dataset_id"]
import torch
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(args.boot_seed, strict=True)
repo = Path(args.repo_dir)
split_name = {"valid": "Valid", "test": "Test"}[args.split]
out_dir = Path(args.output_dir) if args.output_dir else repo / "results" / f"eval_{args.split}"
out_dir.mkdir(parents=True, exist_ok=True)
# data
if args.data_dir:
data_dir = Path(args.data_dir)
else:
from huggingface_hub import snapshot_download
data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset",
local_dir=str(out_dir / "_data"),
allow_patterns=[f"{split_name}/**"]))
# locked artifacts
pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json")))
calib = json.load(open(repo / "configs" / "calibration.json"))
thr_cfg = json.load(open(repo / "configs" / "threshold.json"))
T = calib["temperature"]; use_cal = calib.get("use_calibrated", True)
thr = thr_cfg["locked_threshold"]
# model
ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False)
model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"]); model.to(device).eval()
ds = L.ThyroidImageFolder(data_dir / split_name, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits, y, ids = L.collect_logits(model, loader, device, amp=False)
probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits)
metrics = L.point_metrics(y, probs, thr)
ci = L.bootstrap_ci(y, probs, thr, n_boot=args.n_boot, seed=args.boot_seed)
# per-image csv
pred = (probs >= thr).astype(int)
with open(out_dir / f"{args.split}_predictions.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "true_class", "probability_malignant",
"predicted_label", "predicted_class"])
for i, yy, pr, pd in zip(ids, y, probs, pred):
w.writerow([i, int(yy), L.IDX_TO_CLASS[int(yy)], f"{pr:.6f}",
int(pd), L.IDX_TO_CLASS[int(pd)]])
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
out = {"split": args.split, "n": metrics["n"], "n_pos": metrics["n_pos"],
"n_neg": metrics["n_neg"], "threshold": thr,
"calibration": "temperature(T=%.4f)" % T if use_cal else "none",
"metrics": metrics, "metrics_95ci": {k: list(ci[k]) for k in ci_keys},
"ci_method": f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"}
json.dump(out, open(out_dir / f"{args.split}_metrics.json", "w"), indent=2)
print(f"=== {args.split.upper()} (n={metrics['n']}, thr={thr:.4f}) ===")
for k in ci_keys:
print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]")
print(f" brier {metrics['brier']:.4f}")
print(f" ece {metrics['ece']:.4f}")
print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}")
print("Saved to", out_dir)
if __name__ == "__main__":
main()