Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified | #!/usr/bin/env python | |
| """ | |
| Final evaluation script for the validation/test splits of TN5000. | |
| Uses the LOCKED final model weights, LOCKED preprocessing, LOCKED calibration | |
| (temperature scaling) and LOCKED decision threshold stored in this repo. The | |
| test split is intended to be evaluated only once, after model/calibration/ | |
| threshold were frozen. | |
| Usage: | |
| python evaluate.py --split test --config configs/final_config.yaml | |
| python evaluate.py --split valid | |
| Reads (relative to --repo_dir, default '.'): | |
| final_model.pt (or weights pointed to by final_config.yaml) | |
| configs/preprocess.json | |
| configs/calibration.json | |
| configs/threshold.json | |
| Writes per-image predictions + a metrics table (with bootstrap 95% CI) to | |
| --output_dir (default results/eval_<split>/). | |
| """ | |
| import argparse | |
| import csv | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| import thyroid_lib as L | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--split", choices=["valid", "test"], required=True) | |
| ap.add_argument("--config", default=None, help="optional final_config.yaml (for data_dir/dataset_id)") | |
| ap.add_argument("--repo_dir", default=".") | |
| ap.add_argument("--data_dir", default=None, help="local TN5000 dir; else downloaded from Hub") | |
| ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification") | |
| ap.add_argument("--weights", default="final_model.pt") | |
| ap.add_argument("--output_dir", default=None) | |
| ap.add_argument("--n_boot", type=int, default=2000) | |
| ap.add_argument("--boot_seed", type=int, default=42) | |
| args = ap.parse_args() | |
| if args.config and Path(args.config).exists(): | |
| cfg = yaml.safe_load(open(args.config)) or {} | |
| if not args.data_dir and cfg.get("data_dir"): | |
| args.data_dir = cfg["data_dir"] | |
| if cfg.get("dataset_id"): | |
| args.dataset_id = cfg["dataset_id"] | |
| import torch | |
| from torch.utils.data import DataLoader | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| L.set_determinism(args.boot_seed, strict=True) | |
| repo = Path(args.repo_dir) | |
| split_name = {"valid": "Valid", "test": "Test"}[args.split] | |
| out_dir = Path(args.output_dir) if args.output_dir else repo / "results" / f"eval_{args.split}" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # data | |
| if args.data_dir: | |
| data_dir = Path(args.data_dir) | |
| else: | |
| from huggingface_hub import snapshot_download | |
| data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset", | |
| local_dir=str(out_dir / "_data"), | |
| allow_patterns=[f"{split_name}/**"])) | |
| # locked artifacts | |
| pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json"))) | |
| calib = json.load(open(repo / "configs" / "calibration.json")) | |
| thr_cfg = json.load(open(repo / "configs" / "threshold.json")) | |
| T = calib["temperature"]; use_cal = calib.get("use_calibrated", True) | |
| thr = thr_cfg["locked_threshold"] | |
| # model | |
| ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False) | |
| model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0), | |
| dropout=ck.get("dropout", 0.0)) | |
| model.load_state_dict(ck["model_state"]); model.to(device).eval() | |
| ds = L.ThyroidImageFolder(data_dir / split_name, L.build_eval_transform(pp)) | |
| loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4, | |
| pin_memory=(device == "cuda")) | |
| logits, y, ids = L.collect_logits(model, loader, device, amp=False) | |
| probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits) | |
| metrics = L.point_metrics(y, probs, thr) | |
| ci = L.bootstrap_ci(y, probs, thr, n_boot=args.n_boot, seed=args.boot_seed) | |
| # per-image csv | |
| pred = (probs >= thr).astype(int) | |
| with open(out_dir / f"{args.split}_predictions.csv", "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["image_id", "true_label", "true_class", "probability_malignant", | |
| "predicted_label", "predicted_class"]) | |
| for i, yy, pr, pd in zip(ids, y, probs, pred): | |
| w.writerow([i, int(yy), L.IDX_TO_CLASS[int(yy)], f"{pr:.6f}", | |
| int(pd), L.IDX_TO_CLASS[int(pd)]]) | |
| ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"] | |
| out = {"split": args.split, "n": metrics["n"], "n_pos": metrics["n_pos"], | |
| "n_neg": metrics["n_neg"], "threshold": thr, | |
| "calibration": "temperature(T=%.4f)" % T if use_cal else "none", | |
| "metrics": metrics, "metrics_95ci": {k: list(ci[k]) for k in ci_keys}, | |
| "ci_method": f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"} | |
| json.dump(out, open(out_dir / f"{args.split}_metrics.json", "w"), indent=2) | |
| print(f"=== {args.split.upper()} (n={metrics['n']}, thr={thr:.4f}) ===") | |
| for k in ci_keys: | |
| print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]") | |
| print(f" brier {metrics['brier']:.4f}") | |
| print(f" ece {metrics['ece']:.4f}") | |
| print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}") | |
| print("Saved to", out_dir) | |
| if __name__ == "__main__": | |
| main() | |