| |
| """ |
| Evaluate MedAgent ImageAgent on bone fracture X-ray detection. |
| |
| Dataset |
| ------- |
| prithivMLmods/Bone-Fracture-Detection β 9,246 bone X-ray images. |
| License: MIT / public domain. No registration required. |
| |
| 2 classes: |
| fractured = Bone fracture present |
| not_fractured = No fracture detected |
| |
| Clinical context |
| ---------------- |
| Dataset does not include per-patient metadata (age/sex). We pass a clinically |
| realistic context (fracture symptoms + bone hypothesis) to trigger hypothesis- |
| guided bone domain routing in ImageAgent β without this, grayscale bone X-rays |
| are classified as "chest" by the pixel-level heuristic. |
| |
| CRITICAL: The symptoms and hypothesis fields MUST contain bone-fracture keywords |
| (e.g., "fracture", "bone") so that `_analyze_image()` routes to the bone specialist |
| model instead of the chest model. See the _BONE_TERMS logic in image_agent.py. |
| |
| Model under test |
| ---------------- |
| Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured) |
| New bone domain specialist added in this session. |
| |
| Usage: |
| python eval_bone.py # 100 cases (shuffled, seed=42) |
| python eval_bone.py --n 200 # custom count |
| python eval_bone.py --no-shuffle # first N cases |
| python eval_bone.py --preview # print 2 raw examples and exit |
| """ |
|
|
| import argparse |
| import base64 |
| import io |
| import json |
| import os |
| import sys |
| import time |
| from collections import defaultdict, Counter |
| from datetime import date |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| try: |
| from dotenv import dotenv_values |
| for k, v in dotenv_values(".env").items(): |
| os.environ.setdefault(k, v) |
| except Exception: |
| pass |
|
|
| from datasets import load_dataset |
| from PIL import Image |
| from agents.image_agent import ImageAgent |
|
|
| DATASET_NAME = "Hemg/bone-fracture-detection" |
| RESULTS_DIR = "results" |
| DOCS_DIR = "docs" |
|
|
| |
| ALL_CLASSES = ["fractured", "not_fractured"] |
| CLASS_FULL = { |
| "fractured": "Bone Fracture Present", |
| "not_fractured": "No Fracture Detected", |
| } |
|
|
|
|
| |
|
|
| def normalize_to_bone(label: str) -> str: |
| """Map model output or dataset label to fractured / not_fractured.""" |
| s = label.lower().replace("_", " ").replace("-", " ").strip() |
| if "fractur" in s or "broken" in s or "break" in s or "crack" in s: |
| return "fractured" |
| if "not fractur" in s or "no fractur" in s or "normal" in s or "healthy" in s: |
| return "not_fractured" |
| |
| if s in ("fractured", "fracture", "1", "yes"): |
| return "fractured" |
| if s in ("not fractured", "not_fractured", "0", "no"): |
| return "not_fractured" |
| return "not_fractured" |
|
|
|
|
| def pil_to_b64(img: Image.Image) -> str: |
| buf = io.BytesIO() |
| img.save(buf, format="PNG") |
| return base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
| def build_patient_context() -> tuple[str, str]: |
| """Clinical context that triggers bone domain routing via _BONE_TERMS.""" |
| symptoms = "Bone pain following trauma. X-ray submitted for bone fracture assessment." |
| history = "Musculoskeletal X-ray for bone fracture evaluation. Orthopedic referral." |
| return symptoms, history |
|
|
|
|
| |
|
|
| def preview(ds) -> None: |
| print(f"\nDataset columns: {list(ds.features.keys())}\n") |
| for i in range(min(2, len(ds))): |
| row = ds[i] |
| print(f"βββ Example {i+1} βββββββββββββββββββββββββββββββββββββββββ") |
| for k, v in row.items(): |
| if hasattr(v, "size"): |
| print(f" {k:<20} PIL Image size={v.size} mode={v.mode}") |
| else: |
| print(f" {k:<20} ({type(v).__name__}) = {repr(v)[:120]}") |
| print() |
|
|
|
|
| |
|
|
| def evaluate(n_cases: int, shuffle: bool = True) -> None: |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| os.makedirs(DOCS_DIR, exist_ok=True) |
|
|
| print(f"Loading {DATASET_NAME}β¦") |
| try: |
| ds = load_dataset(DATASET_NAME, split="test") |
| except Exception: |
| try: |
| ds = load_dataset(DATASET_NAME, split="train") |
| print(" (no test split β using train)") |
| except Exception as exc: |
| print(f"ERROR: Cannot load {DATASET_NAME}: {exc}", file=sys.stderr) |
| sys.exit(1) |
|
|
| total = len(ds) |
| if shuffle: |
| ds = ds.shuffle(seed=42) |
| print(f"Loaded: {total} cases. Shuffled (seed=42), using first {n_cases}.\n") |
| else: |
| print(f"Loaded: {total} cases. Using first {n_cases} (unshuffled).\n") |
|
|
| preview(ds) |
|
|
| sample = ds[0] |
| img_field = next((f for f in ("image", "img", "xray", "x_ray") if f in sample), None) |
| lbl_field = next((f for f in ("label", "class", "fracture", "finding") if f in sample), None) |
| if not img_field or not lbl_field: |
| print(f"ERROR: Cannot find image/label fields in {list(sample.keys())}", file=sys.stderr) |
| sys.exit(1) |
| print(f"Using fields: image='{img_field}', label='{lbl_field}'\n") |
|
|
| |
| label_names: list[str] | None = None |
| feat = ds.features.get(lbl_field) |
| if hasattr(feat, "names"): |
| label_names = feat.names |
| print(f"Label names from dataset features: {label_names}\n") |
|
|
| agent = ImageAgent() |
| cases = list(ds.select(range(min(n_cases, len(ds))))) |
|
|
| top1_hits = 0 |
| errors = 0 |
| tp = fp = fn = tn = 0 |
| confusion: dict[str, list[str]] = defaultdict(list) |
| records: list[dict] = [] |
| start_wall = time.time() |
|
|
| symptoms, history = build_patient_context() |
|
|
| print(f"{'#':>5} {'GT':<16} {'Our Pred':<16} {'Confidence':>11} {'Domain':<10} Result") |
| print("β" * 70) |
|
|
| for idx, case in enumerate(cases): |
| raw_label = case[lbl_field] |
| if isinstance(raw_label, int) and label_names: |
| gt_str = normalize_to_bone(label_names[raw_label]) |
| else: |
| gt_str = normalize_to_bone(str(raw_label)) |
|
|
| pil_img = case[img_field] |
| if not isinstance(pil_img, Image.Image): |
| errors += 1 |
| continue |
|
|
| patient_data = { |
| "image_b64": pil_to_b64(pil_img), |
| "symptoms": symptoms, |
| "medical_history": history, |
| "lab_values": "", |
| |
| |
| |
| "hypothesis": { |
| "primary_hypothesis": { |
| "disease": "bone fracture", |
| "confidence": 0.5, |
| }, |
| }, |
| } |
|
|
| t0 = time.time() |
| try: |
| result = agent.analyze(patient_data) |
| candidates = result.get("candidates", []) |
| except Exception as exc: |
| errors += 1 |
| print(f"{idx+1:>5} {gt_str:<16} {'ERROR':16} {'':>11} {'':10} ERR: {exc}", |
| file=sys.stderr) |
| candidates = [] |
| result = {} |
|
|
| elapsed = round(time.time() - t0, 2) |
| raw_top = candidates[0]["disease"] if candidates else "unknown" |
| our_pred = normalize_to_bone(raw_top) |
| conf = candidates[0]["score"] if candidates else 0.0 |
| domain = result.get("image_domain") or "?" |
|
|
| hit1 = (our_pred == gt_str) |
| top1_hits += hit1 |
|
|
| in_gt = (gt_str == "fractured") |
| in_pred = (our_pred == "fractured") |
| if in_gt and in_pred: tp += 1 |
| elif not in_gt and in_pred: fp += 1 |
| elif in_gt and not in_pred: fn += 1 |
| else: tn += 1 |
|
|
| confusion[gt_str].append(our_pred) |
|
|
| tag = "HIT" if hit1 else "MISS" |
| print(f"{idx+1:>5} {gt_str:<16} {our_pred:<16} {conf:>10.1%} {domain:<10} {tag}") |
|
|
| records.append({ |
| "idx": idx, "gt": gt_str, "our_pred": our_pred, |
| "raw_pred": raw_top, "confidence": round(conf, 4), |
| "hit1": hit1, "time_sec": elapsed, "domain": domain, |
| }) |
|
|
| if (idx + 1) % 10 == 0: |
| acc = top1_hits / (idx + 1) * 100 |
| print(f"\n ββ Case {idx+1}/{n_cases} | Top-1 Acc: {acc:.1f}% ββ\n") |
|
|
| |
| total_time = round(time.time() - start_wall, 1) |
| evaluated = len(records) |
| top1_pct = round(top1_hits / max(evaluated, 1) * 100, 1) |
| avg_time = round(total_time / max(evaluated, 1), 2) |
|
|
| sensitivity = round(tp / max(tp + fn, 1), 4) |
| specificity = round(tn / max(tn + fp, 1), 4) |
| precision = round(tp / max(tp + fp, 1), 4) |
| f1_fracture = round(2 * precision * sensitivity / max(precision + sensitivity, 1e-9), 4) |
|
|
| confusion_summary: dict[str, dict] = {} |
| for gt_lbl, preds in confusion.items(): |
| wrong = [p for p in preds if p != gt_lbl] |
| if wrong: |
| mc = Counter(wrong).most_common(1)[0] |
| confusion_summary[gt_lbl] = { |
| "predicted_instead": mc[0], "count": mc[1], |
| "total_cases": len(preds), |
| "error_rate": round(len(wrong) / len(preds), 3), |
| } |
|
|
| print("\n" + "=" * 70) |
| print("EVALUATION SUMMARY β Bone Fracture Detection (ImageAgent, bone model)") |
| print("=" * 70) |
| print(f" Cases evaluated : {evaluated} ({errors} errors)") |
| print(f" Total wall time : {total_time}s ({avg_time}s/case)") |
| print(f" Top-1 Accuracy : {top1_hits}/{evaluated} = {top1_pct}% (binary: fractured vs not)") |
| print(f"\n Fracture detection (binary):") |
| print(f" TP={tp} FP={fp} FN={fn} TN={tn}") |
| print(f" Sensitivity (Recall) : {sensitivity:.3f}") |
| print(f" Specificity : {specificity:.3f}") |
| print(f" Precision : {precision:.3f}") |
| print(f" F1 (fracture) : {f1_fracture:.4f}") |
| print("=" * 70) |
|
|
| json_path = os.path.join(RESULTS_DIR, "eval_08_bone.json") |
| payload = { |
| "dataset": DATASET_NAME, |
| "note": "Bone Fracture Detection β MIT license. 9,246 X-ray images. Requires hypothesis-guided bone routing.", |
| "date_run": date.today().isoformat(), |
| "agent_evaluated": "image_agent only (isolated bone model; hypothesis-guided routing required)", |
| "bone_model": "Hemgg/bone-fracture-detection-using-xray (2 classes: Fractured / Not Fractured)", |
| "routing_note": "Bone X-rays are grayscale β pixel heuristic would route to 'chest'. Hypothesis override with 'bone fracture' keywords forces bone domain.", |
| "evaluation_strategy": "Binary: model 'Fractured' β positive; 'Not Fractured' β negative", |
| "cases_evaluated": evaluated, |
| "top1_accuracy": top1_pct, |
| "fracture_sensitivity": sensitivity, |
| "fracture_specificity": specificity, |
| "fracture_precision": precision, |
| "fracture_f1": f1_fracture, |
| "confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn}, |
| "avg_time_per_case": avg_time, |
| "confusion_summary": confusion_summary, |
| "correct_examples": [r for r in records if r["hit1"]][:5], |
| "wrong_examples": [r for r in records if not r["hit1"]][:5], |
| "total_runtime_sec": total_time, |
| "errors_encountered": errors, |
| } |
| with open(json_path, "w") as f: |
| json.dump(payload, f, indent=2) |
| print(f"\nJSON results saved β {json_path}") |
|
|
| _write_report(payload, confusion_summary) |
|
|
|
|
| def _write_report(payload: dict, confusion_summary: dict) -> None: |
| confusion_section = "\n".join( |
| f"- **{gt} β {info['predicted_instead']}**: " |
| f"{info['count']}/{info['total_cases']} ({info['error_rate']:.1%} error)" |
| for gt, info in sorted(confusion_summary.items()) |
| ) or "No misclassifications recorded." |
|
|
| md = f"""# Evaluation 8: Bone Fracture Detection (X-ray) |
| |
| **Date:** {payload['date_run']} |
| **Dataset:** `{payload['dataset']}` |
| **Agent:** ImageAgent (isolated β bone model with hypothesis-guided routing) |
| **Model:** `{payload['bone_model']}` |
| **Cases:** {payload['cases_evaluated']} |
| **License:** MIT β freely accessible, no registration required |
| |
| --- |
| |
| ## Dataset: Bone Fracture Detection |
| |
| 9,246 bone X-ray images split into two classes: |
| - **Fractured** β clear or subtle fracture line present |
| - **Not Fractured** β normal bone anatomy |
| |
| Images span multiple skeletal sites (arm, hand, shoulder, leg, knee, foot, spine). |
| No per-patient clinical metadata is available in this dataset. |
| |
| --- |
| |
| ## Domain routing: hypothesis override |
| |
| Bone X-rays are **grayscale and non-square** β the same pixel-level signature as |
| chest X-rays. Without intervention, `_detect_domain()` would route these images |
| to the chest specialist model, producing nonsensical output (Pneumonia, Edema, etc.). |
| |
| **Solution:** Hypothesis-guided bone routing (implemented in `image_agent.py`). |
| When patient symptoms or hypothesis contain bone-related keywords (fracture, bone, |
| musculoskeletal, orthop, skeletal, dislocation, arthritis), the domain is overridden |
| to "bone" regardless of pixel-level classification. This eval passes: |
| - `symptoms`: "Bone pain following trauma. X-ray submitted for bone fracture assessment." |
| - `hypothesis.disease`: "bone fracture" |
| |
| Both triggers ensure consistent routing to the bone specialist model. |
| |
| --- |
| |
| ## Results |
| |
| | Metric | Value | |
| |---|---| |
| | **Binary Top-1 Accuracy** | **{payload['top1_accuracy']}%** | |
| | **Fracture Sensitivity** | **{payload['fracture_sensitivity']:.3f}** | |
| | **Fracture Specificity** | **{payload['fracture_specificity']:.3f}** | |
| | **Fracture F1** | **{payload['fracture_f1']:.4f}** | |
| | Confusion matrix | TP={payload['confusion_matrix']['tp']}, FP={payload['confusion_matrix']['fp']}, FN={payload['confusion_matrix']['fn']}, TN={payload['confusion_matrix']['tn']} | |
| | Avg time per case | {payload['avg_time_per_case']}s | |
| |
| **Clinical note:** Sensitivity is the critical metric for fracture screening β |
| missed fractures (FN) delay treatment and risk non-union or malunion complications. |
| False positives (FP) lead to unnecessary immobilization but are less harmful clinically. |
| |
| --- |
| |
| ## Confusion analysis |
| |
| {confusion_section} |
| |
| --- |
| |
| ## Routing validation |
| |
| If `domain` values in the results are predominantly "bone" (not "chest"), the |
| hypothesis-guided override is working correctly. If "chest" appears, check that |
| the `symptoms` and `hypothesis` fields in `patient_data` contain the required |
| bone-trigger keywords and that the `_BONE_TERMS` set in `image_agent.py` includes |
| "fracture" and "bone". |
| |
| --- |
| |
| ## Dataset notes |
| |
| - Covers multiple anatomical sites β the model was trained on multi-site bone X-rays |
| and generalizes reasonably across extremities, spine, and pelvis. |
| - Subtle fractures (stress fractures, hairline fractures) are the primary failure mode |
| for automated systems; the dataset label distribution should be checked for severity. |
| - No geographic or demographic bias information is documented for this dataset. |
| """ |
| md_path = os.path.join(DOCS_DIR, "eval_08_bone.md") |
| with open(md_path, "w") as f: |
| f.write(md) |
| print(f"Markdown report saved β {md_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Evaluate ImageAgent on Bone Fracture Detection dataset") |
| parser.add_argument("--n", type=int, default=100) |
| parser.add_argument("--no-shuffle", action="store_true") |
| parser.add_argument("--preview", action="store_true") |
| args = parser.parse_args() |
|
|
| if args.preview: |
| try: |
| ds = load_dataset(DATASET_NAME, split="test") |
| except Exception: |
| ds = load_dataset(DATASET_NAME, split="train") |
| preview(ds) |
| else: |
| evaluate(args.n, shuffle=not args.no_shuffle) |
|
|