convitom Claude Sonnet 4.6 commited on
Commit
215ecd6
·
1 Parent(s): f965a7f

feat(chexpert): U-MultiClass PNU abnormality guidance + abnormality-guided VQA

Browse files

- chexpert_classifier.py: 14 binary heads → 14×3 softmax (negative/
positive/uncertain per pathology, META-CXR / CheXpert U-MultiClass).
Add format_pnu/buckets_to_pnu as the single source of truth for the
PNU 3-section prompt string (shared with the oracle builder so GT and
predicted prompts are byte-identical).
- mimic_cxr_builder.py: GT chexpert.csv → PNU string (1→pos, 0→neg,
-1→uncertain, blank/NaN→neg). VQA now carries the SAME PNU context
(abnormality-guided VQA, RaDialog-style). O(1) image lookup (was
O(N²)); index every image so report-less studies still serve VQA.
- Drop the obsolete uncertain_policy knob (U-MultiClass is the only
behaviour now) from builder, dataset.py, dataset_resolver, config.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

configs/train_config.yaml CHANGED
@@ -59,15 +59,14 @@ data:
59
  mimic_cxr_root: "/path/to/MIMIC-CXR"
60
  instruct_json: "data/data_files/mimic_cxr_instruct_unified.json"
61
 
62
- # RaDialog abnormality guidance: the 14 CheXpert labels (oracle / GT) are
63
- # read from this CSV and baked into the prompt as
64
- # "Predicted Findings: ...". If left null the builder auto-discovers any
 
 
65
  # *chexpert*.csv under mimic_cxr_root; if none is found, structured_findings
66
  # is null and abnormality guidance is silently DISABLED (loud warning).
67
  mimic_chexpert_csv: null
68
- # How CheXpert -1.0 (uncertain) is mapped: "ignore" (only 1.0 positive,
69
- # default, matches the classifier head) | "positive" (treat -1.0 as positive).
70
- mimic_uncertain_policy: "ignore"
71
  # Optional VQA pairs dir with {train,valid,test}.json. null → skip VQA.
72
  mimic_vqa_root: null
73
  # Auto-build the unified JSON (with CheXpert labels) when the cached
 
59
  mimic_cxr_root: "/path/to/MIMIC-CXR"
60
  instruct_json: "data/data_files/mimic_cxr_instruct_unified.json"
61
 
62
+ # RaDialog abnormality guidance (U-MultiClass / META-CXR): the 14 CheXpert
63
+ # labels (oracle / GT) are read from this CSV and baked into the prompt as
64
+ # the PNU 3-section string ("Positive Abnormalities: ... / Negative ... /
65
+ # Uncertain ..."). CSV value → class: 1→positive, 0→negative, -1→uncertain,
66
+ # blank/NaN→negative. If left null the builder auto-discovers any
67
  # *chexpert*.csv under mimic_cxr_root; if none is found, structured_findings
68
  # is null and abnormality guidance is silently DISABLED (loud warning).
69
  mimic_chexpert_csv: null
 
 
 
70
  # Optional VQA pairs dir with {train,valid,test}.json. null → skip VQA.
71
  mimic_vqa_root: null
72
  # Auto-build the unified JSON (with CheXpert labels) when the cached
data/dataset.py CHANGED
@@ -26,7 +26,7 @@ from torch.utils.data import Dataset
26
  from PIL import Image
27
 
28
  from .prompt_templates import build_training_sample
29
- from model.image_encoder import BioViLTEncoder
30
 
31
 
32
  TaskType = Literal["findings", "impression", "report", "vqa", "mixed"]
@@ -289,7 +289,6 @@ def build_instruct_json(
289
  vqa_data_root: Optional[str] = None,
290
  report_mode: str = "split",
291
  image_mode: str = "all_views_split",
292
- uncertain_policy: str = "ignore",
293
  ) -> str:
294
  """
295
  Build the unified MIMIC-CXR instruction JSON.
@@ -297,25 +296,27 @@ def build_instruct_json(
297
  Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`,
298
  which walks the pre-split MIMIC layout (train/valid/test), parses
299
  findings/impression from the report .txt files, and bakes the 14 CheXpert
300
- labels (oracle, from `*chexpert*.csv`) into `structured_findings` as
301
- "Predicted Findings: ..." the RaDialog image + abnormality-guidance
302
- setup. `report_mode` / `image_mode` mirror the IU builder.
 
303
 
304
  Output entries match the shared schema, e.g.:
305
  {"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg",
306
  "task": "findings", "target": "The lungs are clear...",
307
  "question": null,
308
- "structured_findings": "Predicted Findings: No Finding",
 
 
309
  "split": "train", "study_id": "s50414267",
310
  "subject_id": "p10000032"}
311
  """
312
  from .mimic_cxr_builder import build_mimic_cxr_instruct_json
313
  return build_mimic_cxr_instruct_json(
314
- mimic_root = mimic_cxr_root,
315
- output_path = output_path,
316
- chexpert_csv = chexpert_csv,
317
- vqa_root = vqa_data_root,
318
- report_mode = report_mode,
319
- image_mode = image_mode,
320
- uncertain_policy = uncertain_policy,
321
  )
 
