""" dataset_resolver.py ------------------- Centralises the logic that decides, based on `train_cfg.data.dataset_name`: 1. Which tasks are trainable/evaluable for the chosen dataset. 2. Where images live (`image_root`). 3. Where the unified instruction JSON lives (and building it on-demand for IU X-ray if missing). 4. Task-weight normalization (dropping disabled tasks). 5. The `run_id` used in all output paths: {dataset_name}_run_{N} Numbering scans the existing checkpoint directory — so re-running with the same dataset auto-picks the next N without talking to HF. Keeping this out of train.py / evaluate.py means those two entry points stay short, and the MIMIC-CXR code path is untouched. """ from __future__ import annotations import re from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional SUPPORTED_DATASETS = ("MIMIC-CXR", "MIMIC-CXR_resized", "IU-Xray") @dataclass class DatasetSpec: """Resolved data-layer configuration for a training/eval run.""" dataset_name: str # "MIMIC-CXR" or "IU-Xray" image_root: str # passed to CXRInstructDataset instruct_json: str # passed to CXRInstructDataset tasks: List[str] # which tasks exist in this dataset task_weights: Dict[str, float] # normalized over `tasks` report_mode: str = "split" # "split" | "merged" | "split_cascade" image_mode: str = "all_views_split" # "all_views_split" | "frontal_only_split" | "multi_image_merged" max_images: int = 1 # >1 only when image_mode == multi_image_merged # ─── Dataset resolution ───────────────────────────────────────────────────── def resolve_dataset_spec(train_cfg) -> DatasetSpec: """ Read `train_cfg.data.dataset_name` and return the matching DatasetSpec. For IU-Xray this will also auto-build the instruction JSON if it's missing and `iu_xray.auto_build == true`. The choice of which tasks are "available" depends on `data.report_mode`: "split" → findings, impression (+ vqa for MIMIC) "merged" → report (+ vqa for MIMIC) "split_cascade" → findings, impression (+ vqa for MIMIC); same task set and weights as "split" — only the data builder differs (impression sample carries GT findings as context). """ name = _get(train_cfg.data, "dataset_name", "MIMIC-CXR") report_mode = _get(train_cfg.data, "report_mode", "split") image_mode = _get(train_cfg.data, "image_mode", "all_views_split") max_images = int(_get(train_cfg.data, "max_images_per_sample", 2)) if report_mode not in ("split", "merged", "split_cascade"): raise ValueError( f"data.report_mode must be 'split', 'merged', or 'split_cascade', " f"got {report_mode!r}" ) if image_mode not in ("all_views_split", "frontal_only_split", "multi_image_merged"): raise ValueError( f"data.image_mode must be one of all_views_split / frontal_only_split / " f"multi_image_merged, got {image_mode!r}" ) # In single-image modes max_images must be 1; otherwise the dataset would # pad each sample to N>1 (wasted compute, possibly wrong behaviour). effective_max_images = max_images if image_mode == "multi_image_merged" else 1 if name not in SUPPORTED_DATASETS: raise ValueError( f"Unsupported dataset_name: {name!r}. " f"Expected one of {SUPPORTED_DATASETS}." ) # Extract configured task weights + enabled flags. # In "merged" mode findings_generation / impression_generation are ignored # in favour of report_generation. In "split" mode the opposite. tasks_cfg = train_cfg.tasks report_w = float(_get(tasks_cfg, "report_generation", type("_x", (), {"weight": 0.6, "enabled": True})()).weight) \ if _get(tasks_cfg, "report_generation") is not None else 0.6 all_weights = { "findings": float(tasks_cfg.findings_generation.weight) if tasks_cfg.findings_generation.enabled else 0.0, "impression": float(tasks_cfg.impression_generation.weight) if tasks_cfg.impression_generation.enabled else 0.0, "report": report_w if report_mode == "merged" else 0.0, "vqa": float(tasks_cfg.vqa.weight) if tasks_cfg.vqa.enabled else 0.0, } if report_mode == "merged": # Mute the now-unused single-section weights so they can't sneak back in. all_weights["findings"] = 0.0 all_weights["impression"] = 0.0 if name == "MIMIC-CXR": # All three tasks available (unchanged legacy behaviour) if report_mode == "merged": available = ["report", "vqa"] else: available = ["findings", "impression", "vqa"] image_root = train_cfg.data.mimic_cxr_root instruct_json = _ensure_mimic_json_exists( train_cfg.data, report_mode, image_mode ) elif name == "MIMIC-CXR_resized": # Same semantic dataset as MIMIC-CXR (all 3 tasks) but the on-disk # layout is the raw PhysioNet tree {root}/files/pXX/... and splits # come from mimic-cxr-2.0.0-split.csv instead of a pre-split dir # structure. Reuses the same builder with layout="files". if report_mode == "merged": available = ["report", "vqa"] else: available = ["findings", "impression", "vqa"] mr = train_cfg.data.mimic_cxr_resized image_root = mr.root instruct_json = _ensure_mimic_resized_json_exists( mr, report_mode, image_mode ) else: # IU-Xray # IU has no VQA. available = ["report"] if report_mode == "merged" else ["findings", "impression"] iu = train_cfg.data.iu_xray image_root = iu.images_dir instruct_json = _ensure_iu_json_exists(iu, report_mode, image_mode) # Keep only enabled tasks that actually exist in the dataset selected = [t for t in available if all_weights.get(t, 0.0) > 0] if not selected: raise ValueError( f"No enabled tasks match dataset {name}. " f"Enable at least one of {available} in `tasks:` config." ) weights = {t: all_weights[t] for t in selected} total = sum(weights.values()) weights = {t: w / total for t, w in weights.items()} return DatasetSpec( dataset_name = name, image_root = str(image_root), instruct_json = str(instruct_json), tasks = selected, task_weights = weights, report_mode = report_mode, image_mode = image_mode, max_images = effective_max_images, ) def _ensure_iu_json_exists(iu_cfg, report_mode: str = "split", image_mode: str = "all_views_split") -> str: """ Build the IU X-ray unified JSON if missing (auto_build=true). The cached JSON path is automatically suffixed with BOTH report_mode and image_mode (e.g. iu_xray_instruct__split__all_views_split.json) so any of the 6 mode combinations gets its own cached file and never overwrites a JSON built with different settings. """ base = Path(iu_cfg.instruct_json) out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") if out.is_file(): return str(out) auto = _get(iu_cfg, "auto_build", True) if not auto: raise FileNotFoundError( f"IU X-ray instruct JSON not found at {out} and auto_build=false. " f"Run: python -m data.iu_xray_builder --images_dir {iu_cfg.images_dir} " f"--labels_dir {iu_cfg.labels_dir} --output {out} " f"--report_mode {report_mode} --image_mode {image_mode}" ) # Lazy import to avoid pulling xml.etree on MIMIC-only runs from data.iu_xray_builder import build_iu_xray_instruct_json print(f"[dataset_resolver] IU X-ray JSON not found → auto-building " f"(report_mode={report_mode}, image_mode={image_mode}) …") build_iu_xray_instruct_json( images_dir = iu_cfg.images_dir, labels_dir = iu_cfg.labels_dir, output_path = str(out), train_ratio = float(_get(iu_cfg, "train_ratio", 0.70)), val_ratio = float(_get(iu_cfg, "val_ratio", 0.15)), test_ratio = float(_get(iu_cfg, "test_ratio", 0.15)), seed = int(_get(iu_cfg, "seed", 42)), image_suffix = str(_get(iu_cfg, "image_suffix", ".png")), report_mode = report_mode, image_mode = image_mode, ) return str(out) def _ensure_mimic_json_exists(data_cfg, report_mode: str = "split", image_mode: str = "all_views_split") -> str: """ Build the MIMIC-CXR unified JSON if missing. The configured `data.instruct_json` path is suffixed with both report_mode and image_mode (mimic_..._instruct__split__all_views_split.json) so each of the mode combinations gets its own cache and the RaDialog CheXpert-guided JSON never collides with one built under other settings. Auto-build (default on) reads `*chexpert*.csv` to bake the 14 oracle labels into structured_findings. Set `data.mimic_auto_build: false` to require a pre-built file instead. """ base = Path(_get(data_cfg, "instruct_json", "data/data_files/mimic_cxr_instruct_unified.json")) out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") if out.is_file(): return str(out) if not bool(_get(data_cfg, "mimic_auto_build", True)): raise FileNotFoundError( f"MIMIC instruct JSON not found at {out} and " f"data.mimic_auto_build=false. Run: python -m data.mimic_cxr_builder " f"--mimic_root {_get(data_cfg, 'mimic_cxr_root')} --output {out} " f"--report_mode {report_mode} --image_mode {image_mode}" ) from data.mimic_cxr_builder import build_mimic_cxr_instruct_json print(f"[dataset_resolver] MIMIC JSON not found → auto-building " f"(report_mode={report_mode}, image_mode={image_mode}) …") build_mimic_cxr_instruct_json( mimic_root = str(_get(data_cfg, "mimic_cxr_root")), output_path = str(out), chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"), vqa_root = _get(data_cfg, "mimic_vqa_root"), report_mode = report_mode, image_mode = image_mode, ) return str(out) def _ensure_mimic_resized_json_exists(mr_cfg, report_mode: str = "split", image_mode: str = "all_views_split") -> str: """ Build the MIMIC-CXR_resized unified JSON if missing. This dataset is **manifest-driven**, not directory-walking: - 3 manifest CSVs (manifest_{train,val,test}.csv) carry every row's split label, image/report relative path, and the 14 CheXpert labels as chex_* columns. No separate *split*.csv or *chexpert*.csv is read. - VQA is read from `vqa_dir/{vqa.json, vqa_val.json, vqa_test.json}`. The cache path is suffixed with report_mode+image_mode (same convention as the other two builders) so each mode combination gets its own cache. """ base = Path(_get(mr_cfg, "instruct_json", "data/data_files/mimic_cxr_resized_instruct.json")) out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") if out.is_file(): return str(out) if not bool(_get(mr_cfg, "auto_build", True)): raise FileNotFoundError( f"MIMIC-CXR_resized instruct JSON not found at {out} and " f"auto_build=false. Run: python -m data.mimic_cxr_resized_builder " f"--root {_get(mr_cfg, 'root')} --output {out} " f"--report_mode {report_mode} --image_mode {image_mode}" ) from data.mimic_cxr_resized_builder import build_mimic_cxr_resized_instruct_json print(f"[dataset_resolver] MIMIC-CXR_resized JSON not found → auto-building " f"(report_mode={report_mode}, image_mode={image_mode}) …") root_path = str(_get(mr_cfg, "root")) # Convention defaults: manifest CSVs sit at `root`, VQA at `{root}/vqa`. # Either can be overridden in config; an explicit empty string for # vqa_dir disables VQA entirely. manifest_dir = _get(mr_cfg, "manifest_dir") or root_path vqa_dir_cfg = _get(mr_cfg, "vqa_dir") if vqa_dir_cfg is None: vqa_dir = str(Path(root_path) / "vqa") elif vqa_dir_cfg == "": vqa_dir = None # explicit opt-out else: vqa_dir = str(vqa_dir_cfg) build_mimic_cxr_resized_instruct_json( root = root_path, manifest_dir = manifest_dir, output_path = str(out), vqa_dir = vqa_dir, reports_root = _get(mr_cfg, "reports_root"), report_mode = report_mode, image_mode = image_mode, ) return str(out) # ─── Run ID resolution (dataset-prefixed) ─────────────────────────────────── def resolve_run_id( dataset_name: str, output_root: str, state_file: str, resuming: bool, explicit: Optional[str] = None, hf_repo_id: Optional[str] = None, hf_token: Optional[str] = None, ) -> str: """ Pick a run_id of the form "{dataset_name}_run_{N}". Resolution order: 1. `explicit` flag (always wins) — pass --run_id to force a specific id (e.g. continue a run after VM restart without flagging it as resume). 2. `resuming=True` (i.e. --resume_from / --resume_from_hf): read state_file → fall back to latest run on local disk → fall back to latest run on HF Hub. 3. Fresh session: ALWAYS pick a brand-new id = max(local, remote) + 1. The local state file is NOT honoured here — a stale run_id.txt left over from a previous run would otherwise silently overwrite that run. Use `--run_id ` if you really mean to keep appending. """ prefix = f"{dataset_name}_run_" if explicit: _write_state(state_file, explicit) return explicit def _all_existing() -> List[int]: local = _scan_local_runs(output_root, prefix) remote = _scan_remote_runs(hf_repo_id, hf_token, prefix) return sorted(set(local) | set(remote)) state_path = Path(state_file) if resuming: if state_path.exists(): return state_path.read_text().strip() # No state file but user said --resume_from: pick the latest run # that exists anywhere (local OR remote) as best-effort fallback. existing = _all_existing() if existing: rid = f"{prefix}{max(existing)}" _write_state(state_file, rid) return rid raise RuntimeError( f"Cannot resume: no state file at {state_path}, no '{prefix}*' " f"folders under {output_root}, and none on HF Hub " f"({hf_repo_id or 'no repo configured'}). Pass --run_id explicitly." ) # Fresh session — always allocate a new id, ignoring stale state file. existing = _all_existing() next_n = (max(existing) + 1) if existing else 1 rid = f"{prefix}{next_n}" _write_state(state_file, rid) if existing: print(f"[resolve_run_id] fresh run → {rid} " f"(found existing: {[f'{prefix}{n}' for n in existing]})") else: print(f"[resolve_run_id] first run for this dataset → {rid}") return rid def _scan_remote_runs(repo_id: Optional[str], token: Optional[str], prefix: str) -> List[int]: """List existing 'N' folders on the HF Hub repo. Best-effort — returns [] on any failure (no token, no repo, network down, …).""" if not repo_id: return [] try: from huggingface_hub import HfApi api = HfApi(token=token) files = api.list_repo_files(repo_id, token=token) except Exception as e: print(f"[resolve_run_id] could not list HF runs ({type(e).__name__}: {e})") return [] rx = re.compile(rf"^{re.escape(prefix)}(\d+)(?:/|$)") nums = set() for f in files: m = rx.match(f) if m: nums.add(int(m.group(1))) return sorted(nums) def _scan_local_runs(output_root: str, prefix: str) -> List[int]: root = Path(output_root) if not root.is_dir(): return [] rx = re.compile(rf"^{re.escape(prefix)}(\d+)$") out = [] for d in root.iterdir(): if not d.is_dir(): continue m = rx.match(d.name) if m: out.append(int(m.group(1))) return sorted(out) def _write_state(state_file: str, run_id: str) -> None: p = Path(state_file) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(run_id) # ─── Path helpers ─────────────────────────────────────────────────────────── def run_dir(output_root: str, run_id: str) -> Path: """`{output_root}/{run_id}` — created if missing.""" p = Path(output_root) / run_id p.mkdir(parents=True, exist_ok=True) return p def stage_dir(output_root: str, run_id: str, subdir: str) -> str: """`{output_root}/{run_id}/{subdir}` as a string (for HF Trainer).""" p = run_dir(output_root, run_id) / subdir p.mkdir(parents=True, exist_ok=True) return str(p) # ─── Run-config snapshot ──────────────────────────────────────────────────── def save_run_config( run_dir_path, spec: "DatasetSpec", model_cfg, train_cfg, extra: Optional[Dict] = None, ) -> None: """ Persist a snapshot of the resolved config into the run directory so each run is self-describing. Writes: {run_dir}/configs/model_config.yaml — full OmegaConf dump {run_dir}/configs/train_config.yaml — full OmegaConf dump {run_dir}/run_meta.json — compact, human-readable summary `run_meta.json` is intentionally small: it carries the fields a person typically wants when comparing two runs side by side (dataset, training schedule, mode flags). The full YAML dumps are the source of truth. `extra` is merged into `run_meta.json` — useful for adding e.g. the git commit hash or the resume source. """ import json as _json from datetime import datetime, timezone try: from omegaconf import OmegaConf _to_yaml = lambda c: OmegaConf.to_yaml(c) _to_container = lambda c: OmegaConf.to_container(c, resolve=True) except Exception: _to_yaml = lambda c: str(c) _to_container = lambda c: dict(c) run_dir_path = Path(run_dir_path) cfg_dir = run_dir_path / "configs" cfg_dir.mkdir(parents=True, exist_ok=True) (cfg_dir / "model_config.yaml").write_text(_to_yaml(model_cfg), encoding="utf-8") (cfg_dir / "train_config.yaml").write_text(_to_yaml(train_cfg), encoding="utf-8") # Compact summary — only the fields that meaningfully change behaviour. stage1 = train_cfg.stage1 if "stage1" in _to_container(train_cfg) else {} stage2 = train_cfg.stage2 if "stage2" in _to_container(train_cfg) else {} meta = { "run_id": run_dir_path.name, "saved_at": datetime.now(timezone.utc).isoformat(timespec="seconds"), # — Data "dataset": spec.dataset_name, "image_root": spec.image_root, "instruct_json": spec.instruct_json, "report_mode": spec.report_mode, "image_mode": spec.image_mode, "max_images": spec.max_images, "tasks": spec.tasks, "task_weights": spec.task_weights, # — Training schedule "stage1": { "enabled": _get(stage1, "enabled", True), "num_epochs": _get(stage1, "num_epochs", None), "learning_rate": _get(stage1, "learning_rate", None), "freeze_llm": _get(stage1, "freeze_llm", True), "freeze_encoder": _get(stage1, "freeze_encoder", True), }, "stage2": { "enabled": _get(stage2, "enabled", True), "num_epochs": _get(stage2, "num_epochs", None), "learning_rate": _get(stage2, "learning_rate", None), "freeze_llm": _get(stage2, "freeze_llm", False), "freeze_encoder": _get(stage2, "freeze_encoder", True), }, "batch_size": _get(train_cfg.training, "per_device_train_batch_size", None), "grad_accum": _get(train_cfg.training, "gradient_accumulation_steps", None), "cutoff_len": _get(train_cfg.training, "cutoff_len", None), "fp16": _get(train_cfg.training, "fp16", None), "bf16": _get(train_cfg.training, "bf16", None), # — Model "llm": _get(model_cfg.llm, "name", None), "lora_r": _get(model_cfg.lora, "r", None), "lora_alpha": _get(model_cfg.lora, "lora_alpha", None), "num_image_tokens": _get(model_cfg.projection, "num_image_tokens", None), "chexpert_enabled": _get(model_cfg.chexpert_classifier, "enabled", None), } if extra: meta.update(extra) (run_dir_path / "run_meta.json").write_text( _json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8", ) print(f"[save_run_config] snapshot → {run_dir_path}/configs/, run_meta.json") # ─── Misc ─────────────────────────────────────────────────────────────────── def _get(obj, key: str, default=None): """OmegaConf-safe .get with default.""" try: v = getattr(obj, key) return v if v is not None else default except Exception: return default