File size: 16,419 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
02426e6
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
28b13fc
02426e6
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""
iu_xray_builder.py
------------------
Parses the IU X-ray (OpenI) dataset into a unified instruction JSON
compatible with `CXRInstructDataset`.

Why a separate builder?
  IU X-ray ships as:
    images/CXR{rid}_{n}_IM-{s}-{v}.png   (7,470 PNGs)
    labels/ecgen-radiology/{rid}.xml     (3,955 XMLs, each covers 1–N images)
  We flatten the XMLs into per-(image, task) samples. An XML with two
  `<parentImage>` children produces two samples per task (findings,
  impression): the ground-truth text is shared across images of the
  same study — same convention MIMIC-CXR uses.

IU X-ray has NO VQA, so only two tasks are emitted: findings, impression.

Output JSON entry (matches MIMIC-CXR output of build_instruct_json):
    {
        "image_path":          "CXR1_1_IM-0001-3001.png",   # relative to images_dir
        "task":                "findings" | "impression",
        "target":              "<report text>",
        "question":            null,
        "structured_findings": null,
        "split":               "train" | "validate" | "test",
        "report_id":           "1"
    }

Splitting is done at the REPORT level (not image level) to avoid leakage:
all images belonging to the same study land in the same split.
"""

import argparse
import glob
import json
import random
from pathlib import Path
from typing import Dict, List, Optional
from xml.etree import ElementTree as ET


# ─── XML helpers ────────────────────────────────────────────────────────────

def _extract_sections(root) -> Dict[str, str]:
    """Return a dict mapping section label (uppercase) -> text."""
    sections = {}
    for at in root.findall(".//AbstractText"):
        label = at.attrib.get("Label", "").upper()
        text  = (at.text or "").strip()
        sections[label] = text
    return sections


def _extract_image_ids(root) -> List[str]:
    return [p.attrib["id"] for p in root.findall(".//parentImage")]


def _is_valid_text(text: Optional[str]) -> bool:
    """Reject empty / placeholder-only / trivially-short reports."""
    if not text:
        return False
    # IU reports use "XXXX" as anonymization tokens. If everything is X, skip.
    stripped = text.replace("X", "").replace(".", "").replace(",", "").strip()
    return len(stripped) >= 3


# ─── Main builder ───────────────────────────────────────────────────────────

def build_iu_xray_instruct_json(
    images_dir:   str,
    labels_dir:   str,
    output_path:  str,
    train_ratio:  float = 0.70,
    val_ratio:    float = 0.15,
    test_ratio:   float = 0.15,
    seed:         int   = 42,
    image_suffix: str   = ".png",
    report_mode:  str   = "split",                   # "split" | "merged" | "split_cascade"
    image_mode:   str   = "all_views_split",         # "all_views_split" | "frontal_only_split" | "multi_image_merged"
) -> str:
    """
    Parse all IU X-ray XMLs and emit the unified JSON.

    Args:
        report_mode: "split"  → emit 2 samples per image (task=findings, task=impression).
                                Original behaviour. Use when training two separate tasks.
                     "merged" → emit 1 sample per image (task=report) with target
                                "Findings: ...\n\nImpression: ...". Use when training a
                                single full-report generation task. Samples with only
                                findings are dropped (no impression to anchor on).
                     "split_cascade" → like "split" (2 separate tasks) BUT the
                                impression sample carries the ground-truth findings
                                text as its prompt context (in `structured_findings`,
                                formatted "Findings: ...") instead of CheXpert
                                labels. Impression thus learns findings→impression
                                summarisation while still seeing the image. Only
                                studies with BOTH findings and impression emit an
                                impression sample (findings is its required input).
                                NOTE: eval is teacher-forced (impression gets GT
                                findings); a true cascade eval that feeds the
                                model's own generated findings is future work.

    Returns:
        Absolute path to output JSON.
    """
    assert report_mode in ("split", "merged", "split_cascade"), \
        f"report_mode must be 'split', 'merged', or 'split_cascade', got {report_mode!r}"
    assert image_mode in ("all_views_split", "frontal_only_split", "multi_image_merged"), \
        f"image_mode must be one of all_views_split/frontal_only_split/multi_image_merged, got {image_mode!r}"
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
        "train/val/test ratios must sum to 1.0"

    images_dir = Path(images_dir)
    labels_dir = Path(labels_dir)
    output_path = Path(output_path)

    xml_files = sorted(
        glob.glob(str(labels_dir / "*.xml")),
        key=lambda p: int(Path(p).stem) if Path(p).stem.isdigit() else 10**9,
    )
    if not xml_files:
        raise FileNotFoundError(f"No XML files under {labels_dir}")

    # ── Pass 1: parse XMLs, keep reports with ≥1 valid section & ≥1 image ──
    reports: List[Dict] = []
    skipped_no_text  = 0
    skipped_no_image = 0

    for xml_path in xml_files:
        try:
            tree = ET.parse(xml_path)
        except ET.ParseError:
            continue
        root = tree.getroot()

        sections    = _extract_sections(root)
        findings    = sections.get("FINDINGS",   "").strip()
        impression  = sections.get("IMPRESSION", "").strip()
        has_find    = _is_valid_text(findings)
        has_imp     = _is_valid_text(impression)

        if not (has_find or has_imp):
            skipped_no_text += 1
            continue

        img_ids = _extract_image_ids(root)
        existing = [iid for iid in img_ids
                    if (images_dir / f"{iid}{image_suffix}").is_file()]
        if not existing:
            skipped_no_image += 1
            continue

        reports.append({
            "report_id": Path(xml_path).stem,
            "images":    existing,
            "findings":  findings   if has_find else None,
            "impression": impression if has_imp  else None,
        })

    # ── Pass 2: assign splits at the report level ─────────────────────────
    rng = random.Random(seed)
    rng.shuffle(reports)

    n_total = len(reports)
    n_train = int(n_total * train_ratio)
    n_val   = int(n_total * val_ratio)
    # Remainder → test (absorbs rounding error)
    split_labels = (
        ["train"]    * n_train +
        ["validate"] * n_val +
        ["test"]     * (n_total - n_train - n_val)
    )

    # ── Pass 3: expand to samples ─────────────────────────────────────────
    # Two orthogonal axes control sample shape:
    #
    #   report_mode (task axis):
    #     "split"  → 2 samples/study/image (findings + impression)
    #     "merged" → 1 sample/study/image (task=report)
    #
    #   image_mode (image axis):
    #     "all_views_split"     → 1 sample PER IMAGE (current behaviour)
    #     "frontal_only_split"  → 1 sample per study, frontal view only
    #     "multi_image_merged"  → 1 sample per study, image_paths is a LIST
    from .dataset import format_merged_report   # local import to avoid cycle at module load

    samples: List[Dict] = []
    skipped_merged_no_impression = 0
    skipped_cascade_no_findings  = 0

    def _per_study_image_groups(report_imgs):
        """
        Yield (sample_id_suffix, image_path_or_list) for this study, applying
        the selected image_mode.

        sample_id_suffix is only used for logging; it has no functional effect.
        """
        if image_mode == "all_views_split":
            for img_id in report_imgs:
                yield img_id, f"{img_id}{image_suffix}"
        elif image_mode == "frontal_only_split":
            # IU X-ray convention: the FIRST <parentImage> in the XML is the
            # frontal (PA) view. We rely on this rather than parsing DICOM
            # ViewPosition (which is not provided in IU X-ray XMLs). For
            # MIMIC-CXR, swap this for a ViewPosition lookup against the
            # metadata CSV.
            yield report_imgs[0], f"{report_imgs[0]}{image_suffix}"
        else:  # multi_image_merged
            paths = [f"{iid}{image_suffix}" for iid in report_imgs]
            yield report_imgs[0], paths

    for report, split in zip(reports, split_labels):
        for sid, image_payload in _per_study_image_groups(report["images"]):
            # `image_path` stays a string in the two single-image modes so
            # existing dataloader code keeps working unchanged. In multi-image
            # mode we instead set `image_paths` (a list) and leave
            # `image_path` empty — dataset.py knows to pick whichever is set.
            if isinstance(image_payload, list):
                path_fields = {"image_path": None, "image_paths": image_payload}
            else:
                path_fields = {"image_path": image_payload, "image_paths": None}

            if report_mode == "merged":
                target = format_merged_report(report["findings"], report["impression"])
                if target is None:
                    skipped_merged_no_impression += 1
                    continue
                samples.append({
                    **path_fields,
                    "task":                "report",
                    "target":              target,
                    "question":            None,
                    "structured_findings": None,
                    "split":               split,
                    "report_id":           report["report_id"],
                })
            elif report_mode == "split_cascade":
                # findings sample: identical to "split".
                if report["findings"] is not None:
                    samples.append({
                        **path_fields,
                        "task":                "findings",
                        "target":              report["findings"],
                        "question":            None,
                        "structured_findings": None,
                        "split":               split,
                        "report_id":           report["report_id"],
                    })
                # impression sample: needs findings as its prompt context, so
                # only emit when BOTH sections exist. The GT findings ride in
                # `structured_findings` (same plumbing CheXpert labels use) so
                # train (dataset.py) and eval (evaluate.py) pick it up with no
                # other code changes.
                if report["impression"] is not None:
                    if report["findings"] is None:
                        skipped_cascade_no_findings += 1
                    else:
                        samples.append({
                            **path_fields,
                            "task":                "impression",
                            "target":              report["impression"],
                            "question":            None,
                            "structured_findings": f"Findings: {report['findings'].strip()}",
                            "split":               split,
                            "report_id":           report["report_id"],
                        })
            else:  # "split"
                for task_name, text in (
                    ("findings",   report["findings"]),
                    ("impression", report["impression"]),
                ):
                    if text is None:
                        continue
                    samples.append({
                        **path_fields,
                        "task":                task_name,
                        "target":              text,
                        "question":            None,
                        "structured_findings": None,
                        "split":               split,
                        "report_id":           report["report_id"],
                    })

    # ── Write JSON ────────────────────────────────────────────────────────
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(samples, f, ensure_ascii=False, indent=2)

    # ── Log summary ───────────────────────────────────────────────────────
    by_split, by_task = {}, {}
    for s in samples:
        by_split[s["split"]] = by_split.get(s["split"], 0) + 1
        by_task[s["task"]]   = by_task.get(s["task"],  0) + 1

    print(f"[iu_xray_builder] wrote {len(samples)} samples → {output_path}")
    print(f"  report_mode      : {report_mode}")
    print(f"  image_mode       : {image_mode}")
    print(f"  XMLs scanned     : {len(xml_files)}")
    print(f"  reports kept     : {n_total}")
    print(f"  skipped no_text  : {skipped_no_text}")
    print(f"  skipped no_image : {skipped_no_image}")
    if report_mode == "merged":
        print(f"  skipped no_impr  : {skipped_merged_no_impression}")
    if report_mode == "split_cascade":
        print(f"  skipped impr w/o findings : {skipped_cascade_no_findings}")
    print(f"  by split         : {by_split}")
    print(f"  by task          : {by_task}")

    return str(output_path)


# ─── CLI ────────────────────────────────────────────────────────────────────

def _parse_args():
    p = argparse.ArgumentParser(description="Build IU X-ray unified instruction JSON")
    p.add_argument("--images_dir",  required=True,
                   help="Folder with CXR*.png files")
    p.add_argument("--labels_dir",  required=True,
                   help="Folder with {id}.xml files (ecgen-radiology)")
    p.add_argument("--output",      required=True,
                   help="Output JSON path")
    p.add_argument("--train_ratio", type=float, default=0.70)
    p.add_argument("--val_ratio",   type=float, default=0.15)
    p.add_argument("--test_ratio",  type=float, default=0.15)
    p.add_argument("--seed",        type=int,   default=42)
    p.add_argument("--image_suffix", type=str,  default=".png")
    p.add_argument("--report_mode", type=str,   default="split",
                   choices=["split", "merged", "split_cascade"],
                   help="split: 2 samples/img (findings + impression). "
                        "merged: 1 sample/img with combined target. "
                        "split_cascade: like split, but impression sample's "
                        "prompt context = GT findings text (findings→impression).")
    p.add_argument("--image_mode",  type=str,   default="all_views_split",
                   choices=["all_views_split", "frontal_only_split", "multi_image_merged"],
                   help="all_views_split: 1 sample per image. "
                        "frontal_only_split: 1 sample per study (frontal only). "
                        "multi_image_merged: 1 sample per study with list of views.")
    return p.parse_args()


if __name__ == "__main__":
    args = _parse_args()
    build_iu_xray_instruct_json(
        images_dir   = args.images_dir,
        labels_dir   = args.labels_dir,
        output_path  = args.output,
        train_ratio  = args.train_ratio,
        val_ratio    = args.val_ratio,
        test_ratio   = args.test_ratio,
        seed         = args.seed,
        image_suffix = args.image_suffix,
        report_mode  = args.report_mode,
        image_mode   = args.image_mode,
    )