AttrLLM / visualization /mimic_loader.py
Qingpeng Kong
clean initial state
3e72399
"""
Loader for MIMIC-CXR benchmark examples.
Reads the curated 10-sample CSV and loads any precomputed attribution
results (same directory structure as the PubMedVision benchmark).
"""
from __future__ import annotations
import csv
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Optional
# Re-use parsers from the existing medical_loader
from .medical_loader import (
parse_summary_txt,
parse_vllm_summary,
_build_all_cross_modal_pairs,
apply_method_to_clip_summary,
apply_method_to_vllm_summary,
load_mobius_sidecar,
load_vllm_result_json,
rename_summary_patch_labels_in_place,
rename_cross_pair_patch_labels_in_place,
)
# ── Path resolution ──────────────────────────────────────────────────────
_VIZ_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = _VIZ_DIR.parent
def _resolve_mimic_dataset_dir() -> Path:
return _PROJECT_ROOT / "results" / "mimic" / "dataset"
def _resolve_mimic_results_dir(method_suffix: str = "") -> Optional[Path]:
"""Resolve a MIMIC results directory.
For the BiomedCLIP and LLaVA-Med UnSAM slots, prefers the 4Γ—4 patch-grid
variant (`mimic_biomedclip_patch/`, `mimic_llavamed_patch/`) when present
and non-empty, then falls back to the UnSAM directory. The in-memory
keys and UI labels still use the `_unsam` slot name for historical
compatibility β€” only the on-disk source differs.
method_suffix examples: "", "_biomedclip", "_llavamed",
"_llavamed_unsam", "_vlm_unsam"
"""
_PATCH_MAP = {
"_biomedclip": "_biomedclip_patch",
"_llavamed_unsam": "_llavamed_patch",
}
patch_suffix = _PATCH_MAP.get(method_suffix)
if patch_suffix:
patch_dir = _PROJECT_ROOT / "results" / f"mimic{patch_suffix}"
if patch_dir.exists() and any(patch_dir.iterdir()):
return patch_dir
d = _PROJECT_ROOT / "results" / f"mimic{method_suffix}"
return d if d.exists() else None
# ── Example registry ─────────────────────────────────────────────────────
MIMIC_EXAMPLES: Dict[str, Dict[str, Any]] = {}
def _load_mimic_examples_from_csv() -> Dict[str, Dict[str, Any]]:
"""Load the MIMIC-CXR curated CSV into a registry dict."""
csv_path = _resolve_mimic_dataset_dir() / "mimic_cxr_10.csv"
if not csv_path.exists():
return {}
examples = {}
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
cocoid = row["cocoid"]
example_id = f"coco_{cocoid}"
category = row.get("category", "")
caption = row.get("caption", "")
findings = row.get("findings", "")
# Display title is just the category name (short and scannable)
cap_short = caption[:60] + "..." if len(caption) > 60 else caption
title = category
examples[example_id] = {
"title": title,
"short": cap_short,
"category": category,
"caption": caption,
"findings": findings,
"img_name": row.get("img_name", ""),
"matched_keyword": row.get("matched_keyword", ""),
"cocoid": cocoid,
"source": "MIMIC-CXR",
"has_results": False, # updated below
}
# Check which examples have precomputed results
for eid, meta in examples.items():
for suffix in ["", "_biomedclip", "_llavamed_unsam"]:
rdir = _resolve_mimic_results_dir(suffix)
if rdir and (rdir / eid).exists():
meta["has_results"] = True
break
return examples
# Load at import time
MIMIC_EXAMPLES = _load_mimic_examples_from_csv()
def get_mimic_examples_by_category(
category: Optional[str] = None,
) -> Dict[str, Dict[str, Any]]:
"""Filter MIMIC examples by pathology category."""
if not category or category.lower() == "all":
return MIMIC_EXAMPLES
return {
k: v for k, v in MIMIC_EXAMPLES.items()
if v.get("category", "").lower() == category.lower()
}
def list_mimic_categories() -> List[str]:
"""Return sorted list of unique pathology categories."""
cats = sorted({v["category"] for v in MIMIC_EXAMPLES.values() if v.get("category")})
return cats
# ── Image loading ────────────────────────────────────────────────────────
def get_mimic_image_path(example_id: str) -> Optional[str]:
"""Return the path to the original chest X-ray image."""
meta = MIMIC_EXAMPLES.get(example_id)
if not meta:
return None
img_name = meta.get("img_name", "")
if not img_name:
return None
img_path = _resolve_mimic_dataset_dir() / "images" / img_name
return str(img_path) if img_path.exists() else None
# ── Result loading ───────────────────────────────────────────────────────
def load_mimic_example(example_id: str, *, method: str = "shapley") -> Dict[str, Any]:
"""Load all available precomputed results for a MIMIC-CXR example.
Returns a dict with the same structure as load_benchmark_example()
from medical_loader.py, so the UI handler can use the same logic.
"""
meta = MIMIC_EXAMPLES.get(example_id, {})
caption = meta.get("caption", "")
findings = meta.get("findings", "")
data: Dict[str, Any] = {
"example_id": example_id,
"meta": meta,
"caption": caption,
"findings": findings,
"method": method,
"original_image_path": get_mimic_image_path(example_id),
"has_mobius": {},
# Flags
"has_clip": False,
"has_biomedclip": False,
"has_vllm_logprob": False,
"has_vllm_gen": False,
"has_llavamed_logprob": False,
"has_llavamed_gen": False,
"has_vlm_unsam_logprob": False,
"has_vlm_unsam_gen": False,
"has_llavamed_unsam_logprob": False,
"has_llavamed_unsam_gen": False,
}
# ── CLIP cross-modal ─────────────────────────────────────────────
clip_dir = _resolve_mimic_results_dir("_tok30_dotmask")
if not clip_dir:
clip_dir = _resolve_mimic_results_dir("_tok30")
if not clip_dir:
clip_dir = _resolve_mimic_results_dir()
if clip_dir:
edir = clip_dir / example_id
summary_path = edir / "summary.txt"
if summary_path.exists():
summary = parse_summary_txt(summary_path)
clip_mobius = load_mobius_sidecar(edir)
apply_method_to_clip_summary(summary, clip_mobius, method)
rename_summary_patch_labels_in_place(summary)
data["has_clip"] = True
data["has_mobius"]["clip"] = clip_mobius is not None
data["clip"] = {
"summary": summary,
"mobius_sidecar": clip_mobius,
"image_paths": {
"original": str(edir / "original.png") if (edir / "original.png").exists() else "",
"overlay": str(edir / "overlay.png") if (edir / "overlay.png").exists() else "",
"segmap": str(edir / "segmap.png") if (edir / "segmap.png").exists() else "",
},
"image_b64": {},
}
data["clip"]["all_cross_modal_pairs"] = _build_all_cross_modal_pairs(
data["clip"], mobius_sidecar=clip_mobius, method=method,
)
rename_cross_pair_patch_labels_in_place(data["clip"]["all_cross_modal_pairs"])
# Load base64 images for interactive view
for key in ("original", "overlay", "segmap"):
fpath = edir / f"{key}.png"
if fpath.exists():
import base64
with open(fpath, "rb") as f:
data["clip"].setdefault("image_b64", {})[key] = base64.b64encode(f.read()).decode("ascii")
# ── BiomedCLIP cross-modal ───────────────────────────────────────
bc_dir = _resolve_mimic_results_dir("_biomedclip")
if bc_dir:
edir = bc_dir / example_id
summary_path = edir / "summary.txt"
if summary_path.exists():
summary = parse_summary_txt(summary_path)
bc_mobius = load_mobius_sidecar(edir)
apply_method_to_clip_summary(summary, bc_mobius, method)
rename_summary_patch_labels_in_place(summary)
data["has_biomedclip"] = True
data["has_mobius"]["biomedclip"] = bc_mobius is not None
data["biomedclip"] = {
"summary": summary,
"mobius_sidecar": bc_mobius,
"image_paths": {
"original": str(edir / "original.png") if (edir / "original.png").exists() else "",
"overlay": str(edir / "overlay.png") if (edir / "overlay.png").exists() else "",
"segmap": str(edir / "segmap.png") if (edir / "segmap.png").exists() else "",
},
"image_b64": {},
}
data["biomedclip"]["all_cross_modal_pairs"] = _build_all_cross_modal_pairs(
data["biomedclip"], mobius_sidecar=bc_mobius, method=method,
)
rename_cross_pair_patch_labels_in_place(data["biomedclip"]["all_cross_modal_pairs"])
for key in ("original", "overlay", "segmap"):
fpath = edir / f"{key}.png"
if fpath.exists():
import base64
with open(fpath, "rb") as f:
data["biomedclip"].setdefault("image_b64", {})[key] = base64.b64encode(f.read()).decode("ascii")
# ── VLM (Qwen2-VL) logprob + gen ────────────────────────────────
vlm_dir = _resolve_mimic_results_dir()
if vlm_dir:
edir = vlm_dir / example_id
for prefix, flag_key, json_key in [
("vllm_logprob", "has_vllm_logprob", "vllm_logprob"),
("vllm_gen", "has_vllm_gen", "vllm_gen"),
]:
summary_path = edir / f"{prefix}_summary.txt"
if summary_path.exists():
parsed = parse_vllm_summary(summary_path)
if parsed:
json_data = load_vllm_result_json(edir, prefix, method=method)
apply_method_to_vllm_summary(parsed, json_data, method)
data[flag_key] = True
data[json_key] = parsed
data[f"{json_key}_json"] = json_data
data["has_mobius"][json_key] = bool(json_data.get("mobius_dict"))
overlay = edir / f"{prefix}_overlay.png"
if overlay.exists():
data[json_key]["overlay_path"] = str(overlay)
# ── LLaVA-Med logprob + gen ──────────────────────────────────────
lm_dir = _resolve_mimic_results_dir("_llavamed")
if lm_dir:
edir = lm_dir / example_id
for prefix, flag_key, json_key in [
("vllm_logprob", "has_llavamed_logprob", "llavamed_logprob"),
("vllm_gen", "has_llavamed_gen", "llavamed_gen"),
]:
summary_path = edir / f"{prefix}_summary.txt"
if summary_path.exists():
parsed = parse_vllm_summary(summary_path)
if parsed:
json_data = load_vllm_result_json(edir, prefix, method=method)
apply_method_to_vllm_summary(parsed, json_data, method)
data[flag_key] = True
data[json_key] = parsed
data[f"{json_key}_json"] = json_data
data["has_mobius"][json_key] = bool(json_data.get("mobius_dict"))
overlay = edir / f"{prefix}_overlay.png"
if overlay.exists():
data[json_key]["overlay_path"] = str(overlay)
# ── VLM UnSAM (Qwen2-VL + UnSAM segments) ──────────────────────
vu_dir = _resolve_mimic_results_dir("_vlm_unsam")
if vu_dir:
edir = vu_dir / example_id
for prefix, flag_key, json_key in [
("vllm_logprob", "has_vlm_unsam_logprob", "vlm_unsam_logprob"),
("vllm_gen", "has_vlm_unsam_gen", "vlm_unsam_gen"),
]:
summary_path = edir / f"{prefix}_summary.txt"
if summary_path.exists():
parsed = parse_vllm_summary(summary_path)
if parsed:
json_data = load_vllm_result_json(edir, prefix, method=method)
apply_method_to_vllm_summary(parsed, json_data, method)
data[flag_key] = True
data[json_key] = parsed
data[f"{json_key}_json"] = json_data
data["has_mobius"][json_key] = bool(json_data.get("mobius_dict"))
overlay = edir / f"{prefix}_overlay.png"
if overlay.exists():
data[json_key]["overlay_path"] = str(overlay)
segmap = edir / "segmap.png"
original = edir / "original.png"
if segmap.exists():
data["vlm_unsam_segmap_path"] = str(segmap)
if original.exists():
data["vlm_unsam_original_path"] = str(original)
# ── LLaVA-Med UnSAM ─────────────────────────────────────────────
lu_dir = _resolve_mimic_results_dir("_llavamed_unsam")
if lu_dir:
edir = lu_dir / example_id
for prefix, flag_key, json_key in [
("vllm_logprob", "has_llavamed_unsam_logprob", "llavamed_unsam_logprob"),
("vllm_gen", "has_llavamed_unsam_gen", "llavamed_unsam_gen"),
]:
summary_path = edir / f"{prefix}_summary.txt"
if summary_path.exists():
parsed = parse_vllm_summary(summary_path)
if parsed:
json_data = load_vllm_result_json(edir, prefix, method=method)
apply_method_to_vllm_summary(parsed, json_data, method)
data[flag_key] = True
data[json_key] = parsed
data[f"{json_key}_json"] = json_data
data["has_mobius"][json_key] = bool(json_data.get("mobius_dict"))
overlay = edir / f"{prefix}_overlay.png"
if overlay.exists():
data[json_key]["overlay_path"] = str(overlay)
segmap = edir / "segmap.png"
original = edir / "original.png"
if segmap.exists():
data["llavamed_unsam_segmap_path"] = str(segmap)
if original.exists():
data["llavamed_unsam_original_path"] = str(original)
return data