File size: 5,389 Bytes
45af8e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python
"""
Final evaluation script for the validation/test splits of TN5000.

Uses the LOCKED final model weights, LOCKED preprocessing, LOCKED calibration
(temperature scaling) and LOCKED decision threshold stored in this repo. The
test split is intended to be evaluated only once, after model/calibration/
threshold were frozen.

Usage:
    python evaluate.py --split test --config configs/final_config.yaml
    python evaluate.py --split valid

Reads (relative to --repo_dir, default '.'):
    final_model.pt              (or weights pointed to by final_config.yaml)
    configs/preprocess.json
    configs/calibration.json
    configs/threshold.json

Writes per-image predictions + a metrics table (with bootstrap 95% CI) to
--output_dir (default results/eval_<split>/).
"""
import argparse
import csv
import json
from pathlib import Path

import numpy as np
import yaml

import thyroid_lib as L


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split", choices=["valid", "test"], required=True)
    ap.add_argument("--config", default=None, help="optional final_config.yaml (for data_dir/dataset_id)")
    ap.add_argument("--repo_dir", default=".")
    ap.add_argument("--data_dir", default=None, help="local TN5000 dir; else downloaded from Hub")
    ap.add_argument("--dataset_id", default="Johnyquest7/TN5000-thyroid-nodule-classification")
    ap.add_argument("--weights", default="final_model.pt")
    ap.add_argument("--output_dir", default=None)
    ap.add_argument("--n_boot", type=int, default=2000)
    ap.add_argument("--boot_seed", type=int, default=42)
    args = ap.parse_args()

    if args.config and Path(args.config).exists():
        cfg = yaml.safe_load(open(args.config)) or {}
        if not args.data_dir and cfg.get("data_dir"):
            args.data_dir = cfg["data_dir"]
        if cfg.get("dataset_id"):
            args.dataset_id = cfg["dataset_id"]

    import torch
    from torch.utils.data import DataLoader
    device = "cuda" if torch.cuda.is_available() else "cpu"
    L.set_determinism(args.boot_seed, strict=True)

    repo = Path(args.repo_dir)
    split_name = {"valid": "Valid", "test": "Test"}[args.split]
    out_dir = Path(args.output_dir) if args.output_dir else repo / "results" / f"eval_{args.split}"
    out_dir.mkdir(parents=True, exist_ok=True)

    # data
    if args.data_dir:
        data_dir = Path(args.data_dir)
    else:
        from huggingface_hub import snapshot_download
        data_dir = Path(snapshot_download(repo_id=args.dataset_id, repo_type="dataset",
                                          local_dir=str(out_dir / "_data"),
                                          allow_patterns=[f"{split_name}/**"]))

    # locked artifacts
    pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json")))
    calib = json.load(open(repo / "configs" / "calibration.json"))
    thr_cfg = json.load(open(repo / "configs" / "threshold.json"))
    T = calib["temperature"]; use_cal = calib.get("use_calibrated", True)
    thr = thr_cfg["locked_threshold"]

    # model
    ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False)
    model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
                             dropout=ck.get("dropout", 0.0))
    model.load_state_dict(ck["model_state"]); model.to(device).eval()

    ds = L.ThyroidImageFolder(data_dir / split_name, L.build_eval_transform(pp))
    loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
                        pin_memory=(device == "cuda"))
    logits, y, ids = L.collect_logits(model, loader, device, amp=False)
    probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits)

    metrics = L.point_metrics(y, probs, thr)
    ci = L.bootstrap_ci(y, probs, thr, n_boot=args.n_boot, seed=args.boot_seed)

    # per-image csv
    pred = (probs >= thr).astype(int)
    with open(out_dir / f"{args.split}_predictions.csv", "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["image_id", "true_label", "true_class", "probability_malignant",
                    "predicted_label", "predicted_class"])
        for i, yy, pr, pd in zip(ids, y, probs, pred):
            w.writerow([i, int(yy), L.IDX_TO_CLASS[int(yy)], f"{pr:.6f}",
                        int(pd), L.IDX_TO_CLASS[int(pd)]])

    ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
    out = {"split": args.split, "n": metrics["n"], "n_pos": metrics["n_pos"],
           "n_neg": metrics["n_neg"], "threshold": thr,
           "calibration": "temperature(T=%.4f)" % T if use_cal else "none",
           "metrics": metrics, "metrics_95ci": {k: list(ci[k]) for k in ci_keys},
           "ci_method": f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"}
    json.dump(out, open(out_dir / f"{args.split}_metrics.json", "w"), indent=2)

    print(f"=== {args.split.upper()} (n={metrics['n']}, thr={thr:.4f}) ===")
    for k in ci_keys:
        print(f"  {k:12s} {metrics[k]:.4f}  CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]")
    print(f"  brier        {metrics['brier']:.4f}")
    print(f"  ece          {metrics['ece']:.4f}")
    print(f"  confusion    TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}")
    print("Saved to", out_dir)


if __name__ == "__main__":
    main()