cxr-vlm-code / data /iu_xray_builder.py
convitom
feat(data): add split_cascade report mode + MIMIC-CXR builder with CheXpert oracle labels
02426e6
"""
iu_xray_builder.py
------------------
Parses the IU X-ray (OpenI) dataset into a unified instruction JSON
compatible with `CXRInstructDataset`.
Why a separate builder?
IU X-ray ships as:
images/CXR{rid}_{n}_IM-{s}-{v}.png (7,470 PNGs)
labels/ecgen-radiology/{rid}.xml (3,955 XMLs, each covers 1–N images)
We flatten the XMLs into per-(image, task) samples. An XML with two
`<parentImage>` children produces two samples per task (findings,
impression): the ground-truth text is shared across images of the
same study — same convention MIMIC-CXR uses.
IU X-ray has NO VQA, so only two tasks are emitted: findings, impression.
Output JSON entry (matches MIMIC-CXR output of build_instruct_json):
{
"image_path": "CXR1_1_IM-0001-3001.png", # relative to images_dir
"task": "findings" | "impression",
"target": "<report text>",
"question": null,
"structured_findings": null,
"split": "train" | "validate" | "test",
"report_id": "1"
}
Splitting is done at the REPORT level (not image level) to avoid leakage:
all images belonging to the same study land in the same split.
"""
import argparse
import glob
import json
import random
from pathlib import Path
from typing import Dict, List, Optional
from xml.etree import ElementTree as ET
# ─── XML helpers ────────────────────────────────────────────────────────────
def _extract_sections(root) -> Dict[str, str]:
"""Return a dict mapping section label (uppercase) -> text."""
sections = {}
for at in root.findall(".//AbstractText"):
label = at.attrib.get("Label", "").upper()
text = (at.text or "").strip()
sections[label] = text
return sections
def _extract_image_ids(root) -> List[str]:
return [p.attrib["id"] for p in root.findall(".//parentImage")]
def _is_valid_text(text: Optional[str]) -> bool:
"""Reject empty / placeholder-only / trivially-short reports."""
if not text:
return False
# IU reports use "XXXX" as anonymization tokens. If everything is X, skip.
stripped = text.replace("X", "").replace(".", "").replace(",", "").strip()
return len(stripped) >= 3
# ─── Main builder ───────────────────────────────────────────────────────────
def build_iu_xray_instruct_json(
images_dir: str,
labels_dir: str,
output_path: str,
train_ratio: float = 0.70,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
seed: int = 42,
image_suffix: str = ".png",
report_mode: str = "split", # "split" | "merged" | "split_cascade"
image_mode: str = "all_views_split", # "all_views_split" | "frontal_only_split" | "multi_image_merged"
) -> str:
"""
Parse all IU X-ray XMLs and emit the unified JSON.
Args:
report_mode: "split" → emit 2 samples per image (task=findings, task=impression).
Original behaviour. Use when training two separate tasks.
"merged" → emit 1 sample per image (task=report) with target
"Findings: ...\n\nImpression: ...". Use when training a
single full-report generation task. Samples with only
findings are dropped (no impression to anchor on).
"split_cascade" → like "split" (2 separate tasks) BUT the
impression sample carries the ground-truth findings
text as its prompt context (in `structured_findings`,
formatted "Findings: ...") instead of CheXpert
labels. Impression thus learns findings→impression
summarisation while still seeing the image. Only
studies with BOTH findings and impression emit an
impression sample (findings is its required input).
NOTE: eval is teacher-forced (impression gets GT
findings); a true cascade eval that feeds the
model's own generated findings is future work.
Returns:
Absolute path to output JSON.
"""
assert report_mode in ("split", "merged", "split_cascade"), \
f"report_mode must be 'split', 'merged', or 'split_cascade', got {report_mode!r}"
assert image_mode in ("all_views_split", "frontal_only_split", "multi_image_merged"), \
f"image_mode must be one of all_views_split/frontal_only_split/multi_image_merged, got {image_mode!r}"
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
"train/val/test ratios must sum to 1.0"
images_dir = Path(images_dir)
labels_dir = Path(labels_dir)
output_path = Path(output_path)
xml_files = sorted(
glob.glob(str(labels_dir / "*.xml")),
key=lambda p: int(Path(p).stem) if Path(p).stem.isdigit() else 10**9,
)
if not xml_files:
raise FileNotFoundError(f"No XML files under {labels_dir}")
# ── Pass 1: parse XMLs, keep reports with ≥1 valid section & ≥1 image ──
reports: List[Dict] = []
skipped_no_text = 0
skipped_no_image = 0
for xml_path in xml_files:
try:
tree = ET.parse(xml_path)
except ET.ParseError:
continue
root = tree.getroot()
sections = _extract_sections(root)
findings = sections.get("FINDINGS", "").strip()
impression = sections.get("IMPRESSION", "").strip()
has_find = _is_valid_text(findings)
has_imp = _is_valid_text(impression)
if not (has_find or has_imp):
skipped_no_text += 1
continue
img_ids = _extract_image_ids(root)
existing = [iid for iid in img_ids
if (images_dir / f"{iid}{image_suffix}").is_file()]
if not existing:
skipped_no_image += 1
continue
reports.append({
"report_id": Path(xml_path).stem,
"images": existing,
"findings": findings if has_find else None,
"impression": impression if has_imp else None,
})
# ── Pass 2: assign splits at the report level ─────────────────────────
rng = random.Random(seed)
rng.shuffle(reports)
n_total = len(reports)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
# Remainder → test (absorbs rounding error)
split_labels = (
["train"] * n_train +
["validate"] * n_val +
["test"] * (n_total - n_train - n_val)
)
# ── Pass 3: expand to samples ─────────────────────────────────────────
# Two orthogonal axes control sample shape:
#
# report_mode (task axis):
# "split" → 2 samples/study/image (findings + impression)
# "merged" → 1 sample/study/image (task=report)
#
# image_mode (image axis):
# "all_views_split" → 1 sample PER IMAGE (current behaviour)
# "frontal_only_split" → 1 sample per study, frontal view only
# "multi_image_merged" → 1 sample per study, image_paths is a LIST
from .dataset import format_merged_report # local import to avoid cycle at module load
samples: List[Dict] = []
skipped_merged_no_impression = 0
skipped_cascade_no_findings = 0
def _per_study_image_groups(report_imgs):
"""
Yield (sample_id_suffix, image_path_or_list) for this study, applying
the selected image_mode.
sample_id_suffix is only used for logging; it has no functional effect.
"""
if image_mode == "all_views_split":
for img_id in report_imgs:
yield img_id, f"{img_id}{image_suffix}"
elif image_mode == "frontal_only_split":
# IU X-ray convention: the FIRST <parentImage> in the XML is the
# frontal (PA) view. We rely on this rather than parsing DICOM
# ViewPosition (which is not provided in IU X-ray XMLs). For
# MIMIC-CXR, swap this for a ViewPosition lookup against the
# metadata CSV.
yield report_imgs[0], f"{report_imgs[0]}{image_suffix}"
else: # multi_image_merged
paths = [f"{iid}{image_suffix}" for iid in report_imgs]
yield report_imgs[0], paths
for report, split in zip(reports, split_labels):
for sid, image_payload in _per_study_image_groups(report["images"]):
# `image_path` stays a string in the two single-image modes so
# existing dataloader code keeps working unchanged. In multi-image
# mode we instead set `image_paths` (a list) and leave
# `image_path` empty — dataset.py knows to pick whichever is set.
if isinstance(image_payload, list):
path_fields = {"image_path": None, "image_paths": image_payload}
else:
path_fields = {"image_path": image_payload, "image_paths": None}
if report_mode == "merged":
target = format_merged_report(report["findings"], report["impression"])
if target is None:
skipped_merged_no_impression += 1
continue
samples.append({
**path_fields,
"task": "report",
"target": target,
"question": None,
"structured_findings": None,
"split": split,
"report_id": report["report_id"],
})
elif report_mode == "split_cascade":
# findings sample: identical to "split".
if report["findings"] is not None:
samples.append({
**path_fields,
"task": "findings",
"target": report["findings"],
"question": None,
"structured_findings": None,
"split": split,
"report_id": report["report_id"],
})
# impression sample: needs findings as its prompt context, so
# only emit when BOTH sections exist. The GT findings ride in
# `structured_findings` (same plumbing CheXpert labels use) so
# train (dataset.py) and eval (evaluate.py) pick it up with no
# other code changes.
if report["impression"] is not None:
if report["findings"] is None:
skipped_cascade_no_findings += 1
else:
samples.append({
**path_fields,
"task": "impression",
"target": report["impression"],
"question": None,
"structured_findings": f"Findings: {report['findings'].strip()}",
"split": split,
"report_id": report["report_id"],
})
else: # "split"
for task_name, text in (
("findings", report["findings"]),
("impression", report["impression"]),
):
if text is None:
continue
samples.append({
**path_fields,
"task": task_name,
"target": text,
"question": None,
"structured_findings": None,
"split": split,
"report_id": report["report_id"],
})
# ── Write JSON ────────────────────────────────────────────────────────
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(samples, f, ensure_ascii=False, indent=2)
# ── Log summary ───────────────────────────────────────────────────────
by_split, by_task = {}, {}
for s in samples:
by_split[s["split"]] = by_split.get(s["split"], 0) + 1
by_task[s["task"]] = by_task.get(s["task"], 0) + 1
print(f"[iu_xray_builder] wrote {len(samples)} samples → {output_path}")
print(f" report_mode : {report_mode}")
print(f" image_mode : {image_mode}")
print(f" XMLs scanned : {len(xml_files)}")
print(f" reports kept : {n_total}")
print(f" skipped no_text : {skipped_no_text}")
print(f" skipped no_image : {skipped_no_image}")
if report_mode == "merged":
print(f" skipped no_impr : {skipped_merged_no_impression}")
if report_mode == "split_cascade":
print(f" skipped impr w/o findings : {skipped_cascade_no_findings}")
print(f" by split : {by_split}")
print(f" by task : {by_task}")
return str(output_path)
# ─── CLI ────────────────────────────────────────────────────────────────────
def _parse_args():
p = argparse.ArgumentParser(description="Build IU X-ray unified instruction JSON")
p.add_argument("--images_dir", required=True,
help="Folder with CXR*.png files")
p.add_argument("--labels_dir", required=True,
help="Folder with {id}.xml files (ecgen-radiology)")
p.add_argument("--output", required=True,
help="Output JSON path")
p.add_argument("--train_ratio", type=float, default=0.70)
p.add_argument("--val_ratio", type=float, default=0.15)
p.add_argument("--test_ratio", type=float, default=0.15)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--image_suffix", type=str, default=".png")
p.add_argument("--report_mode", type=str, default="split",
choices=["split", "merged", "split_cascade"],
help="split: 2 samples/img (findings + impression). "
"merged: 1 sample/img with combined target. "
"split_cascade: like split, but impression sample's "
"prompt context = GT findings text (findings→impression).")
p.add_argument("--image_mode", type=str, default="all_views_split",
choices=["all_views_split", "frontal_only_split", "multi_image_merged"],
help="all_views_split: 1 sample per image. "
"frontal_only_split: 1 sample per study (frontal only). "
"multi_image_merged: 1 sample per study with list of views.")
return p.parse_args()
if __name__ == "__main__":
args = _parse_args()
build_iu_xray_instruct_json(
images_dir = args.images_dir,
labels_dir = args.labels_dir,
output_path = args.output,
train_ratio = args.train_ratio,
val_ratio = args.val_ratio,
test_ratio = args.test_ratio,
seed = args.seed,
image_suffix = args.image_suffix,
report_mode = args.report_mode,
image_mode = args.image_mode,
)