File size: 5,389 Bytes
45af8e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | #!/usr/bin/env python
"""
Final evaluation script for the validation/test splits of TN5000.
Uses the LOCKED final model weights, LOCKED preprocessing, LOCKED calibration
(temperature scaling) and LOCKED decision threshold stored in this repo. The
test split is intended to be evaluated only once, after model/calibration/
threshold were frozen.
Usage:
python evaluate.py --split test --config configs/final_config.yaml
python evaluate.py --split valid
Reads (relative to --repo_dir, default '.'):
final_model.pt (or weights pointed to by final_config.yaml)
configs/preprocess.json
configs/calibration.json
configs/threshold.json
Writes per-image predictions + a metrics table (with bootstrap 95% CI) to
--output_dir (default results/eval_<split>/).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import yaml
import thyroid_lib as L
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--split", choices=["valid", "test"], required=True)
ap.add_argument("--config", default=None, help="optional final_config.yaml (for data_dir/dataset_id)")
ap.add_argument("--repo_dir", default=".")
ap.add_argument("--data_dir", default=None, help="local TN5000 dir; else downloaded from Hub")
ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification")
ap.add_argument("--weights", default="final_model.pt")
ap.add_argument("--output_dir", default=None)
ap.add_argument("--n_boot", type=int, default=2000)
ap.add_argument("--boot_seed", type=int, default=42)
args = ap.parse_args()
if args.config and Path(args.config).exists():
cfg = yaml.safe_load(open(args.config)) or {}
if not args.data_dir and cfg.get("data_dir"):
args.data_dir = cfg["data_dir"]
if cfg.get("dataset_id"):
args.dataset_id = cfg["dataset_id"]
import torch
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(args.boot_seed, strict=True)
repo = Path(args.repo_dir)
split_name = {"valid": "Valid", "test": "Test"}[args.split]
out_dir = Path(args.output_dir) if args.output_dir else repo / "results" / f"eval_{args.split}"
out_dir.mkdir(parents=True, exist_ok=True)
# data
if args.data_dir:
data_dir = Path(args.data_dir)
else:
from huggingface_hub import snapshot_download
data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset",
local_dir=str(out_dir / "_data"),
allow_patterns=[f"{split_name}/**"]))
# locked artifacts
pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json")))
calib = json.load(open(repo / "configs" / "calibration.json"))
thr_cfg = json.load(open(repo / "configs" / "threshold.json"))
T = calib["temperature"]; use_cal = calib.get("use_calibrated", True)
thr = thr_cfg["locked_threshold"]
# model
ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False)
model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"]); model.to(device).eval()
ds = L.ThyroidImageFolder(data_dir / split_name, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits, y, ids = L.collect_logits(model, loader, device, amp=False)
probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits)
metrics = L.point_metrics(y, probs, thr)
ci = L.bootstrap_ci(y, probs, thr, n_boot=args.n_boot, seed=args.boot_seed)
# per-image csv
pred = (probs >= thr).astype(int)
with open(out_dir / f"{args.split}_predictions.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "true_class", "probability_malignant",
"predicted_label", "predicted_class"])
for i, yy, pr, pd in zip(ids, y, probs, pred):
w.writerow([i, int(yy), L.IDX_TO_CLASS[int(yy)], f"{pr:.6f}",
int(pd), L.IDX_TO_CLASS[int(pd)]])
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
out = {"split": args.split, "n": metrics["n"], "n_pos": metrics["n_pos"],
"n_neg": metrics["n_neg"], "threshold": thr,
"calibration": "temperature(T=%.4f)" % T if use_cal else "none",
"metrics": metrics, "metrics_95ci": {k: list(ci[k]) for k in ci_keys},
"ci_method": f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"}
json.dump(out, open(out_dir / f"{args.split}_metrics.json", "w"), indent=2)
print(f"=== {args.split.upper()} (n={metrics['n']}, thr={thr:.4f}) ===")
for k in ci_keys:
print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]")
print(f" brier {metrics['brier']:.4f}")
print(f" ece {metrics['ece']:.4f}")
print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}")
print("Saved to", out_dir)
if __name__ == "__main__":
main()
|