26
  from PIL import Image
27
 
28
  from .prompt_templates import build_training_sample
29
+ from model.rad_dino import BioViLTEncoder
30
 
31
 
32
  TaskType = Literal["findings", "impression", "report", "vqa", "mixed"]
 
289
  vqa_data_root: Optional[str] = None,
290
  report_mode: str = "split",
291
  image_mode: str = "all_views_split",
 
292
  ) -> str:
293
  """
294
  Build the unified MIMIC-CXR instruction JSON.
 
296
  Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`,
297
  which walks the pre-split MIMIC layout (train/valid/test), parses
298
  findings/impression from the report .txt files, and bakes the 14 CheXpert
299
+ labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the
300
+ PNU 3-section string (U-MultiClass, META-CXR format) the RaDialog
301
+ image + abnormality-guidance setup. `report_mode` / `image_mode` mirror
302
+ the IU builder.
303
 
304
  Output entries match the shared schema, e.g.:
305
  {"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg",
306
  "task": "findings", "target": "The lungs are clear...",
307
  "question": null,
308
+ "structured_findings": "Positive Abnormalities: None\\n
309
+ Negative Abnormalities: No Finding, ...\\n
310
+ Uncertain Abnormalities: None",
311
  "split": "train", "study_id": "s50414267",
312
  "subject_id": "p10000032"}
313
  """
314
  from .mimic_cxr_builder import build_mimic_cxr_instruct_json
315
  return build_mimic_cxr_instruct_json(
316
+ mimic_root = mimic_cxr_root,
317
+ output_path = output_path,
318
+ chexpert_csv = chexpert_csv,
319
+ vqa_root = vqa_data_root,
320
+ report_mode = report_mode,
321
+ image_mode = image_mode,
 
322
  )
data/mimic_cxr_builder.py CHANGED
@@ -15,25 +15,34 @@ NOT the raw PhysioNet tree):
15
  └── test /pNN/...
16
  {anywhere under mimic_root}/ *chexpert*.csv (optional, auto-discovered)
17
 
18
- RaDialog-style abnormality guidance
19
- -----------------------------------
20
  The 14 CheXpert labels are read from `mimic-cxr-2.0.0-chexpert.csv`
21
  (CheXbert run on the ground-truth reports) and baked into the prompt as
22
- `structured_findings`:
23
 
24
- "Predicted Findings: Cardiomegaly, Pleural Effusion"
25
- "Predicted Findings: No Finding" (when no positive label)
 
 
 
 
26
 
27
  This is the *oracle* setting — GT labels, no trained image classifier and
28
- no model change. The CheXpert classifier module stays unused; the existing
29
- `structured_findings` prompt plumbing carries the string through train
30
- (dataset.py) and eval (evaluate.py) untouched.
 
 
31
 
32
  VQA
33
  ---
34
- VQA pairs live in a separate dataset and are attached by passing
35
- `vqa_root` (mirrors the notebook). Omit it to build findings/impression
36
- only.
 
 
 
37
  """
38
 
39
  import argparse
@@ -50,23 +59,11 @@ from typing import Dict, List, Optional, Tuple
50
  _FINDINGS_RE = re.compile(r"FINDINGS\s*:\s*(.*?)(?=\n\s*[A-Z ]{3,}\s*:|\Z)", re.S | re.I)
51
  _IMPRESSION_RE = re.compile(r"IMPRESSION\s*:\s*(.*?)(?=\n\s*[A-Z ]{3,}\s*:|\Z)", re.S | re.I)
52
 
53
- # 14 CheXpert columns, in the canonical order used by the classifier head.
54
- CHEXPERT_LABELS = [
55
- "No Finding",
56
- "Enlarged Cardiomediastinum",
57
- "Cardiomegaly",
58
- "Lung Opacity",
59
- "Lung Lesion",
60
- "Edema",
61
- "Consolidation",
62
- "Pneumonia",
63
- "Atelectasis",
64
- "Pneumothorax",
65
- "Pleural Effusion",
66
- "Pleural Other",
67
- "Fracture",
68
- "Support Devices",
69
- ]
70
 
71
 
72
  def _clean(txt: str) -> str:
@@ -83,7 +80,7 @@ def _parse_report(txt_path: Path) -> Tuple[Optional[str], Optional[str]]:
83
  )
84
 
85
 
86
- # ─── CheXpert CSV → "Predicted Findings: ..." string ────────────────────────
87
 
88
  def _discover_chexpert_csv(mimic_root: Path, explicit: Optional[str]) -> Optional[Path]:
89
  if explicit:
@@ -97,17 +94,26 @@ def _discover_chexpert_csv(mimic_root: Path, explicit: Optional[str]) -> Optiona
97
  return None
98
 
99
 
100
- def _load_chexpert_map(
101
- csv_path: Path,
102
- uncertain_policy: str = "ignore", # "ignore" → only 1.0 positive | "positive" → -1.0 also positive
103
- ) -> Dict[Tuple[str, str], str]:
104
  """
