#!/usr/bin/env python3 """ Evaluate MedAgent ImageAgent on bone fracture X-ray detection. Dataset ------- prithivMLmods/Bone-Fracture-Detection — 9,246 bone X-ray images. License: MIT / public domain. No registration required. 2 classes: fractured = Bone fracture present not_fractured = No fracture detected Clinical context ---------------- Dataset does not include per-patient metadata (age/sex). We pass a clinically realistic context (fracture symptoms + bone hypothesis) to trigger hypothesis- guided bone domain routing in ImageAgent — without this, grayscale bone X-rays are classified as "chest" by the pixel-level heuristic. CRITICAL: The symptoms and hypothesis fields MUST contain bone-fracture keywords (e.g., "fracture", "bone") so that `_analyze_image()` routes to the bone specialist model instead of the chest model. See the _BONE_TERMS logic in image_agent.py. Model under test ---------------- Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured) New bone domain specialist added in this session. Usage: python eval_bone.py # 100 cases (shuffled, seed=42) python eval_bone.py --n 200 # custom count python eval_bone.py --no-shuffle # first N cases python eval_bone.py --preview # print 2 raw examples and exit """ import argparse import base64 import io import json import os import sys import time from collections import defaultdict, Counter from datetime import date sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) try: from dotenv import dotenv_values for k, v in dotenv_values(".env").items(): os.environ.setdefault(k, v) except Exception: pass from datasets import load_dataset from PIL import Image from agents.image_agent import ImageAgent DATASET_NAME = "Hemg/bone-fracture-detection" RESULTS_DIR = "results" DOCS_DIR = "docs" # 2 ground-truth classes ALL_CLASSES = ["fractured", "not_fractured"] CLASS_FULL = { "fractured": "Bone Fracture Present", "not_fractured": "No Fracture Detected", } # ── Helpers ─────────────────────────────────────────────────────────────────── def normalize_to_bone(label: str) -> str: """Map model output or dataset label to fractured / not_fractured.""" s = label.lower().replace("_", " ").replace("-", " ").strip() if "fractur" in s or "broken" in s or "break" in s or "crack" in s: return "fractured" if "not fractur" in s or "no fractur" in s or "normal" in s or "healthy" in s: return "not_fractured" # Dataset uses "fractured" / "not fractured" as primary labels — cover both if s in ("fractured", "fracture", "1", "yes"): return "fractured" if s in ("not fractured", "not_fractured", "0", "no"): return "not_fractured" return "not_fractured" # safe default def pil_to_b64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") def build_patient_context() -> tuple[str, str]: """Clinical context that triggers bone domain routing via _BONE_TERMS.""" symptoms = "Bone pain following trauma. X-ray submitted for bone fracture assessment." history = "Musculoskeletal X-ray for bone fracture evaluation. Orthopedic referral." return symptoms, history # ── Dataset preview ─────────────────────────────────────────────────────────── def preview(ds) -> None: print(f"\nDataset columns: {list(ds.features.keys())}\n") for i in range(min(2, len(ds))): row = ds[i] print(f"─── Example {i+1} ─────────────────────────────────────────") for k, v in row.items(): if hasattr(v, "size"): print(f" {k:<20} PIL Image size={v.size} mode={v.mode}") else: print(f" {k:<20} ({type(v).__name__}) = {repr(v)[:120]}") print() # ── Core evaluation loop ────────────────────────────────────────────────────── def evaluate(n_cases: int, shuffle: bool = True) -> None: os.makedirs(RESULTS_DIR, exist_ok=True) os.makedirs(DOCS_DIR, exist_ok=True) print(f"Loading {DATASET_NAME}…") try: ds = load_dataset(DATASET_NAME, split="test") except Exception: try: ds = load_dataset(DATASET_NAME, split="train") print(" (no test split — using train)") except Exception as exc: print(f"ERROR: Cannot load {DATASET_NAME}: {exc}", file=sys.stderr) sys.exit(1) total = len(ds) if shuffle: ds = ds.shuffle(seed=42) print(f"Loaded: {total} cases. Shuffled (seed=42), using first {n_cases}.\n") else: print(f"Loaded: {total} cases. Using first {n_cases} (unshuffled).\n") preview(ds) sample = ds[0] img_field = next((f for f in ("image", "img", "xray", "x_ray") if f in sample), None) lbl_field = next((f for f in ("label", "class", "fracture", "finding") if f in sample), None) if not img_field or not lbl_field: print(f"ERROR: Cannot find image/label fields in {list(sample.keys())}", file=sys.stderr) sys.exit(1) print(f"Using fields: image='{img_field}', label='{lbl_field}'\n") # Resolve label names if dataset uses integer indices label_names: list[str] | None = None feat = ds.features.get(lbl_field) if hasattr(feat, "names"): label_names = feat.names print(f"Label names from dataset features: {label_names}\n") agent = ImageAgent() cases = list(ds.select(range(min(n_cases, len(ds))))) top1_hits = 0 errors = 0 tp = fp = fn = tn = 0 confusion: dict[str, list[str]] = defaultdict(list) records: list[dict] = [] start_wall = time.time() symptoms, history = build_patient_context() print(f"{'#':>5} {'GT':<16} {'Our Pred':<16} {'Confidence':>11} {'Domain':<10} Result") print("─" * 70) for idx, case in enumerate(cases): raw_label = case[lbl_field] if isinstance(raw_label, int) and label_names: gt_str = normalize_to_bone(label_names[raw_label]) else: gt_str = normalize_to_bone(str(raw_label)) pil_img = case[img_field] if not isinstance(pil_img, Image.Image): errors += 1 continue patient_data = { "image_b64": pil_to_b64(pil_img), "symptoms": symptoms, "medical_history": history, "lab_values": "", # CRITICAL: bone keywords in symptoms AND hypothesis trigger the # hypothesis-guided bone routing override in _analyze_image(). # Without this, grayscale X-rays route to "chest" domain. "hypothesis": { "primary_hypothesis": { "disease": "bone fracture", "confidence": 0.5, }, }, } t0 = time.time() try: result = agent.analyze(patient_data) candidates = result.get("candidates", []) except Exception as exc: errors += 1 print(f"{idx+1:>5} {gt_str:<16} {'ERROR':16} {'':>11} {'':10} ERR: {exc}", file=sys.stderr) candidates = [] result = {} elapsed = round(time.time() - t0, 2) raw_top = candidates[0]["disease"] if candidates else "unknown" our_pred = normalize_to_bone(raw_top) conf = candidates[0]["score"] if candidates else 0.0 domain = result.get("image_domain") or "?" hit1 = (our_pred == gt_str) top1_hits += hit1 in_gt = (gt_str == "fractured") in_pred = (our_pred == "fractured") if in_gt and in_pred: tp += 1 elif not in_gt and in_pred: fp += 1 elif in_gt and not in_pred: fn += 1 else: tn += 1 confusion[gt_str].append(our_pred) tag = "HIT" if hit1 else "MISS" print(f"{idx+1:>5} {gt_str:<16} {our_pred:<16} {conf:>10.1%} {domain:<10} {tag}") records.append({ "idx": idx, "gt": gt_str, "our_pred": our_pred, "raw_pred": raw_top, "confidence": round(conf, 4), "hit1": hit1, "time_sec": elapsed, "domain": domain, }) if (idx + 1) % 10 == 0: acc = top1_hits / (idx + 1) * 100 print(f"\n ── Case {idx+1}/{n_cases} | Top-1 Acc: {acc:.1f}% ──\n") # ── Metrics ─────────────────────────────────────────────────────────────── total_time = round(time.time() - start_wall, 1) evaluated = len(records) top1_pct = round(top1_hits / max(evaluated, 1) * 100, 1) avg_time = round(total_time / max(evaluated, 1), 2) sensitivity = round(tp / max(tp + fn, 1), 4) # fracture recall specificity = round(tn / max(tn + fp, 1), 4) precision = round(tp / max(tp + fp, 1), 4) f1_fracture = round(2 * precision * sensitivity / max(precision + sensitivity, 1e-9), 4) confusion_summary: dict[str, dict] = {} for gt_lbl, preds in confusion.items(): wrong = [p for p in preds if p != gt_lbl] if wrong: mc = Counter(wrong).most_common(1)[0] confusion_summary[gt_lbl] = { "predicted_instead": mc[0], "count": mc[1], "total_cases": len(preds), "error_rate": round(len(wrong) / len(preds), 3), } print("\n" + "=" * 70) print("EVALUATION SUMMARY — Bone Fracture Detection (ImageAgent, bone model)") print("=" * 70) print(f" Cases evaluated : {evaluated} ({errors} errors)") print(f" Total wall time : {total_time}s ({avg_time}s/case)") print(f" Top-1 Accuracy : {top1_hits}/{evaluated} = {top1_pct}% (binary: fractured vs not)") print(f"\n Fracture detection (binary):") print(f" TP={tp} FP={fp} FN={fn} TN={tn}") print(f" Sensitivity (Recall) : {sensitivity:.3f}") print(f" Specificity : {specificity:.3f}") print(f" Precision : {precision:.3f}") print(f" F1 (fracture) : {f1_fracture:.4f}") print("=" * 70) json_path = os.path.join(RESULTS_DIR, "eval_08_bone.json") payload = { "dataset": DATASET_NAME, "note": "Bone Fracture Detection — MIT license. 9,246 X-ray images. Requires hypothesis-guided bone routing.", "date_run": date.today().isoformat(), "agent_evaluated": "image_agent only (isolated bone model; hypothesis-guided routing required)", "bone_model": "Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured)", "routing_note": "Bone X-rays are grayscale — pixel heuristic would route to 'chest'. Hypothesis override with 'bone fracture' keywords forces bone domain.", "evaluation_strategy": "Binary: model 'Fractured' → positive; 'Not Fractured' → negative", "cases_evaluated": evaluated, "top1_accuracy": top1_pct, "fracture_sensitivity": sensitivity, "fracture_specificity": specificity, "fracture_precision": precision, "fracture_f1": f1_fracture, "confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn}, "avg_time_per_case": avg_time, "confusion_summary": confusion_summary, "correct_examples": [r for r in records if r["hit1"]][:5], "wrong_examples": [r for r in records if not r["hit1"]][:5], "total_runtime_sec": total_time, "errors_encountered": errors, } with open(json_path, "w") as f: json.dump(payload, f, indent=2) print(f"\nJSON results saved → {json_path}") _write_report(payload, confusion_summary) def _write_report(payload: dict, confusion_summary: dict) -> None: confusion_section = "\n".join( f"- **{gt} → {info['predicted_instead']}**: " f"{info['count']}/{info['total_cases']} ({info['error_rate']:.1%} error)" for gt, info in sorted(confusion_summary.items()) ) or "No misclassifications recorded." md = f"""# Evaluation 8: Bone Fracture Detection (X-ray) **Date:** {payload['date_run']} **Dataset:** `{payload['dataset']}` **Agent:** ImageAgent (isolated — bone model with hypothesis-guided routing) **Model:** `{payload['bone_model']}` **Cases:** {payload['cases_evaluated']} **License:** MIT — freely accessible, no registration required --- ## Dataset: Bone Fracture Detection 9,246 bone X-ray images split into two classes: - **Fractured** — clear or subtle fracture line present - **Not Fractured** — normal bone anatomy Images span multiple skeletal sites (arm, hand, shoulder, leg, knee, foot, spine). No per-patient clinical metadata is available in this dataset. --- ## Domain routing: hypothesis override Bone X-rays are **grayscale and non-square** — the same pixel-level signature as chest X-rays. Without intervention, `_detect_domain()` would route these images to the chest specialist model, producing nonsensical output (Pneumonia, Edema, etc.). **Solution:** Hypothesis-guided bone routing (implemented in `image_agent.py`). When patient symptoms or hypothesis contain bone-related keywords (fracture, bone, musculoskeletal, orthop, skeletal, dislocation, arthritis), the domain is overridden to "bone" regardless of pixel-level classification. This eval passes: - `symptoms`: "Bone pain following trauma. X-ray submitted for bone fracture assessment." - `hypothesis.disease`: "bone fracture" Both triggers ensure consistent routing to the bone specialist model. --- ## Results | Metric | Value | |---|---| | **Binary Top-1 Accuracy** | **{payload['top1_accuracy']}%** | | **Fracture Sensitivity** | **{payload['fracture_sensitivity']:.3f}** | | **Fracture Specificity** | **{payload['fracture_specificity']:.3f}** | | **Fracture F1** | **{payload['fracture_f1']:.4f}** | | Confusion matrix | TP={payload['confusion_matrix']['tp']}, FP={payload['confusion_matrix']['fp']}, FN={payload['confusion_matrix']['fn']}, TN={payload['confusion_matrix']['tn']} | | Avg time per case | {payload['avg_time_per_case']}s | **Clinical note:** Sensitivity is the critical metric for fracture screening — missed fractures (FN) delay treatment and risk non-union or malunion complications. False positives (FP) lead to unnecessary immobilization but are less harmful clinically. --- ## Confusion analysis {confusion_section} --- ## Routing validation If `domain` values in the results are predominantly "bone" (not "chest"), the hypothesis-guided override is working correctly. If "chest" appears, check that the `symptoms` and `hypothesis` fields in `patient_data` contain the required bone-trigger keywords and that the `_BONE_TERMS` set in `image_agent.py` includes "fracture" and "bone". --- ## Dataset notes - Covers multiple anatomical sites — the model was trained on multi-site bone X-rays and generalizes reasonably across extremities, spine, and pelvis. - Subtle fractures (stress fractures, hairline fractures) are the primary failure mode for automated systems; the dataset label distribution should be checked for severity. - No geographic or demographic bias information is documented for this dataset. """ md_path = os.path.join(DOCS_DIR, "eval_08_bone.md") with open(md_path, "w") as f: f.write(md) print(f"Markdown report saved → {md_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate ImageAgent on Bone Fracture Detection dataset") parser.add_argument("--n", type=int, default=100) parser.add_argument("--no-shuffle", action="store_true") parser.add_argument("--preview", action="store_true") args = parser.parse_args() if args.preview: try: ds = load_dataset(DATASET_NAME, split="test") except Exception: ds = load_dataset(DATASET_NAME, split="train") preview(ds) else: evaluate(args.n, shuffle=not args.no_shuffle)