cxr-vlm-code / utils /dataset_resolver.py
convitom
f
b961b41
"""
dataset_resolver.py
-------------------
Centralises the logic that decides, based on `train_cfg.data.dataset_name`:
1. Which tasks are trainable/evaluable for the chosen dataset.
2. Where images live (`image_root`).
3. Where the unified instruction JSON lives (and building it on-demand
for IU X-ray if missing).
4. Task-weight normalization (dropping disabled tasks).
5. The `run_id` used in all output paths:
{dataset_name}_run_{N}
Numbering scans the existing checkpoint directory — so re-running
with the same dataset auto-picks the next N without talking to HF.
Keeping this out of train.py / evaluate.py means those two entry points
stay short, and the MIMIC-CXR code path is untouched.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
SUPPORTED_DATASETS = ("MIMIC-CXR", "MIMIC-CXR_resized", "IU-Xray")
@dataclass
class DatasetSpec:
"""Resolved data-layer configuration for a training/eval run."""
dataset_name: str # "MIMIC-CXR" or "IU-Xray"
image_root: str # passed to CXRInstructDataset
instruct_json: str # passed to CXRInstructDataset
tasks: List[str] # which tasks exist in this dataset
task_weights: Dict[str, float] # normalized over `tasks`
report_mode: str = "split" # "split" | "merged" | "split_cascade"
image_mode: str = "all_views_split" # "all_views_split" | "frontal_only_split" | "multi_image_merged"
max_images: int = 1 # >1 only when image_mode == multi_image_merged
# ─── Dataset resolution ─────────────────────────────────────────────────────
def resolve_dataset_spec(train_cfg) -> DatasetSpec:
"""
Read `train_cfg.data.dataset_name` and return the matching DatasetSpec.
For IU-Xray this will also auto-build the instruction JSON if it's
missing and `iu_xray.auto_build == true`.
The choice of which tasks are "available" depends on `data.report_mode`:
"split" → findings, impression (+ vqa for MIMIC)
"merged" → report (+ vqa for MIMIC)
"split_cascade" → findings, impression (+ vqa for MIMIC); same task set
and weights as "split" — only the data builder differs
(impression sample carries GT findings as context).
"""
name = _get(train_cfg.data, "dataset_name", "MIMIC-CXR")
report_mode = _get(train_cfg.data, "report_mode", "split")
image_mode = _get(train_cfg.data, "image_mode", "all_views_split")
max_images = int(_get(train_cfg.data, "max_images_per_sample", 2))
if report_mode not in ("split", "merged", "split_cascade"):
raise ValueError(
f"data.report_mode must be 'split', 'merged', or 'split_cascade', "
f"got {report_mode!r}"
)
if image_mode not in ("all_views_split", "frontal_only_split", "multi_image_merged"):
raise ValueError(
f"data.image_mode must be one of all_views_split / frontal_only_split / "
f"multi_image_merged, got {image_mode!r}"
)
# In single-image modes max_images must be 1; otherwise the dataset would
# pad each sample to N>1 (wasted compute, possibly wrong behaviour).
effective_max_images = max_images if image_mode == "multi_image_merged" else 1
if name not in SUPPORTED_DATASETS:
raise ValueError(
f"Unsupported dataset_name: {name!r}. "
f"Expected one of {SUPPORTED_DATASETS}."
)
# Extract configured task weights + enabled flags.
# In "merged" mode findings_generation / impression_generation are ignored
# in favour of report_generation. In "split" mode the opposite.
tasks_cfg = train_cfg.tasks
report_w = float(_get(tasks_cfg, "report_generation",
type("_x", (), {"weight": 0.6, "enabled": True})()).weight) \
if _get(tasks_cfg, "report_generation") is not None else 0.6
all_weights = {
"findings": float(tasks_cfg.findings_generation.weight)
if tasks_cfg.findings_generation.enabled else 0.0,
"impression": float(tasks_cfg.impression_generation.weight)
if tasks_cfg.impression_generation.enabled else 0.0,
"report": report_w if report_mode == "merged" else 0.0,
"vqa": float(tasks_cfg.vqa.weight)
if tasks_cfg.vqa.enabled else 0.0,
}
if report_mode == "merged":
# Mute the now-unused single-section weights so they can't sneak back in.
all_weights["findings"] = 0.0
all_weights["impression"] = 0.0
if name == "MIMIC-CXR":
# All three tasks available (unchanged legacy behaviour)
if report_mode == "merged":
available = ["report", "vqa"]
else:
available = ["findings", "impression", "vqa"]
image_root = train_cfg.data.mimic_cxr_root
instruct_json = _ensure_mimic_json_exists(
train_cfg.data, report_mode, image_mode
)
elif name == "MIMIC-CXR_resized":
# Same semantic dataset as MIMIC-CXR (all 3 tasks) but the on-disk
# layout is the raw PhysioNet tree {root}/files/pXX/... and splits
# come from mimic-cxr-2.0.0-split.csv instead of a pre-split dir
# structure. Reuses the same builder with layout="files".
if report_mode == "merged":
available = ["report", "vqa"]
else:
available = ["findings", "impression", "vqa"]
mr = train_cfg.data.mimic_cxr_resized
image_root = mr.root
instruct_json = _ensure_mimic_resized_json_exists(
mr, report_mode, image_mode
)
else: # IU-Xray
# IU has no VQA.
available = ["report"] if report_mode == "merged" else ["findings", "impression"]
iu = train_cfg.data.iu_xray
image_root = iu.images_dir
instruct_json = _ensure_iu_json_exists(iu, report_mode, image_mode)
# Keep only enabled tasks that actually exist in the dataset
selected = [t for t in available if all_weights.get(t, 0.0) > 0]
if not selected:
raise ValueError(
f"No enabled tasks match dataset {name}. "
f"Enable at least one of {available} in `tasks:` config."
)
weights = {t: all_weights[t] for t in selected}
total = sum(weights.values())
weights = {t: w / total for t, w in weights.items()}
return DatasetSpec(
dataset_name = name,
image_root = str(image_root),
instruct_json = str(instruct_json),
tasks = selected,
task_weights = weights,
report_mode = report_mode,
image_mode = image_mode,
max_images = effective_max_images,
)
def _ensure_iu_json_exists(iu_cfg,
report_mode: str = "split",
image_mode: str = "all_views_split") -> str:
"""
Build the IU X-ray unified JSON if missing (auto_build=true).
The cached JSON path is automatically suffixed with BOTH report_mode and
image_mode (e.g. iu_xray_instruct__split__all_views_split.json) so any
of the 6 mode combinations gets its own cached file and never overwrites
a JSON built with different settings.
"""
base = Path(iu_cfg.instruct_json)
out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")
if out.is_file():
return str(out)
auto = _get(iu_cfg, "auto_build", True)
if not auto:
raise FileNotFoundError(
f"IU X-ray instruct JSON not found at {out} and auto_build=false. "
f"Run: python -m data.iu_xray_builder --images_dir {iu_cfg.images_dir} "
f"--labels_dir {iu_cfg.labels_dir} --output {out} "
f"--report_mode {report_mode} --image_mode {image_mode}"
)
# Lazy import to avoid pulling xml.etree on MIMIC-only runs
from data.iu_xray_builder import build_iu_xray_instruct_json
print(f"[dataset_resolver] IU X-ray JSON not found → auto-building "
f"(report_mode={report_mode}, image_mode={image_mode}) …")
build_iu_xray_instruct_json(
images_dir = iu_cfg.images_dir,
labels_dir = iu_cfg.labels_dir,
output_path = str(out),
train_ratio = float(_get(iu_cfg, "train_ratio", 0.70)),
val_ratio = float(_get(iu_cfg, "val_ratio", 0.15)),
test_ratio = float(_get(iu_cfg, "test_ratio", 0.15)),
seed = int(_get(iu_cfg, "seed", 42)),
image_suffix = str(_get(iu_cfg, "image_suffix", ".png")),
report_mode = report_mode,
image_mode = image_mode,
)
return str(out)
def _ensure_mimic_json_exists(data_cfg,
report_mode: str = "split",
image_mode: str = "all_views_split") -> str:
"""
Build the MIMIC-CXR unified JSON if missing.
The configured `data.instruct_json` path is suffixed with both
report_mode and image_mode (mimic_..._instruct__split__all_views_split.json)
so each of the mode combinations gets its own cache and the RaDialog
CheXpert-guided JSON never collides with one built under other settings.
Auto-build (default on) reads `*chexpert*.csv` to bake the 14 oracle
labels into structured_findings. Set `data.mimic_auto_build: false` to
require a pre-built file instead.
"""
base = Path(_get(data_cfg, "instruct_json",
"data/data_files/mimic_cxr_instruct_unified.json"))
out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")
if out.is_file():
return str(out)
if not bool(_get(data_cfg, "mimic_auto_build", True)):
raise FileNotFoundError(
f"MIMIC instruct JSON not found at {out} and "
f"data.mimic_auto_build=false. Run: python -m data.mimic_cxr_builder "
f"--mimic_root {_get(data_cfg, 'mimic_cxr_root')} --output {out} "
f"--report_mode {report_mode} --image_mode {image_mode}"
)
from data.mimic_cxr_builder import build_mimic_cxr_instruct_json
print(f"[dataset_resolver] MIMIC JSON not found → auto-building "
f"(report_mode={report_mode}, image_mode={image_mode}) …")
build_mimic_cxr_instruct_json(
mimic_root = str(_get(data_cfg, "mimic_cxr_root")),
output_path = str(out),
chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"),
vqa_root = _get(data_cfg, "mimic_vqa_root"),
report_mode = report_mode,
image_mode = image_mode,
)
return str(out)
def _ensure_mimic_resized_json_exists(mr_cfg,
report_mode: str = "split",
image_mode: str = "all_views_split") -> str:
"""
Build the MIMIC-CXR_resized unified JSON if missing.
This dataset is **manifest-driven**, not directory-walking:
- 3 manifest CSVs (manifest_{train,val,test}.csv) carry every row's
split label, image/report relative path, and the 14 CheXpert
labels as chex_* columns. No separate *split*.csv or *chexpert*.csv
is read.
- VQA is read from `vqa_dir/{vqa.json, vqa_val.json, vqa_test.json}`.
The cache path is suffixed with report_mode+image_mode (same convention
as the other two builders) so each mode combination gets its own cache.
"""
base = Path(_get(mr_cfg, "instruct_json",
"data/data_files/mimic_cxr_resized_instruct.json"))
out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")
if out.is_file():
return str(out)
if not bool(_get(mr_cfg, "auto_build", True)):
raise FileNotFoundError(
f"MIMIC-CXR_resized instruct JSON not found at {out} and "
f"auto_build=false. Run: python -m data.mimic_cxr_resized_builder "
f"--root {_get(mr_cfg, 'root')} --output {out} "
f"--report_mode {report_mode} --image_mode {image_mode}"
)
from data.mimic_cxr_resized_builder import build_mimic_cxr_resized_instruct_json
print(f"[dataset_resolver] MIMIC-CXR_resized JSON not found → auto-building "
f"(report_mode={report_mode}, image_mode={image_mode}) …")
root_path = str(_get(mr_cfg, "root"))
# Convention defaults: manifest CSVs sit at `root`, VQA at `{root}/vqa`.
# Either can be overridden in config; an explicit empty string for
# vqa_dir disables VQA entirely.
manifest_dir = _get(mr_cfg, "manifest_dir") or root_path
vqa_dir_cfg = _get(mr_cfg, "vqa_dir")
if vqa_dir_cfg is None:
vqa_dir = str(Path(root_path) / "vqa")
elif vqa_dir_cfg == "":
vqa_dir = None # explicit opt-out
else:
vqa_dir = str(vqa_dir_cfg)
build_mimic_cxr_resized_instruct_json(
root = root_path,
manifest_dir = manifest_dir,
output_path = str(out),
vqa_dir = vqa_dir,
reports_root = _get(mr_cfg, "reports_root"),
report_mode = report_mode,
image_mode = image_mode,
)
return str(out)
# ─── Run ID resolution (dataset-prefixed) ───────────────────────────────────
def resolve_run_id(
dataset_name: str,
output_root: str,
state_file: str,
resuming: bool,
explicit: Optional[str] = None,
hf_repo_id: Optional[str] = None,
hf_token: Optional[str] = None,
) -> str:
"""
Pick a run_id of the form "{dataset_name}_run_{N}".
Resolution order:
1. `explicit` flag (always wins) — pass --run_id to force a specific id
(e.g. continue a run after VM restart without flagging it as resume).
2. `resuming=True` (i.e. --resume_from / --resume_from_hf): read
state_file → fall back to latest run on local disk → fall back to
latest run on HF Hub.
3. Fresh session: ALWAYS pick a brand-new id = max(local, remote) + 1.
The local state file is NOT honoured here — a stale run_id.txt left
over from a previous run would otherwise silently overwrite that run.
Use `--run_id <name>` if you really mean to keep appending.
"""
prefix = f"{dataset_name}_run_"
if explicit:
_write_state(state_file, explicit)
return explicit
def _all_existing() -> List[int]:
local = _scan_local_runs(output_root, prefix)
remote = _scan_remote_runs(hf_repo_id, hf_token, prefix)
return sorted(set(local) | set(remote))
state_path = Path(state_file)
if resuming:
if state_path.exists():
return state_path.read_text().strip()
# No state file but user said --resume_from: pick the latest run
# that exists anywhere (local OR remote) as best-effort fallback.
existing = _all_existing()
if existing:
rid = f"{prefix}{max(existing)}"
_write_state(state_file, rid)
return rid
raise RuntimeError(
f"Cannot resume: no state file at {state_path}, no '{prefix}*' "
f"folders under {output_root}, and none on HF Hub "
f"({hf_repo_id or 'no repo configured'}). Pass --run_id explicitly."
)
# Fresh session — always allocate a new id, ignoring stale state file.
existing = _all_existing()
next_n = (max(existing) + 1) if existing else 1
rid = f"{prefix}{next_n}"
_write_state(state_file, rid)
if existing:
print(f"[resolve_run_id] fresh run → {rid} "
f"(found existing: {[f'{prefix}{n}' for n in existing]})")
else:
print(f"[resolve_run_id] first run for this dataset → {rid}")
return rid
def _scan_remote_runs(repo_id: Optional[str], token: Optional[str], prefix: str) -> List[int]:
"""List existing '<prefix>N' folders on the HF Hub repo. Best-effort —
returns [] on any failure (no token, no repo, network down, …)."""
if not repo_id:
return []
try:
from huggingface_hub import HfApi
api = HfApi(token=token)
files = api.list_repo_files(repo_id, token=token)
except Exception as e:
print(f"[resolve_run_id] could not list HF runs ({type(e).__name__}: {e})")
return []
rx = re.compile(rf"^{re.escape(prefix)}(\d+)(?:/|$)")
nums = set()
for f in files:
m = rx.match(f)
if m:
nums.add(int(m.group(1)))
return sorted(nums)
def _scan_local_runs(output_root: str, prefix: str) -> List[int]:
root = Path(output_root)
if not root.is_dir():
return []
rx = re.compile(rf"^{re.escape(prefix)}(\d+)$")
out = []
for d in root.iterdir():
if not d.is_dir():
continue
m = rx.match(d.name)
if m:
out.append(int(m.group(1)))
return sorted(out)
def _write_state(state_file: str, run_id: str) -> None:
p = Path(state_file)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(run_id)
# ─── Path helpers ───────────────────────────────────────────────────────────
def run_dir(output_root: str, run_id: str) -> Path:
"""`{output_root}/{run_id}` — created if missing."""
p = Path(output_root) / run_id
p.mkdir(parents=True, exist_ok=True)
return p
def stage_dir(output_root: str, run_id: str, subdir: str) -> str:
"""`{output_root}/{run_id}/{subdir}` as a string (for HF Trainer)."""
p = run_dir(output_root, run_id) / subdir
p.mkdir(parents=True, exist_ok=True)
return str(p)
# ─── Run-config snapshot ────────────────────────────────────────────────────
def save_run_config(
run_dir_path,
spec: "DatasetSpec",
model_cfg,
train_cfg,
extra: Optional[Dict] = None,
) -> None:
"""
Persist a snapshot of the resolved config into the run directory so each
run is self-describing. Writes:
{run_dir}/configs/model_config.yaml — full OmegaConf dump
{run_dir}/configs/train_config.yaml — full OmegaConf dump
{run_dir}/run_meta.json — compact, human-readable summary
`run_meta.json` is intentionally small: it carries the fields a person
typically wants when comparing two runs side by side (dataset, training
schedule, mode flags). The full YAML dumps are the source of truth.
`extra` is merged into `run_meta.json` — useful for adding e.g. the git
commit hash or the resume source.
"""
import json as _json
from datetime import datetime, timezone
try:
from omegaconf import OmegaConf
_to_yaml = lambda c: OmegaConf.to_yaml(c)
_to_container = lambda c: OmegaConf.to_container(c, resolve=True)
except Exception:
_to_yaml = lambda c: str(c)
_to_container = lambda c: dict(c)
run_dir_path = Path(run_dir_path)
cfg_dir = run_dir_path / "configs"
cfg_dir.mkdir(parents=True, exist_ok=True)
(cfg_dir / "model_config.yaml").write_text(_to_yaml(model_cfg), encoding="utf-8")
(cfg_dir / "train_config.yaml").write_text(_to_yaml(train_cfg), encoding="utf-8")
# Compact summary — only the fields that meaningfully change behaviour.
stage1 = train_cfg.stage1 if "stage1" in _to_container(train_cfg) else {}
stage2 = train_cfg.stage2 if "stage2" in _to_container(train_cfg) else {}
meta = {
"run_id": run_dir_path.name,
"saved_at": datetime.now(timezone.utc).isoformat(timespec="seconds"),
# — Data
"dataset": spec.dataset_name,
"image_root": spec.image_root,
"instruct_json": spec.instruct_json,
"report_mode": spec.report_mode,
"image_mode": spec.image_mode,
"max_images": spec.max_images,
"tasks": spec.tasks,
"task_weights": spec.task_weights,
# — Training schedule
"stage1": {
"enabled": _get(stage1, "enabled", True),
"num_epochs": _get(stage1, "num_epochs", None),
"learning_rate": _get(stage1, "learning_rate", None),
"freeze_llm": _get(stage1, "freeze_llm", True),
"freeze_encoder": _get(stage1, "freeze_encoder", True),
},
"stage2": {
"enabled": _get(stage2, "enabled", True),
"num_epochs": _get(stage2, "num_epochs", None),
"learning_rate": _get(stage2, "learning_rate", None),
"freeze_llm": _get(stage2, "freeze_llm", False),
"freeze_encoder": _get(stage2, "freeze_encoder", True),
},
"batch_size": _get(train_cfg.training, "per_device_train_batch_size", None),
"grad_accum": _get(train_cfg.training, "gradient_accumulation_steps", None),
"cutoff_len": _get(train_cfg.training, "cutoff_len", None),
"fp16": _get(train_cfg.training, "fp16", None),
"bf16": _get(train_cfg.training, "bf16", None),
# — Model
"llm": _get(model_cfg.llm, "name", None),
"lora_r": _get(model_cfg.lora, "r", None),
"lora_alpha": _get(model_cfg.lora, "lora_alpha", None),
"num_image_tokens": _get(model_cfg.projection, "num_image_tokens", None),
"chexpert_enabled": _get(model_cfg.chexpert_classifier, "enabled", None),
}
if extra:
meta.update(extra)
(run_dir_path / "run_meta.json").write_text(
_json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8",
)
print(f"[save_run_config] snapshot → {run_dir_path}/configs/, run_meta.json")
# ─── Misc ───────────────────────────────────────────────────────────────────
def _get(obj, key: str, default=None):
"""OmegaConf-safe .get with default."""
try:
v = getattr(obj, key)
return v if v is not None else default
except Exception:
return default