105
- Return {(subject_id, study_id): "Predicted Findings: A, B"} where the ids
106
- are the bare integers as strings (CSV stores them without the p/s prefix).
 
 
 
 
 
 
107
  """
108
- pos_threshold = {"1", "1.0"}
109
- if uncertain_policy == "positive":
110
- pos_threshold = pos_threshold | {"-1", "-1.0"}
 
 
 
 
 
 
111
 
112
  out: Dict[Tuple[str, str], str] = {}
113
  with open(csv_path, newline="") as f:
@@ -121,25 +127,17 @@ def _load_chexpert_map(
121
  f"{csv_path} missing subject_id/study_id columns "
122
  f"(have: {reader.fieldnames})"
123
  )
124
- label_cols = [(name, col[name.lower()]) for name in CHEXPERT_LABELS
125
  if name.lower() in col]
126
 
127
  for row in reader:
128
  subj = str(row[subj_c]).strip().lstrip("p").split(".")[0]
129
  study = str(row[study_c]).strip().lstrip("s").split(".")[0]
130
- positives = [
131
- name for name, c in label_cols
132
- if str(row.get(c, "")).strip() in pos_threshold
133
- ]
134
- # "No Finding" alone is reported as such; otherwise list the
135
- # genuine positives (drop a redundant "No Finding" if any
136
- # pathology is also positive).
137
- real = [p for p in positives if p != "No Finding"]
138
- if real:
139
- txt = ", ".join(real)
140
- else:
141
- txt = "No Finding"
142
- out[(subj, study)] = f"Predicted Findings: {txt}"
143
  return out
144
 
145
 
@@ -152,18 +150,17 @@ def build_mimic_cxr_instruct_json(
152
  vqa_root: Optional[str] = None,
153
  report_mode: str = "split", # "split" | "merged" | "split_cascade"
154
  image_mode: str = "all_views_split", # "all_views_split" | "frontal_only_split" | "multi_image_merged"
155
- uncertain_policy: str = "ignore", # how CheXpert -1.0 (uncertain) is treated
156
  ) -> str:
157
  """
158
  Build the unified MIMIC-CXR instruction JSON.
159
 
160
  report_mode mirrors iu_xray_builder:
161
  "split" → findings + impression samples; BOTH carry the CheXpert
162
- "Predicted Findings: ..." string in structured_findings
163
- (RaDialog: image + 14 labels → text).
164
  "merged" → one task=report sample, target "Findings: ...\n\n
165
- Impression: ...", carries the CheXpert string.
166
- "split_cascade" → findings sample carries the CheXpert string; the
167
  impression sample instead carries "Findings: <GT
168
  findings>" as context (findings→impression). Same
169
  convention as the IU builder.
