medyx-v2 / eval_bone.py
apook's picture
Add full evaluation suite, fix MedQA/ODIR evals, complete technical report
e7c144c
#!/usr/bin/env python3
"""
Evaluate MedAgent ImageAgent on bone fracture X-ray detection.
Dataset
-------
prithivMLmods/Bone-Fracture-Detection β€” 9,246 bone X-ray images.
License: MIT / public domain. No registration required.
2 classes:
fractured = Bone fracture present
not_fractured = No fracture detected
Clinical context
----------------
Dataset does not include per-patient metadata (age/sex). We pass a clinically
realistic context (fracture symptoms + bone hypothesis) to trigger hypothesis-
guided bone domain routing in ImageAgent β€” without this, grayscale bone X-rays
are classified as "chest" by the pixel-level heuristic.
CRITICAL: The symptoms and hypothesis fields MUST contain bone-fracture keywords
(e.g., "fracture", "bone") so that `_analyze_image()` routes to the bone specialist
model instead of the chest model. See the _BONE_TERMS logic in image_agent.py.
Model under test
----------------
Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured)
New bone domain specialist added in this session.
Usage:
python eval_bone.py # 100 cases (shuffled, seed=42)
python eval_bone.py --n 200 # custom count
python eval_bone.py --no-shuffle # first N cases
python eval_bone.py --preview # print 2 raw examples and exit
"""
import argparse
import base64
import io
import json
import os
import sys
import time
from collections import defaultdict, Counter
from datetime import date
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from dotenv import dotenv_values
for k, v in dotenv_values(".env").items():
os.environ.setdefault(k, v)
except Exception:
pass
from datasets import load_dataset
from PIL import Image
from agents.image_agent import ImageAgent
DATASET_NAME = "Hemg/bone-fracture-detection"
RESULTS_DIR = "results"
DOCS_DIR = "docs"
# 2 ground-truth classes
ALL_CLASSES = ["fractured", "not_fractured"]
CLASS_FULL = {
"fractured": "Bone Fracture Present",
"not_fractured": "No Fracture Detected",
}
# ── Helpers ───────────────────────────────────────────────────────────────────
def normalize_to_bone(label: str) -> str:
"""Map model output or dataset label to fractured / not_fractured."""
s = label.lower().replace("_", " ").replace("-", " ").strip()
if "fractur" in s or "broken" in s or "break" in s or "crack" in s:
return "fractured"
if "not fractur" in s or "no fractur" in s or "normal" in s or "healthy" in s:
return "not_fractured"
# Dataset uses "fractured" / "not fractured" as primary labels β€” cover both
if s in ("fractured", "fracture", "1", "yes"):
return "fractured"
if s in ("not fractured", "not_fractured", "0", "no"):
return "not_fractured"
return "not_fractured" # safe default
def pil_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def build_patient_context() -> tuple[str, str]:
"""Clinical context that triggers bone domain routing via _BONE_TERMS."""
symptoms = "Bone pain following trauma. X-ray submitted for bone fracture assessment."
history = "Musculoskeletal X-ray for bone fracture evaluation. Orthopedic referral."
return symptoms, history
# ── Dataset preview ───────────────────────────────────────────────────────────
def preview(ds) -> None:
print(f"\nDataset columns: {list(ds.features.keys())}\n")
for i in range(min(2, len(ds))):
row = ds[i]
print(f"─── Example {i+1} ─────────────────────────────────────────")
for k, v in row.items():
if hasattr(v, "size"):
print(f" {k:<20} PIL Image size={v.size} mode={v.mode}")
else:
print(f" {k:<20} ({type(v).__name__}) = {repr(v)[:120]}")
print()
# ── Core evaluation loop ──────────────────────────────────────────────────────
def evaluate(n_cases: int, shuffle: bool = True) -> None:
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(DOCS_DIR, exist_ok=True)
print(f"Loading {DATASET_NAME}…")
try:
ds = load_dataset(DATASET_NAME, split="test")
except Exception:
try:
ds = load_dataset(DATASET_NAME, split="train")
print(" (no test split β€” using train)")
except Exception as exc:
print(f"ERROR: Cannot load {DATASET_NAME}: {exc}", file=sys.stderr)
sys.exit(1)
total = len(ds)
if shuffle:
ds = ds.shuffle(seed=42)
print(f"Loaded: {total} cases. Shuffled (seed=42), using first {n_cases}.\n")
else:
print(f"Loaded: {total} cases. Using first {n_cases} (unshuffled).\n")
preview(ds)
sample = ds[0]
img_field = next((f for f in ("image", "img", "xray", "x_ray") if f in sample), None)
lbl_field = next((f for f in ("label", "class", "fracture", "finding") if f in sample), None)
if not img_field or not lbl_field:
print(f"ERROR: Cannot find image/label fields in {list(sample.keys())}", file=sys.stderr)
sys.exit(1)
print(f"Using fields: image='{img_field}', label='{lbl_field}'\n")
# Resolve label names if dataset uses integer indices
label_names: list[str] | None = None
feat = ds.features.get(lbl_field)
if hasattr(feat, "names"):
label_names = feat.names
print(f"Label names from dataset features: {label_names}\n")
agent = ImageAgent()
cases = list(ds.select(range(min(n_cases, len(ds)))))
top1_hits = 0
errors = 0
tp = fp = fn = tn = 0
confusion: dict[str, list[str]] = defaultdict(list)
records: list[dict] = []
start_wall = time.time()
symptoms, history = build_patient_context()
print(f"{'#':>5} {'GT':<16} {'Our Pred':<16} {'Confidence':>11} {'Domain':<10} Result")
print("─" * 70)
for idx, case in enumerate(cases):
raw_label = case[lbl_field]
if isinstance(raw_label, int) and label_names:
gt_str = normalize_to_bone(label_names[raw_label])
else:
gt_str = normalize_to_bone(str(raw_label))
pil_img = case[img_field]
if not isinstance(pil_img, Image.Image):
errors += 1
continue
patient_data = {
"image_b64": pil_to_b64(pil_img),
"symptoms": symptoms,
"medical_history": history,
"lab_values": "",
# CRITICAL: bone keywords in symptoms AND hypothesis trigger the
# hypothesis-guided bone routing override in _analyze_image().
# Without this, grayscale X-rays route to "chest" domain.
"hypothesis": {
"primary_hypothesis": {
"disease": "bone fracture",
"confidence": 0.5,
},
},
}
t0 = time.time()
try:
result = agent.analyze(patient_data)
candidates = result.get("candidates", [])
except Exception as exc:
errors += 1
print(f"{idx+1:>5} {gt_str:<16} {'ERROR':16} {'':>11} {'':10} ERR: {exc}",
file=sys.stderr)
candidates = []
result = {}
elapsed = round(time.time() - t0, 2)
raw_top = candidates[0]["disease"] if candidates else "unknown"
our_pred = normalize_to_bone(raw_top)
conf = candidates[0]["score"] if candidates else 0.0
domain = result.get("image_domain") or "?"
hit1 = (our_pred == gt_str)
top1_hits += hit1
in_gt = (gt_str == "fractured")
in_pred = (our_pred == "fractured")
if in_gt and in_pred: tp += 1
elif not in_gt and in_pred: fp += 1
elif in_gt and not in_pred: fn += 1
else: tn += 1
confusion[gt_str].append(our_pred)
tag = "HIT" if hit1 else "MISS"
print(f"{idx+1:>5} {gt_str:<16} {our_pred:<16} {conf:>10.1%} {domain:<10} {tag}")
records.append({
"idx": idx, "gt": gt_str, "our_pred": our_pred,
"raw_pred": raw_top, "confidence": round(conf, 4),
"hit1": hit1, "time_sec": elapsed, "domain": domain,
})
if (idx + 1) % 10 == 0:
acc = top1_hits / (idx + 1) * 100
print(f"\n ── Case {idx+1}/{n_cases} | Top-1 Acc: {acc:.1f}% ──\n")
# ── Metrics ───────────────────────────────────────────────────────────────
total_time = round(time.time() - start_wall, 1)
evaluated = len(records)
top1_pct = round(top1_hits / max(evaluated, 1) * 100, 1)
avg_time = round(total_time / max(evaluated, 1), 2)
sensitivity = round(tp / max(tp + fn, 1), 4) # fracture recall
specificity = round(tn / max(tn + fp, 1), 4)
precision = round(tp / max(tp + fp, 1), 4)
f1_fracture = round(2 * precision * sensitivity / max(precision + sensitivity, 1e-9), 4)
confusion_summary: dict[str, dict] = {}
for gt_lbl, preds in confusion.items():
wrong = [p for p in preds if p != gt_lbl]
if wrong:
mc = Counter(wrong).most_common(1)[0]
confusion_summary[gt_lbl] = {
"predicted_instead": mc[0], "count": mc[1],
"total_cases": len(preds),
"error_rate": round(len(wrong) / len(preds), 3),
}
print("\n" + "=" * 70)
print("EVALUATION SUMMARY β€” Bone Fracture Detection (ImageAgent, bone model)")
print("=" * 70)
print(f" Cases evaluated : {evaluated} ({errors} errors)")
print(f" Total wall time : {total_time}s ({avg_time}s/case)")
print(f" Top-1 Accuracy : {top1_hits}/{evaluated} = {top1_pct}% (binary: fractured vs not)")
print(f"\n Fracture detection (binary):")
print(f" TP={tp} FP={fp} FN={fn} TN={tn}")
print(f" Sensitivity (Recall) : {sensitivity:.3f}")
print(f" Specificity : {specificity:.3f}")
print(f" Precision : {precision:.3f}")
print(f" F1 (fracture) : {f1_fracture:.4f}")
print("=" * 70)
json_path = os.path.join(RESULTS_DIR, "eval_08_bone.json")
payload = {
"dataset": DATASET_NAME,
"note": "Bone Fracture Detection β€” MIT license. 9,246 X-ray images. Requires hypothesis-guided bone routing.",
"date_run": date.today().isoformat(),
"agent_evaluated": "image_agent only (isolated bone model; hypothesis-guided routing required)",
"bone_model": "Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured)",
"routing_note": "Bone X-rays are grayscale β€” pixel heuristic would route to 'chest'. Hypothesis override with 'bone fracture' keywords forces bone domain.",
"evaluation_strategy": "Binary: model 'Fractured' β†’ positive; 'Not Fractured' β†’ negative",
"cases_evaluated": evaluated,
"top1_accuracy": top1_pct,
"fracture_sensitivity": sensitivity,
"fracture_specificity": specificity,
"fracture_precision": precision,
"fracture_f1": f1_fracture,
"confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn},
"avg_time_per_case": avg_time,
"confusion_summary": confusion_summary,
"correct_examples": [r for r in records if r["hit1"]][:5],
"wrong_examples": [r for r in records if not r["hit1"]][:5],
"total_runtime_sec": total_time,
"errors_encountered": errors,
}
with open(json_path, "w") as f:
json.dump(payload, f, indent=2)
print(f"\nJSON results saved β†’ {json_path}")
_write_report(payload, confusion_summary)
def _write_report(payload: dict, confusion_summary: dict) -> None:
confusion_section = "\n".join(
f"- **{gt} β†’ {info['predicted_instead']}**: "
f"{info['count']}/{info['total_cases']} ({info['error_rate']:.1%} error)"
for gt, info in sorted(confusion_summary.items())
) or "No misclassifications recorded."
md = f"""# Evaluation 8: Bone Fracture Detection (X-ray)
**Date:** {payload['date_run']}
**Dataset:** `{payload['dataset']}`
**Agent:** ImageAgent (isolated β€” bone model with hypothesis-guided routing)
**Model:** `{payload['bone_model']}`
**Cases:** {payload['cases_evaluated']}
**License:** MIT β€” freely accessible, no registration required
---
## Dataset: Bone Fracture Detection
9,246 bone X-ray images split into two classes:
- **Fractured** β€” clear or subtle fracture line present
- **Not Fractured** β€” normal bone anatomy
Images span multiple skeletal sites (arm, hand, shoulder, leg, knee, foot, spine).
No per-patient clinical metadata is available in this dataset.
---
## Domain routing: hypothesis override
Bone X-rays are **grayscale and non-square** β€” the same pixel-level signature as
chest X-rays. Without intervention, `_detect_domain()` would route these images
to the chest specialist model, producing nonsensical output (Pneumonia, Edema, etc.).
**Solution:** Hypothesis-guided bone routing (implemented in `image_agent.py`).
When patient symptoms or hypothesis contain bone-related keywords (fracture, bone,
musculoskeletal, orthop, skeletal, dislocation, arthritis), the domain is overridden
to "bone" regardless of pixel-level classification. This eval passes:
- `symptoms`: "Bone pain following trauma. X-ray submitted for bone fracture assessment."
- `hypothesis.disease`: "bone fracture"
Both triggers ensure consistent routing to the bone specialist model.
---
## Results
| Metric | Value |
|---|---|
| **Binary Top-1 Accuracy** | **{payload['top1_accuracy']}%** |
| **Fracture Sensitivity** | **{payload['fracture_sensitivity']:.3f}** |
| **Fracture Specificity** | **{payload['fracture_specificity']:.3f}** |
| **Fracture F1** | **{payload['fracture_f1']:.4f}** |
| Confusion matrix | TP={payload['confusion_matrix']['tp']}, FP={payload['confusion_matrix']['fp']}, FN={payload['confusion_matrix']['fn']}, TN={payload['confusion_matrix']['tn']} |
| Avg time per case | {payload['avg_time_per_case']}s |
**Clinical note:** Sensitivity is the critical metric for fracture screening β€”
missed fractures (FN) delay treatment and risk non-union or malunion complications.
False positives (FP) lead to unnecessary immobilization but are less harmful clinically.
---
## Confusion analysis
{confusion_section}
---
## Routing validation
If `domain` values in the results are predominantly "bone" (not "chest"), the
hypothesis-guided override is working correctly. If "chest" appears, check that
the `symptoms` and `hypothesis` fields in `patient_data` contain the required
bone-trigger keywords and that the `_BONE_TERMS` set in `image_agent.py` includes
"fracture" and "bone".
---
## Dataset notes
- Covers multiple anatomical sites β€” the model was trained on multi-site bone X-rays
and generalizes reasonably across extremities, spine, and pelvis.
- Subtle fractures (stress fractures, hairline fractures) are the primary failure mode
for automated systems; the dataset label distribution should be checked for severity.
- No geographic or demographic bias information is documented for this dataset.
"""
md_path = os.path.join(DOCS_DIR, "eval_08_bone.md")
with open(md_path, "w") as f:
f.write(md)
print(f"Markdown report saved β†’ {md_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ImageAgent on Bone Fracture Detection dataset")
parser.add_argument("--n", type=int, default=100)
parser.add_argument("--no-shuffle", action="store_true")
parser.add_argument("--preview", action="store_true")
args = parser.parse_args()
if args.preview:
try:
ds = load_dataset(DATASET_NAME, split="test")
except Exception:
ds = load_dataset(DATASET_NAME, split="train")
preview(ds)
else:
evaluate(args.n, shuffle=not args.no_shuffle)