File size: 8,726 Bytes
45af8e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | #!/usr/bin/env python
"""
Evaluate the LOCKED thyroid ResNet-18 model on an EXTERNAL dataset, using the
exact same preprocessing, calibration (temperature scaling) and locked decision
threshold from the final model repo.
Usage:
python evaluate_external.py \
--model_repo Johnyquest7/agentic_thyroid_model \
--data_dir /path/to/external_dataset \
--output_dir external_results
The external dataset may be provided in either of two formats:
(A) Folder format with class subfolders:
<data_dir>/Benign/*.png|jpg|...
<data_dir>/Malignant/*.png|jpg|...
(case-insensitive; also accepts 0/1 or benign/malignant)
(B) CSV with image paths and labels:
--csv /path/to/labels.csv
with columns:
image_path (absolute, or relative to --data_dir or to the CSV's folder)
label (0/1, or benign/malignant, case-insensitive)
If labels are present, full metrics + bootstrap 95% CIs are computed. If labels
are absent/unknown, only per-image probabilities and predictions are written.
The model weights and locked configs are downloaded from --model_repo (or read
from --local_repo_dir if provided).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import thyroid_lib as L
IMG_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
BENIGN_ALIASES = {"benign", "0", "b", "neg", "negative"}
MALIGNANT_ALIASES = {"malignant", "1", "m", "pos", "positive", "cancer"}
def parse_label(v):
s = str(v).strip().lower()
if s in BENIGN_ALIASES:
return 0
if s in MALIGNANT_ALIASES:
return 1
return None # unknown
def gather_folder(data_dir):
data_dir = Path(data_dir)
items = [] # (path, label_or_None, id)
# find class subfolders case-insensitively
subdirs = {p.name.lower(): p for p in data_dir.iterdir() if p.is_dir()}
cls_map = {}
for name, p in subdirs.items():
lab = parse_label(name)
if lab is not None:
cls_map[p] = lab
if cls_map:
for p, lab in cls_map.items():
for f in sorted(p.rglob("*")):
if f.suffix.lower() in IMG_EXT:
items.append((f, lab, f.stem))
else:
# flat folder, no labels
for f in sorted(data_dir.rglob("*")):
if f.suffix.lower() in IMG_EXT:
items.append((f, None, f.stem))
return items
def gather_csv(csv_path, data_dir):
csv_path = Path(csv_path)
base = Path(data_dir) if data_dir else csv_path.parent
items = []
with open(csv_path) as f:
reader = csv.DictReader(f)
cols = {c.lower(): c for c in reader.fieldnames}
pcol = cols.get("image_path") or cols.get("path") or cols.get("image") or cols.get("filepath")
lcol = cols.get("label") or cols.get("class") or cols.get("target")
if pcol is None:
raise ValueError("CSV must have an image path column (image_path/path/image).")
for row in reader:
raw = row[pcol]
p = Path(raw)
if not p.is_absolute():
cand = base / raw
p = cand if cand.exists() else (csv_path.parent / raw)
lab = parse_label(row[lcol]) if lcol else None
items.append((p, lab, p.stem))
return items
class ListDataset:
def __init__(self, items, transform):
from PIL import Image
self.Image = Image
self.items = items
self.transform = transform
def __len__(self):
return len(self.items)
def __getitem__(self, i):
path, lab, iid = self.items[i]
with self.Image.open(path) as im:
x = self.transform(im.convert("RGB"))
return x, (-1 if lab is None else lab), iid
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_repo", default="Johnyquest7/agentic_thyroid_model")
ap.add_argument("--local_repo_dir", default=None,
help="use a local copy of the model repo instead of downloading")
ap.add_argument("--data_dir", default=None, help="external image folder (format A) or CSV base")
ap.add_argument("--csv", default=None, help="CSV with image_path,label (format B)")
ap.add_argument("--weights", default="final_model.pt")
ap.add_argument("--output_dir", default="external_results")
ap.add_argument("--n_boot", type=int, default=2000)
ap.add_argument("--boot_seed", type=int, default=42)
args = ap.parse_args()
if not args.data_dir and not args.csv:
raise SystemExit("Provide --data_dir (folder format) or --csv (CSV format).")
import torch
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(args.boot_seed, strict=True)
out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
# ---- fetch locked repo artifacts ----
if args.local_repo_dir:
repo = Path(args.local_repo_dir)
else:
from huggingface_hub import snapshot_download
repo = Path(snapshot_download(repo_id=args.model_repo, repo_type="model",
local_dir=str(out_dir / "_model_repo"),
allow_patterns=[args.weights, "configs/*", "thyroid_lib.py"]))
pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json")))
calib = json.load(open(repo / "configs" / "calibration.json"))
thr_cfg = json.load(open(repo / "configs" / "threshold.json"))
T = calib["temperature"]; use_cal = calib.get("use_calibrated", True)
thr = thr_cfg["locked_threshold"]
ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False)
model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"]); model.to(device).eval()
# ---- gather external images ----
items = gather_csv(args.csv, args.data_dir) if args.csv else gather_folder(args.data_dir)
if not items:
raise SystemExit("No images found in external dataset.")
has_labels = all(it[1] is not None for it in items)
print(f"Found {len(items)} external images; labels available: {has_labels}")
ds = ListDataset(items, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits_all, labels_all, ids_all = [], [], []
with torch.no_grad():
for x, y, iid in loader:
x = x.to(device)
out = model(x).view(-1)
logits_all.append(out.float().cpu().numpy())
labels_all.append(np.asarray(y))
ids_all.extend(list(iid))
logits = np.concatenate(logits_all); labels = np.concatenate(labels_all).astype(int)
probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits)
pred = (probs >= thr).astype(int)
# ---- per-image CSV ----
with open(out_dir / "external_predictions.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "probability_malignant",
"predicted_label", "predicted_class"])
for i, yy, pr, pd in zip(ids_all, labels, probs, pred):
w.writerow([i, ("" if yy < 0 else int(yy)), f"{pr:.6f}",
int(pd), L.IDX_TO_CLASS[int(pd)]])
result = {"n": len(items), "threshold": thr,
"calibration": "temperature(T=%.4f)" % T if use_cal else "none",
"labels_available": bool(has_labels)}
if has_labels:
metrics = L.point_metrics(labels, probs, thr)
ci = L.bootstrap_ci(labels, probs, thr, n_boot=args.n_boot, seed=args.boot_seed)
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
result["metrics"] = metrics
result["metrics_95ci"] = {k: list(ci[k]) for k in ci_keys}
result["ci_method"] = f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"
print("=== EXTERNAL METRICS ===")
for k in ci_keys:
print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]")
print(f" brier {metrics['brier']:.4f}")
print(f" ece {metrics['ece']:.4f}")
print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}")
else:
print("No labels provided — wrote probabilities and predictions only.")
json.dump(result, open(out_dir / "external_metrics.json", "w"), indent=2)
print("Saved to", out_dir)
if __name__ == "__main__":
main()
|