agentic_thyroid_model / evaluate_external.py
Johnyquest7's picture
Add full reproducible thyroid ResNet-18 experiment: weights, scripts, configs, calibration, locked threshold, test eval w/ CIs, figures, data exploration, README, LOG
45af8e1 verified
Raw
History Blame Contribute Delete
8.73 kB
#!/usr/bin/env python
"""
Evaluate the LOCKED thyroid ResNet-18 model on an EXTERNAL dataset, using the
exact same preprocessing, calibration (temperature scaling) and locked decision
threshold from the final model repo.
Usage:
python evaluate_external.py \
--model_repo Johnyquest7/agentic_thyroid_model \
--data_dir /path/to/external_dataset \
--output_dir external_results
The external dataset may be provided in either of two formats:
(A) Folder format with class subfolders:
<data_dir>/Benign/*.png|jpg|...
<data_dir>/Malignant/*.png|jpg|...
(case-insensitive; also accepts 0/1 or benign/malignant)
(B) CSV with image paths and labels:
--csv /path/to/labels.csv
with columns:
image_path (absolute, or relative to --data_dir or to the CSV's folder)
label (0/1, or benign/malignant, case-insensitive)
If labels are present, full metrics + bootstrap 95% CIs are computed. If labels
are absent/unknown, only per-image probabilities and predictions are written.
The model weights and locked configs are downloaded from --model_repo (or read
from --local_repo_dir if provided).
"""
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import thyroid_lib as L
IMG_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
BENIGN_ALIASES = {"benign", "0", "b", "neg", "negative"}
MALIGNANT_ALIASES = {"malignant", "1", "m", "pos", "positive", "cancer"}
def parse_label(v):
s = str(v).strip().lower()
if s in BENIGN_ALIASES:
return 0
if s in MALIGNANT_ALIASES:
return 1
return None # unknown
def gather_folder(data_dir):
data_dir = Path(data_dir)
items = [] # (path, label_or_None, id)
# find class subfolders case-insensitively
subdirs = {p.name.lower(): p for p in data_dir.iterdir() if p.is_dir()}
cls_map = {}
for name, p in subdirs.items():
lab = parse_label(name)
if lab is not None:
cls_map[p] = lab
if cls_map:
for p, lab in cls_map.items():
for f in sorted(p.rglob("*")):
if f.suffix.lower() in IMG_EXT:
items.append((f, lab, f.stem))
else:
# flat folder, no labels
for f in sorted(data_dir.rglob("*")):
if f.suffix.lower() in IMG_EXT:
items.append((f, None, f.stem))
return items
def gather_csv(csv_path, data_dir):
csv_path = Path(csv_path)
base = Path(data_dir) if data_dir else csv_path.parent
items = []
with open(csv_path) as f:
reader = csv.DictReader(f)
cols = {c.lower(): c for c in reader.fieldnames}
pcol = cols.get("image_path") or cols.get("path") or cols.get("image") or cols.get("filepath")
lcol = cols.get("label") or cols.get("class") or cols.get("target")
if pcol is None:
raise ValueError("CSV must have an image path column (image_path/path/image).")
for row in reader:
raw = row[pcol]
p = Path(raw)
if not p.is_absolute():
cand = base / raw
p = cand if cand.exists() else (csv_path.parent / raw)
lab = parse_label(row[lcol]) if lcol else None
items.append((p, lab, p.stem))
return items
class ListDataset:
def __init__(self, items, transform):
from PIL import Image
self.Image = Image
self.items = items
self.transform = transform
def __len__(self):
return len(self.items)
def __getitem__(self, i):
path, lab, iid = self.items[i]
with self.Image.open(path) as im:
x = self.transform(im.convert("RGB"))
return x, (-1 if lab is None else lab), iid
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_repo", default="Johnyquest7/agentic_thyroid_model")
ap.add_argument("--local_repo_dir", default=None,
help="use a local copy of the model repo instead of downloading")
ap.add_argument("--data_dir", default=None, help="external image folder (format A) or CSV base")
ap.add_argument("--csv", default=None, help="CSV with image_path,label (format B)")
ap.add_argument("--weights", default="final_model.pt")
ap.add_argument("--output_dir", default="external_results")
ap.add_argument("--n_boot", type=int, default=2000)
ap.add_argument("--boot_seed", type=int, default=42)
args = ap.parse_args()
if not args.data_dir and not args.csv:
raise SystemExit("Provide --data_dir (folder format) or --csv (CSV format).")
import torch
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
L.set_determinism(args.boot_seed, strict=True)
out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True)
# ---- fetch locked repo artifacts ----
if args.local_repo_dir:
repo = Path(args.local_repo_dir)
else:
from huggingface_hub import snapshot_download
repo = Path(snapshot_download(repo_id=args.model_repo, repo_type="model",
local_dir=str(out_dir / "_model_repo"),
allow_patterns=[args.weights, "configs/*", "thyroid_lib.py"]))
pp = L.PreprocessConfig.from_dict(json.load(open(repo / "configs" / "preprocess.json")))
calib = json.load(open(repo / "configs" / "calibration.json"))
thr_cfg = json.load(open(repo / "configs" / "threshold.json"))
T = calib["temperature"]; use_cal = calib.get("use_calibrated", True)
thr = thr_cfg["locked_threshold"]
ck = torch.load(repo / args.weights, map_location="cpu", weights_only=False)
model, _ = L.build_model(ck["backbone"], freeze_stage=ck.get("freeze_stage", 0),
dropout=ck.get("dropout", 0.0))
model.load_state_dict(ck["model_state"]); model.to(device).eval()
# ---- gather external images ----
items = gather_csv(args.csv, args.data_dir) if args.csv else gather_folder(args.data_dir)
if not items:
raise SystemExit("No images found in external dataset.")
has_labels = all(it[1] is not None for it in items)
print(f"Found {len(items)} external images; labels available: {has_labels}")
ds = ListDataset(items, L.build_eval_transform(pp))
loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=4,
pin_memory=(device == "cuda"))
logits_all, labels_all, ids_all = [], [], []
with torch.no_grad():
for x, y, iid in loader:
x = x.to(device)
out = model(x).view(-1)
logits_all.append(out.float().cpu().numpy())
labels_all.append(np.asarray(y))
ids_all.extend(list(iid))
logits = np.concatenate(logits_all); labels = np.concatenate(labels_all).astype(int)
probs = L.apply_temperature(logits, T) if use_cal else L.sigmoid(logits)
pred = (probs >= thr).astype(int)
# ---- per-image CSV ----
with open(out_dir / "external_predictions.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["image_id", "true_label", "probability_malignant",
"predicted_label", "predicted_class"])
for i, yy, pr, pd in zip(ids_all, labels, probs, pred):
w.writerow([i, ("" if yy < 0 else int(yy)), f"{pr:.6f}",
int(pd), L.IDX_TO_CLASS[int(pd)]])
result = {"n": len(items), "threshold": thr,
"calibration": "temperature(T=%.4f)" % T if use_cal else "none",
"labels_available": bool(has_labels)}
if has_labels:
metrics = L.point_metrics(labels, probs, thr)
ci = L.bootstrap_ci(labels, probs, thr, n_boot=args.n_boot, seed=args.boot_seed)
ci_keys = ["auroc", "sensitivity", "specificity", "ppv", "npv", "accuracy", "f1"]
result["metrics"] = metrics
result["metrics_95ci"] = {k: list(ci[k]) for k in ci_keys}
result["ci_method"] = f"stratified bootstrap, {args.n_boot} resamples, seed={args.boot_seed}"
print("=== EXTERNAL METRICS ===")
for k in ci_keys:
print(f" {k:12s} {metrics[k]:.4f} CI [{ci[k][0]:.4f}, {ci[k][1]:.4f}]")
print(f" brier {metrics['brier']:.4f}")
print(f" ece {metrics['ece']:.4f}")
print(f" confusion TN={metrics['tn']} FP={metrics['fp']} FN={metrics['fn']} TP={metrics['tp']}")
else:
print("No labels provided — wrote probabilities and predictions only.")
json.dump(result, open(out_dir / "external_metrics.json", "w"), indent=2)
print("Saved to", out_dir)
if __name__ == "__main__":
main()