@@ -201,9 +198,9 @@ def build_mimic_cxr_instruct_json(
201
  # ── CheXpert labels ───────────────────────────────────────────────────
202
  csv_path = _discover_chexpert_csv(mimic_root, chexpert_csv)
203
  if csv_path is not None:
204
- chexpert_map = _load_chexpert_map(csv_path, uncertain_policy)
205
  print(f"[mimic_cxr_builder] CheXpert CSV: {csv_path} "
206
- f"({len(chexpert_map):,} studies, uncertain={uncertain_policy})")
207
  else:
208
  chexpert_map = {}
209
  print("[mimic_cxr_builder] WARNING: no *chexpert*.csv found under "
@@ -213,23 +210,24 @@ def build_mimic_cxr_instruct_json(
213
 
214
  # ── Pass 1: index studies ─────────────────────────────────────────────
215
  samples: List[Dict] = []
216
- image_index: Dict[str, str] = {} # subject-relative pathsplit label
 
 
217
  n_studies = n_missing_report = n_no_chexpert = 0
218
  skipped_merged_no_impression = skipped_cascade_no_findings = 0
219
 
220
  def _structured_for(subj: str, study: str) -> Optional[str]:
221
  return chexpert_map.get((subj.lstrip("p"), study.lstrip("s")))
222
 
223
- def _image_groups(study_dir: Path, split_sub: str, subj: str, study: str):
 
 
 
 
 
 
 
224
  """Yield path_fields dicts honouring image_mode (same rules as IU)."""
225
- imgs = sorted(study_dir.glob("*.jpg"))
226
- if not imgs:
227
- return
228
- def _rel(img: Path) -> str:
229
- return f"{split_sub}/{img.parent.parent.parent.name}/{subj}/{study}/{img.name}"
230
- rels = [_rel(im) for im in imgs]
231
- for r in rels:
232
- image_index[r] = split_dirs[split_sub]
233
  if image_mode == "all_views_split":
234
  for r in rels:
235
  yield {"image_path": r, "image_paths": None}
@@ -242,11 +240,15 @@ def build_mimic_cxr_instruct_json(
242
  for p_dir in sorted(split_dir.glob("p*")):
243
  for pat_dir in p_dir.glob("p*"):
244
  for study_dir in pat_dir.glob("s*"):
245
- jpgs = list(study_dir.glob("*.jpg"))
246
- if not jpgs:
 
247
  continue
248
  n_studies += 1
249
- subj, study = pat_dir.name, study_dir.name
 
 
 
250
  txts = list(study_dir.glob("*.txt"))
251
  if not txts:
252
  n_missing_report += 1
@@ -257,7 +259,7 @@ def build_mimic_cxr_instruct_json(
257
  n_no_chexpert += 1
258
  split_label = split_dirs[split_sub]
259
 
260
- for path_fields in _image_groups(study_dir, split_sub, subj, study):
261
  base = {
262
  **path_fields,
263
  "question": None,
@@ -310,22 +312,26 @@ def build_mimic_cxr_instruct_json(
310
  sub_rel = str(row["image_path"]).lstrip("/")
311
  if sub_rel.startswith("files/"):
312
  sub_rel = sub_rel[len("files/"):]
313
- # match against any indexed image whose tail equals sub_rel
314
- hit = next((k for k in image_index if k.endswith(sub_rel)), None)
315
- if hit is None:
316
  n_vqa_dropped += 1
317
  continue
318
  ans = row.get("answer", [])
319
  answer = (", ".join(map(str, ans)) if isinstance(ans, list)
320
  else str(ans)) or "No."
 
 
321
  samples.append({
322
- "image_path": hit, "image_paths": None,
323
  "task": "vqa", "target": answer,
324
  "question": row["question"],
325
- "structured_findings": None,
 
 
 
326
  "split": split_label,
327
- "study_id": row.get("study_id"),
328
- "subject_id": row.get("subject_id"),
329
  })
330
  n_vqa += 1
331
 
@@ -372,9 +378,6 @@ def _parse_args():
372
  choices=["split", "merged", "split_cascade"])
373
  p.add_argument("--image_mode", default="all_views_split",
374
  choices=["all_views_split", "frontal_only_split", "multi_image_merged"])
375
- p.add_argument("--uncertain_policy", default="ignore",
376
- choices=["ignore", "positive"],
377
- help="CheXpert -1.0 (uncertain): ignore (default) or treat as positive.")
378
  return p.parse_args()
379
 
380
 
@@ -387,5 +390,4 @@ if __name__ == "__main__":
387
  vqa_root = a.vqa_root,
388
  report_mode = a.report_mode,
389
  image_mode = a.image_mode,
390
- uncertain_policy = a.uncertain_policy,
391
  )
 
15
  └── test /pNN/...
16
  {anywhere under mimic_root}/ *chexpert*.csv (optional, auto-discovered)
17
 
18
+ RaDialog-style abnormality guidance (U-MultiClass / META-CXR)
19
+ -------------------------------------------------------------
20
  The 14 CheXpert labels are read from `mimic-cxr-2.0.0-chexpert.csv`
21
  (CheXbert run on the ground-truth reports) and baked into the prompt as
22
+ `structured_findings` in the PNU 3-section format:
23
 
24
+ Positive Abnormalities: Cardiomegaly, Pleural Effusion
25
+ Negative Abnormalities: No Finding, Edema, ...
26
+ Uncertain Abnormalities: Atelectasis
27
+
28
+ CSV value → class: 1 → positive, 0 → negative, -1 → uncertain,
29
+ blank/NaN → negative (META-CXR convention: missing == negative).
30
 
31
  This is the *oracle* setting — GT labels, no trained image classifier and
32
+ no model change. The string format is shared verbatim with
33
+ `model.chexpert_classifier.format_pnu`, so the learned-classifier path
34
+ (at inference) produces byte-identical prompts. The existing
35
+ `structured_findings` plumbing carries it through train (dataset.py) and
36
+ eval (evaluate.py) untouched.
37
 
38
  VQA
39
  ---
40
+ VQA pairs live in 3 files {train,valid,test}.json (MIMIC-Ext-CXR-VQA);
41
+ attach them by passing `vqa_root`. Each row is one (image, question,
42
+ answer) sample — one image can yield many rows. VQA samples get the SAME
43
+ PNU CheXpert context as findings/impression (abnormality-guided VQA, à la
44
+ RaDialog), looked up by subject_id/study_id. Omit `vqa_root` to build
45
+ findings/impression only.
46
  """
47
 
48
  import argparse
 
59
  _FINDINGS_RE = re.compile(r"FINDINGS\s*:\s*(.*?)(?=\n\s*[A-Z ]{3,}\s*:|\Z)", re.S | re.I)
60
  _IMPRESSION_RE = re.compile(r"IMPRESSION\s*:\s*(.*?)(?=\n\s*[A-Z ]{3,}\s*:|\Z)", re.S | re.I)
61
 
62
+ # The 14-label list, PNU string formatter and class indices live in
63
+ # model.chexpert_classifier — single source of truth shared with the learned
64
+ # classifier so GT-oracle and predicted prompts are byte-identical. Imported
65
+ # lazily inside _load_chexpert_map (it pulls the model package, which is
66
+ # always available in the train/eval env where JSON building runs).
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  def _clean(txt: str) -> str:
 
80
  )
81
 
82
 
83
+ # ─── CheXpert CSV → PNU structured-findings string ──────────────────────────
84
 
85
  def _discover_chexpert_csv(mimic_root: Path, explicit: Optional[str]) -> Optional[Path]:
86
  if explicit:
 
94
  return None
95
 
96
 
97
+ def _load_chexpert_map(csv_path: Path) -> Dict[Tuple[str, str], str]:
 
 
 
98
  """
99
+ Return {(subject_id, study_id): <PNU string>} where the ids are the bare
100
+ integers as strings (CSV stores them without the p/s prefix).
101
+
102
+ U-MultiClass mapping of each CheXpert cell:
103
+ 1 / 1.0 → positive
104
+ 0 / 0.0 → negative
105
+ -1 / -1.0 → uncertain
106
+ blank/NaN → negative (META-CXR convention: missing == negative)
107
  """
108
+ from model.chexpert_classifier import (
109
+ PATHOLOGIES, buckets_to_pnu,
110
+ CLASS_NEGATIVE, CLASS_POSITIVE, CLASS_UNCERTAIN,
111
+ )
112
+ val_to_cls = {
113
+ "1": CLASS_POSITIVE, "1.0": CLASS_POSITIVE,
114
+ "0": CLASS_NEGATIVE, "0.0": CLASS_NEGATIVE,
115
+ "-1": CLASS_UNCERTAIN, "-1.0": CLASS_UNCERTAIN,
116
+ }
117
 
118
  out: Dict[Tuple[str, str], str] = {}
119
  with open(csv_path, newline="") as f:
 
127
  f"{csv_path} missing subject_id/study_id columns "
128
  f"(have: {reader.fieldnames})"
129
  )
130
+ label_cols = [(name, col[name.lower()]) for name in PATHOLOGIES
131
  if name.lower() in col]
132
 
133
  for row in reader:
134
  subj = str(row[subj_c]).strip().lstrip("p").split(".")[0]
135
  study = str(row[study_c]).strip().lstrip("s").split(".")[0]
136
+ mapping = {
137
+ name: val_to_cls.get(str(row.get(c, "")).strip(), CLASS_NEGATIVE)
138
+ for name, c in label_cols
139
+ }
140
+ out[(subj, study)] = buckets_to_pnu(mapping)
 
 
 
 
 
 
 
 
141
  return out
142
 
143
 
 
150
  vqa_root: Optional[str] = None,
151
  report_mode: str = "split", # "split" | "merged" | "split_cascade"
152
  image_mode: str = "all_views_split", # "all_views_split" | "frontal_only_split" | "multi_image_merged"
 
153
  ) -> str:
154
  """
155
  Build the unified MIMIC-CXR instruction JSON.
156
 
157
  report_mode mirrors iu_xray_builder:
158
  "split" → findings + impression samples; BOTH carry the CheXpert
159
+ PNU string in structured_findings (RaDialog: image +
160
+ 14 labels → text).
161
  "merged" → one task=report sample, target "Findings: ...\n\n
162
+ Impression: ...", carries the CheXpert PNU string.
163
+ "split_cascade" → findings sample carries the CheXpert PNU string; the
164
  impression sample instead carries "Findings: <GT
165
  findings>" as context (findings→impression). Same
166
  convention as the IU builder.
 
198
  # ── CheXpert labels ───────────────────────────────────────────────────
199
  csv_path = _discover_chexpert_csv(mimic_root, chexpert_csv)
200
  if csv_path is not None:
201
+ chexpert_map = _load_chexpert_map(csv_path)
202
  print(f"[mimic_cxr_builder] CheXpert CSV: {csv_path} "
203
+ f"({len(chexpert_map):,} studies, PNU U-MultiClass)")
204
  else:
205
  chexpert_map = {}
206
  print("[mimic_cxr_builder] WARNING: no *chexpert*.csv found under "
 
210
 
211
  # ── Pass 1: index studies ─────────────────────────────────────────────
212
  samples: List[Dict] = []
213
+ # sub_rel ("pXX/pXXXX/sYYYY/img.jpg")full stored image_path
214
+ # ("{split}/pXX/pXXXX/sYYYY/img.jpg"). O(1) VQA lookup.
215
+ image_index: Dict[str, str] = {}
216
  n_studies = n_missing_report = n_no_chexpert = 0
217
  skipped_merged_no_impression = skipped_cascade_no_findings = 0
218
 
219
  def _structured_for(subj: str, study: str) -> Optional[str]:
220
  return chexpert_map.get((subj.lstrip("p"), study.lstrip("s")))
221
 
222
+ def _rels_for(study_dir: Path, split_sub: str, subj: str, study: str) -> List[str]:
223
+ """Split-prefixed relative image paths for one study, sorted."""
224
+ return [
225
+ f"{split_sub}/{im.parent.parent.parent.name}/{subj}/{study}/{im.name}"
226
+ for im in sorted(study_dir.glob("*.jpg"))
227
+ ]
228
+
229
+ def _image_groups(rels: List[str]):
230
  """Yield path_fields dicts honouring image_mode (same rules as IU)."""
 
 
 
 
 
 
 
 
231
  if image_mode == "all_views_split":
232
  for r in rels:
233
  yield {"image_path": r, "image_paths": None}
 
240
  for p_dir in sorted(split_dir.glob("p*")):
241
  for pat_dir in p_dir.glob("p*"):
242
  for study_dir in pat_dir.glob("s*"):
243
+ subj, study = pat_dir.name, study_dir.name
244
+ rels = _rels_for(study_dir, split_sub, subj, study)
245
+ if not rels:
246
  continue
247
  n_studies += 1
248
+ # Index EVERY image up front — a VQA row may reference a
249
+ # study that has images but no findings/impression report.
250
+ for r in rels:
251
+ image_index[r.split("/", 1)[1]] = r
252
  txts = list(study_dir.glob("*.txt"))
253
  if not txts:
254
  n_missing_report += 1
 
259
  n_no_chexpert += 1
260
  split_label = split_dirs[split_sub]
261
 
262
+ for path_fields in _image_groups(rels):
263
  base = {
264
  **path_fields,
265
  "question": None,
 
312
  sub_rel = str(row["image_path"]).lstrip("/")
313
  if sub_rel.startswith("files/"):
314
  sub_rel = sub_rel[len("files/"):]
315
+ full = image_index.get(sub_rel) # O(1)
316
+ if full is None:
 
317
  n_vqa_dropped += 1
318
  continue
319
  ans = row.get("answer", [])
320
  answer = (", ".join(map(str, ans)) if isinstance(ans, list)
321
  else str(ans)) or "No."
322
+ subj = str(row.get("subject_id", ""))
323
+ study = str(row.get("study_id", ""))
324
  samples.append({
325
+ "image_path": full, "image_paths": None,
326
  "task": "vqa", "target": answer,
327
  "question": row["question"],
328
+ # Abnormality-guided VQA (RaDialog): same PNU CheXpert
329
+ # context as findings/impression. None if no chexpert.csv
330
+ # (graceful — falls back to image + question only).
331
+ "structured_findings": _structured_for(subj, study),
332
  "split": split_label,
333
+ "study_id": study,
334
+ "subject_id": subj,
335
  })
336
  n_vqa += 1
337
 
 
378
  choices=["split", "merged", "split_cascade"])
379
  p.add_argument("--image_mode", default="all_views_split",
380
  choices=["all_views_split", "frontal_only_split", "multi_image_merged"])
 
 
 
381
  return p.parse_args()
382
 
383
 
 
390
  vqa_root = a.vqa_root,
391
  report_mode = a.report_mode,
392
  image_mode = a.image_mode,
 
393
  )
model/chexpert_classifier.py CHANGED
@@ -1,21 +1,28 @@
1
  """
2
  chexpert_classifier.py
3
  ----------------------
4
- Multi-label CheXpert pathology classifier.
5
- Trained separately on MIMIC-CXR with CheXbert labels.
6
 
7
- This component provides structured findings (e.g. "Pleural Effusion: Positive")
8
- that are appended to the LLM prompt alongside image tokens, improving clinical
9
- accuracy of generated reports.
10
 
11
- Reference: RaDialog (Pellegrini et al., 2023) CheXpert Classifier provides
12
- structured findings to the LLM prompt to improve clinical correctness.
 
 
 
 
 
 
 
 
 
13
  """
14
 
15
  import torch
16
  import torch.nn as nn
17
- from pathlib import Path
18
- from typing import Optional, List, Dict
19
 
20
 
21
  PATHOLOGIES = [
@@ -35,40 +42,77 @@ PATHOLOGIES = [
35
  "Support Devices",
36
  ]
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class CheXpertClassifier(nn.Module):
40
  """
41
- Lightweight multi-label classifier on top of BioViL-T global embeddings.
42
 
43
- Trained separately (Stage 0) with binary cross-entropy loss.
44
- Frozen during Stage 1 and Stage 2 of main model training.
45
 
46
  Args:
47
- input_dim: BioViL-T global embedding dim (512 for BioViL-T global)
48
- num_classes: number of pathology classes (14)
49
- threshold: classification threshold for positive predictions
50
- checkpoint: path to trained weights (None = random init / not loaded)
51
  """
52
 
53
  def __init__(
54
  self,
55
- input_dim: int = 512,
56
  num_classes: int = 14,
57
- threshold: float = 0.5,
58
- checkpoint: Optional[str] = None,
59
  ):
60
  super().__init__()
61
 
62
  self.num_classes = num_classes
63
- self.threshold = threshold
64
  self.pathologies = PATHOLOGIES
65
 
66
- # Simple MLP classifier head
67
  self.classifier = nn.Sequential(
68
  nn.Linear(input_dim, 256),
69
  nn.ReLU(),
70
  nn.Dropout(0.2),
71
- nn.Linear(256, num_classes),
72
  )
73
 
74
  if checkpoint is not None:
@@ -82,46 +126,44 @@ class CheXpertClassifier(nn.Module):
82
  def forward(self, global_features: torch.Tensor) -> torch.Tensor:
83
  """
84
  Args:
85
- global_features: (B, input_dim) — global CXR embedding from BioViL-T
86
 
87
  Returns:
88
- logits: (B, 14)
 
 
 
89
  """
90
- return self.classifier(global_features)
 
91
 
92
  @torch.no_grad()
93
  def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]:
94
  """
95
- Run inference and return human-readable findings per sample.
96
-
97
- Returns:
98
- List of dicts like {"Pleural Effusion": "Positive", "Cardiomegaly": "Negative", ...}
99
- """
100
- logits = self.forward(global_features) # (B, 14)
101
- probs = torch.sigmoid(logits) # (B, 14)
102
- preds = (probs > self.threshold).cpu() # (B, 14) bool
103
-
104
- results = []
105
- for i in range(preds.size(0)):
106
- finding = {}
107
- for j, name in enumerate(self.pathologies):
108
- finding[name] = "Positive" if preds[i, j].item() else "Negative"
109
- results.append(finding)
110
- return results
111
-
112
- def findings_to_text(self, findings: Dict[str, str]) -> str:
113
  """
114
- Convert findings dict to a structured text string for LLM prompt.
 
 
 
 
 
 
 
 
115
 
116
- Example output:
117
- "Predicted Findings: Pleural Effusion: Positive, Cardiomegaly: Negative, ..."
118
  """
119
- positive = [k for k, v in findings.items() if v == "Positive"]
120
- negative = [k for k, v in findings.items() if v == "Negative"]
121
-
122
- if not positive:
123
- pos_str = "No Finding"
124
- else:
125
- pos_str = ", ".join(positive)
126
-
127
- return f"Predicted Findings: {pos_str}"
 
 
 
1
  """
2
  chexpert_classifier.py
3
  ----------------------
4
+ Multi-label, multi-CLASS CheXpert pathology classifier (U-MultiClass).
 
5
 
6
+ Each of the 14 pathologies is predicted as one of THREE classes —
7
+ negative / positive / uncertain via a per-pathology softmax, mirroring
8
+ META-CXR's MHCAC head and the CheXpert "U-MultiClass" uncertainty policy.
9
 
10
+ The structured findings injected into the LLM prompt use the PNU
11
+ (Positive / Negative / Uncertain) 3-section format. `format_pnu()` is the
12
+ single source of truth for that string so the oracle path
13
+ (data/mimic_cxr_builder.py, GT from chexpert.csv) and the learned path
14
+ (this classifier at inference) produce byte-identical prompts.
15
+
16
+ Trained separately (Stage 0) on MIMIC-CXR CheXbert labels; frozen during
17
+ Stage 1 / Stage 2 of the main VLM.
18
+
19
+ Reference: RaDialog (Pellegrini et al., 2023) for the prompt-conditioning
20
+ idea; META-CXR (Edirisinghe et al., 2025) for the explicit uncertain class.
21
  """
22
 
23
  import torch
24
  import torch.nn as nn
25
+ from typing import Optional, List, Dict, Sequence
 
26
 
27
 
28
  PATHOLOGIES = [
 
42
  "Support Devices",
43
  ]
44
 
45
+ # Per-pathology class indices (softmax dim order). Keep this stable: the
46
+ # trained checkpoint and the GT-label mapping in mimic_cxr_builder.py both
47
+ # rely on it.
48
+ CLASS_NEGATIVE = 0
49
+ CLASS_POSITIVE = 1
50
+ CLASS_UNCERTAIN = 2
51
+ NUM_STATES = 3
52
+ CLASS_NAMES = {CLASS_NEGATIVE: "negative",
53
+ CLASS_POSITIVE: "positive",
54
+ CLASS_UNCERTAIN: "uncertain"}
55
+
56
+
57
+ def format_pnu(positive: Sequence[str],
58
+ negative: Sequence[str],
59
+ uncertain: Sequence[str]) -> str:
60
+ """
61
+ Build the PNU structured-findings string (META-CXR prompt format).
62
+
63
+ Positive Abnormalities: Cardiomegaly, Pleural Effusion
64
+ Negative Abnormalities: No Finding, Edema, ...
65
+ Uncertain Abnormalities: Atelectasis
66
+
67
+ Empty sections render as "None" so the three lines are always present
68
+ (the LLM sees a fixed structure regardless of the case).
69
+ """
70
+ def _fmt(xs: Sequence[str]) -> str:
71
+ return ", ".join(xs) if xs else "None"
72
+ return (f"Positive Abnormalities: {_fmt(positive)}\n"
73
+ f"Negative Abnormalities: {_fmt(negative)}\n"
74
+ f"Uncertain Abnormalities: {_fmt(uncertain)}")
75
+
76
+
77
+ def buckets_to_pnu(class_by_pathology: Dict[str, int]) -> str:
78
+ """Group a {pathology: class_idx} dict into the PNU string."""
79
+ pos = [p for p, c in class_by_pathology.items() if c == CLASS_POSITIVE]
80
+ neg = [p for p, c in class_by_pathology.items() if c == CLASS_NEGATIVE]
81
+ unc = [p for p, c in class_by_pathology.items() if c == CLASS_UNCERTAIN]
82
+ return format_pnu(pos, neg, unc)
83
+
84
 
85
  class CheXpertClassifier(nn.Module):
86
  """
87
+ Multi-label, 3-class-per-label classifier on BioViL-T global embeddings.
88
 
89
+ Output logits have shape (B, 14, 3); a per-pathology softmax/argmax
90
+ yields negative / positive / uncertain.
91
 
92
  Args:
93
+ input_dim: global CXR embedding dim
94
+ num_classes: number of pathologies (14)
95
+ checkpoint: trained weights (None = not loaded)
 
96
  """
97
 
98
  def __init__(
99
  self,
100
+ input_dim: int = 512,
101
  num_classes: int = 14,
102
+ checkpoint: Optional[str] = None,
 
103
  ):
104
  super().__init__()
105
 
106
  self.num_classes = num_classes
107
+ self.num_states = NUM_STATES
108
  self.pathologies = PATHOLOGIES
109
 
110
+ # MLP head → num_classes * 3 logits, reshaped to (B, num_classes, 3)
111
  self.classifier = nn.Sequential(
112
  nn.Linear(input_dim, 256),
113
  nn.ReLU(),
114
  nn.Dropout(0.2),
115
+ nn.Linear(256, num_classes * NUM_STATES),
116
  )
117
 
118
  if checkpoint is not None:
 
126
  def forward(self, global_features: torch.Tensor) -> torch.Tensor:
127
  """
128
  Args:
129
+ global_features: (B, input_dim)
130
 
131
  Returns:
132
+ logits: (B, num_classes, 3) — softmax over the last dim gives
133
+ P(negative), P(positive), P(uncertain) per pathology.
134
+ Train with cross-entropy over the last dim (the natural
135
+ U-MultiClass objective).
136
  """
137
+ flat = self.classifier(global_features) # (B, 14*3)
138
+ return flat.view(-1, self.num_classes, NUM_STATES) # (B, 14, 3)
139
 
140
  @torch.no_grad()
141
  def predict(self, global_features: torch.Tensor) -> List[Dict[str, str]]:
142
  """
143
+ Returns a list (per sample) of {pathology: "negative"|"positive"|
144
+ "uncertain"} using argmax over the 3-state softmax.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  """
146
+ logits = self.forward(global_features) # (B, 14, 3)
147
+ cls = logits.argmax(dim=-1).cpu() # (B, 14)
148
+ out: List[Dict[str, str]] = []
149
+ for i in range(cls.size(0)):
150
+ out.append({
151
+ name: CLASS_NAMES[int(cls[i, j].item())]
152
+ for j, name in enumerate(self.pathologies)
153
+ })
154
+ return out
155
 
156
+ @torch.no_grad()
157
+ def findings_to_text(self, global_features: torch.Tensor) -> List[str]:
158
  """
159
+ Per-sample PNU structured-findings string, identical in format to the
160
+ GT oracle path (data/mimic_cxr_builder.py). One string per sample.
161
+ """
162
+ logits = self.forward(global_features) # (B, 14, 3)
163
+ cls = logits.argmax(dim=-1).cpu() # (B, 14)
164
+ texts: List[str] = []
165
+ for i in range(cls.size(0)):
166
+ mapping = {name: int(cls[i, j].item())
167
+ for j, name in enumerate(self.pathologies)}
168
+ texts.append(buckets_to_pnu(mapping))
169
+ return texts
model/{image_encoder.py → rad_dino.py} RENAMED
File without changes
utils/dataset_resolver.py CHANGED
@@ -223,13 +223,12 @@ def _ensure_mimic_json_exists(data_cfg,
223
  print(f"[dataset_resolver] MIMIC JSON not found → auto-building "
224
  f"(report_mode={report_mode}, image_mode={image_mode}) …")
225
  build_mimic_cxr_instruct_json(
226
- mimic_root = str(_get(data_cfg, "mimic_cxr_root")),
227
- output_path = str(out),
228
- chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"),
229
- vqa_root = _get(data_cfg, "mimic_vqa_root"),
230
- report_mode = report_mode,
231
- image_mode = image_mode,
232
- uncertain_policy = str(_get(data_cfg, "mimic_uncertain_policy", "ignore")),
233
  )
234
  return str(out)
235
 
 
223
  print(f"[dataset_resolver] MIMIC JSON not found → auto-building "
224
  f"(report_mode={report_mode}, image_mode={image_mode}) …")
225
  build_mimic_cxr_instruct_json(
226
+ mimic_root = str(_get(data_cfg, "mimic_cxr_root")),
227
+ output_path = str(out),
228
+ chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"),
229
+ vqa_root = _get(data_cfg, "mimic_vqa_root"),
230
+ report_mode = report_mode,
231
+ image_mode = image_mode,
 
232
  )
233
  return str(out)
234