| """ |
| dataset_resolver.py |
| ------------------- |
| Centralises the logic that decides, based on `train_cfg.data.dataset_name`: |
| |
| 1. Which tasks are trainable/evaluable for the chosen dataset. |
| 2. Where images live (`image_root`). |
| 3. Where the unified instruction JSON lives (and building it on-demand |
| for IU X-ray if missing). |
| 4. Task-weight normalization (dropping disabled tasks). |
| 5. The `run_id` used in all output paths: |
| {dataset_name}_run_{N} |
| Numbering scans the existing checkpoint directory — so re-running |
| with the same dataset auto-picks the next N without talking to HF. |
| |
| Keeping this out of train.py / evaluate.py means those two entry points |
| stay short, and the MIMIC-CXR code path is untouched. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
|
|
| SUPPORTED_DATASETS = ("MIMIC-CXR", "MIMIC-CXR_resized", "IU-Xray") |
|
|
|
|
| @dataclass |
| class DatasetSpec: |
| """Resolved data-layer configuration for a training/eval run.""" |
| dataset_name: str |
| image_root: str |
| instruct_json: str |
| tasks: List[str] |
| task_weights: Dict[str, float] |
| report_mode: str = "split" |
| image_mode: str = "all_views_split" |
| max_images: int = 1 |
|
|
|
|
| |
|
|
| def resolve_dataset_spec(train_cfg) -> DatasetSpec: |
| """ |
| Read `train_cfg.data.dataset_name` and return the matching DatasetSpec. |
| For IU-Xray this will also auto-build the instruction JSON if it's |
| missing and `iu_xray.auto_build == true`. |
| |
| The choice of which tasks are "available" depends on `data.report_mode`: |
| "split" → findings, impression (+ vqa for MIMIC) |
| "merged" → report (+ vqa for MIMIC) |
| "split_cascade" → findings, impression (+ vqa for MIMIC); same task set |
| and weights as "split" — only the data builder differs |
| (impression sample carries GT findings as context). |
| """ |
| name = _get(train_cfg.data, "dataset_name", "MIMIC-CXR") |
| report_mode = _get(train_cfg.data, "report_mode", "split") |
| image_mode = _get(train_cfg.data, "image_mode", "all_views_split") |
| max_images = int(_get(train_cfg.data, "max_images_per_sample", 2)) |
| if report_mode not in ("split", "merged", "split_cascade"): |
| raise ValueError( |
| f"data.report_mode must be 'split', 'merged', or 'split_cascade', " |
| f"got {report_mode!r}" |
| ) |
| if image_mode not in ("all_views_split", "frontal_only_split", "multi_image_merged"): |
| raise ValueError( |
| f"data.image_mode must be one of all_views_split / frontal_only_split / " |
| f"multi_image_merged, got {image_mode!r}" |
| ) |
| |
| |
| effective_max_images = max_images if image_mode == "multi_image_merged" else 1 |
|
|
| if name not in SUPPORTED_DATASETS: |
| raise ValueError( |
| f"Unsupported dataset_name: {name!r}. " |
| f"Expected one of {SUPPORTED_DATASETS}." |
| ) |
|
|
| |
| |
| |
| tasks_cfg = train_cfg.tasks |
| report_w = float(_get(tasks_cfg, "report_generation", |
| type("_x", (), {"weight": 0.6, "enabled": True})()).weight) \ |
| if _get(tasks_cfg, "report_generation") is not None else 0.6 |
|
|
| all_weights = { |
| "findings": float(tasks_cfg.findings_generation.weight) |
| if tasks_cfg.findings_generation.enabled else 0.0, |
| "impression": float(tasks_cfg.impression_generation.weight) |
| if tasks_cfg.impression_generation.enabled else 0.0, |
| "report": report_w if report_mode == "merged" else 0.0, |
| "vqa": float(tasks_cfg.vqa.weight) |
| if tasks_cfg.vqa.enabled else 0.0, |
| } |
| if report_mode == "merged": |
| |
| all_weights["findings"] = 0.0 |
| all_weights["impression"] = 0.0 |
|
|
| if name == "MIMIC-CXR": |
| |
| if report_mode == "merged": |
| available = ["report", "vqa"] |
| else: |
| available = ["findings", "impression", "vqa"] |
| image_root = train_cfg.data.mimic_cxr_root |
| instruct_json = _ensure_mimic_json_exists( |
| train_cfg.data, report_mode, image_mode |
| ) |
|
|
| elif name == "MIMIC-CXR_resized": |
| |
| |
| |
| |
| if report_mode == "merged": |
| available = ["report", "vqa"] |
| else: |
| available = ["findings", "impression", "vqa"] |
| mr = train_cfg.data.mimic_cxr_resized |
| image_root = mr.root |
| instruct_json = _ensure_mimic_resized_json_exists( |
| mr, report_mode, image_mode |
| ) |
|
|
| else: |
| |
| available = ["report"] if report_mode == "merged" else ["findings", "impression"] |
| iu = train_cfg.data.iu_xray |
| image_root = iu.images_dir |
| instruct_json = _ensure_iu_json_exists(iu, report_mode, image_mode) |
|
|
| |
| selected = [t for t in available if all_weights.get(t, 0.0) > 0] |
| if not selected: |
| raise ValueError( |
| f"No enabled tasks match dataset {name}. " |
| f"Enable at least one of {available} in `tasks:` config." |
| ) |
|
|
| weights = {t: all_weights[t] for t in selected} |
| total = sum(weights.values()) |
| weights = {t: w / total for t, w in weights.items()} |
|
|
| return DatasetSpec( |
| dataset_name = name, |
| image_root = str(image_root), |
| instruct_json = str(instruct_json), |
| tasks = selected, |
| task_weights = weights, |
| report_mode = report_mode, |
| image_mode = image_mode, |
| max_images = effective_max_images, |
| ) |
|
|
|
|
| def _ensure_iu_json_exists(iu_cfg, |
| report_mode: str = "split", |
| image_mode: str = "all_views_split") -> str: |
| """ |
| Build the IU X-ray unified JSON if missing (auto_build=true). |
| |
| The cached JSON path is automatically suffixed with BOTH report_mode and |
| image_mode (e.g. iu_xray_instruct__split__all_views_split.json) so any |
| of the 6 mode combinations gets its own cached file and never overwrites |
| a JSON built with different settings. |
| """ |
| base = Path(iu_cfg.instruct_json) |
| out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") |
|
|
| if out.is_file(): |
| return str(out) |
|
|
| auto = _get(iu_cfg, "auto_build", True) |
| if not auto: |
| raise FileNotFoundError( |
| f"IU X-ray instruct JSON not found at {out} and auto_build=false. " |
| f"Run: python -m data.iu_xray_builder --images_dir {iu_cfg.images_dir} " |
| f"--labels_dir {iu_cfg.labels_dir} --output {out} " |
| f"--report_mode {report_mode} --image_mode {image_mode}" |
| ) |
|
|
| |
| from data.iu_xray_builder import build_iu_xray_instruct_json |
|
|
| print(f"[dataset_resolver] IU X-ray JSON not found → auto-building " |
| f"(report_mode={report_mode}, image_mode={image_mode}) …") |
| build_iu_xray_instruct_json( |
| images_dir = iu_cfg.images_dir, |
| labels_dir = iu_cfg.labels_dir, |
| output_path = str(out), |
| train_ratio = float(_get(iu_cfg, "train_ratio", 0.70)), |
| val_ratio = float(_get(iu_cfg, "val_ratio", 0.15)), |
| test_ratio = float(_get(iu_cfg, "test_ratio", 0.15)), |
| seed = int(_get(iu_cfg, "seed", 42)), |
| image_suffix = str(_get(iu_cfg, "image_suffix", ".png")), |
| report_mode = report_mode, |
| image_mode = image_mode, |
| ) |
| return str(out) |
|
|
|
|
| def _ensure_mimic_json_exists(data_cfg, |
| report_mode: str = "split", |
| image_mode: str = "all_views_split") -> str: |
| """ |
| Build the MIMIC-CXR unified JSON if missing. |
| |
| The configured `data.instruct_json` path is suffixed with both |
| report_mode and image_mode (mimic_..._instruct__split__all_views_split.json) |
| so each of the mode combinations gets its own cache and the RaDialog |
| CheXpert-guided JSON never collides with one built under other settings. |
| |
| Auto-build (default on) reads `*chexpert*.csv` to bake the 14 oracle |
| labels into structured_findings. Set `data.mimic_auto_build: false` to |
| require a pre-built file instead. |
| """ |
| base = Path(_get(data_cfg, "instruct_json", |
| "data/data_files/mimic_cxr_instruct_unified.json")) |
| out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") |
| if out.is_file(): |
| return str(out) |
|
|
| if not bool(_get(data_cfg, "mimic_auto_build", True)): |
| raise FileNotFoundError( |
| f"MIMIC instruct JSON not found at {out} and " |
| f"data.mimic_auto_build=false. Run: python -m data.mimic_cxr_builder " |
| f"--mimic_root {_get(data_cfg, 'mimic_cxr_root')} --output {out} " |
| f"--report_mode {report_mode} --image_mode {image_mode}" |
| ) |
|
|
| from data.mimic_cxr_builder import build_mimic_cxr_instruct_json |
| print(f"[dataset_resolver] MIMIC JSON not found → auto-building " |
| f"(report_mode={report_mode}, image_mode={image_mode}) …") |
| build_mimic_cxr_instruct_json( |
| mimic_root = str(_get(data_cfg, "mimic_cxr_root")), |
| output_path = str(out), |
| chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"), |
| vqa_root = _get(data_cfg, "mimic_vqa_root"), |
| report_mode = report_mode, |
| image_mode = image_mode, |
| ) |
| return str(out) |
|
|
|
|
| def _ensure_mimic_resized_json_exists(mr_cfg, |
| report_mode: str = "split", |
| image_mode: str = "all_views_split") -> str: |
| """ |
| Build the MIMIC-CXR_resized unified JSON if missing. |
| |
| This dataset is **manifest-driven**, not directory-walking: |
| - 3 manifest CSVs (manifest_{train,val,test}.csv) carry every row's |
| split label, image/report relative path, and the 14 CheXpert |
| labels as chex_* columns. No separate *split*.csv or *chexpert*.csv |
| is read. |
| - VQA is read from `vqa_dir/{vqa.json, vqa_val.json, vqa_test.json}`. |
| |
| The cache path is suffixed with report_mode+image_mode (same convention |
| as the other two builders) so each mode combination gets its own cache. |
| """ |
| base = Path(_get(mr_cfg, "instruct_json", |
| "data/data_files/mimic_cxr_resized_instruct.json")) |
| out = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}") |
| if out.is_file(): |
| return str(out) |
|
|
| if not bool(_get(mr_cfg, "auto_build", True)): |
| raise FileNotFoundError( |
| f"MIMIC-CXR_resized instruct JSON not found at {out} and " |
| f"auto_build=false. Run: python -m data.mimic_cxr_resized_builder " |
| f"--root {_get(mr_cfg, 'root')} --output {out} " |
| f"--report_mode {report_mode} --image_mode {image_mode}" |
| ) |
|
|
| from data.mimic_cxr_resized_builder import build_mimic_cxr_resized_instruct_json |
| print(f"[dataset_resolver] MIMIC-CXR_resized JSON not found → auto-building " |
| f"(report_mode={report_mode}, image_mode={image_mode}) …") |
| root_path = str(_get(mr_cfg, "root")) |
| |
| |
| |
| manifest_dir = _get(mr_cfg, "manifest_dir") or root_path |
| vqa_dir_cfg = _get(mr_cfg, "vqa_dir") |
| if vqa_dir_cfg is None: |
| vqa_dir = str(Path(root_path) / "vqa") |
| elif vqa_dir_cfg == "": |
| vqa_dir = None |
| else: |
| vqa_dir = str(vqa_dir_cfg) |
| build_mimic_cxr_resized_instruct_json( |
| root = root_path, |
| manifest_dir = manifest_dir, |
| output_path = str(out), |
| vqa_dir = vqa_dir, |
| reports_root = _get(mr_cfg, "reports_root"), |
| report_mode = report_mode, |
| image_mode = image_mode, |
| ) |
| return str(out) |
|
|
|
|
| |
|
|
| def resolve_run_id( |
| dataset_name: str, |
| output_root: str, |
| state_file: str, |
| resuming: bool, |
| explicit: Optional[str] = None, |
| hf_repo_id: Optional[str] = None, |
| hf_token: Optional[str] = None, |
| ) -> str: |
| """ |
| Pick a run_id of the form "{dataset_name}_run_{N}". |
| |
| Resolution order: |
| 1. `explicit` flag (always wins) — pass --run_id to force a specific id |
| (e.g. continue a run after VM restart without flagging it as resume). |
| 2. `resuming=True` (i.e. --resume_from / --resume_from_hf): read |
| state_file → fall back to latest run on local disk → fall back to |
| latest run on HF Hub. |
| 3. Fresh session: ALWAYS pick a brand-new id = max(local, remote) + 1. |
| The local state file is NOT honoured here — a stale run_id.txt left |
| over from a previous run would otherwise silently overwrite that run. |
| Use `--run_id <name>` if you really mean to keep appending. |
| """ |
| prefix = f"{dataset_name}_run_" |
|
|
| if explicit: |
| _write_state(state_file, explicit) |
| return explicit |
|
|
| def _all_existing() -> List[int]: |
| local = _scan_local_runs(output_root, prefix) |
| remote = _scan_remote_runs(hf_repo_id, hf_token, prefix) |
| return sorted(set(local) | set(remote)) |
|
|
| state_path = Path(state_file) |
| if resuming: |
| if state_path.exists(): |
| return state_path.read_text().strip() |
| |
| |
| existing = _all_existing() |
| if existing: |
| rid = f"{prefix}{max(existing)}" |
| _write_state(state_file, rid) |
| return rid |
| raise RuntimeError( |
| f"Cannot resume: no state file at {state_path}, no '{prefix}*' " |
| f"folders under {output_root}, and none on HF Hub " |
| f"({hf_repo_id or 'no repo configured'}). Pass --run_id explicitly." |
| ) |
|
|
| |
| existing = _all_existing() |
| next_n = (max(existing) + 1) if existing else 1 |
| rid = f"{prefix}{next_n}" |
| _write_state(state_file, rid) |
| if existing: |
| print(f"[resolve_run_id] fresh run → {rid} " |
| f"(found existing: {[f'{prefix}{n}' for n in existing]})") |
| else: |
| print(f"[resolve_run_id] first run for this dataset → {rid}") |
| return rid |
|
|
|
|
| def _scan_remote_runs(repo_id: Optional[str], token: Optional[str], prefix: str) -> List[int]: |
| """List existing '<prefix>N' folders on the HF Hub repo. Best-effort — |
| returns [] on any failure (no token, no repo, network down, …).""" |
| if not repo_id: |
| return [] |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi(token=token) |
| files = api.list_repo_files(repo_id, token=token) |
| except Exception as e: |
| print(f"[resolve_run_id] could not list HF runs ({type(e).__name__}: {e})") |
| return [] |
| rx = re.compile(rf"^{re.escape(prefix)}(\d+)(?:/|$)") |
| nums = set() |
| for f in files: |
| m = rx.match(f) |
| if m: |
| nums.add(int(m.group(1))) |
| return sorted(nums) |
|
|
|
|
| def _scan_local_runs(output_root: str, prefix: str) -> List[int]: |
| root = Path(output_root) |
| if not root.is_dir(): |
| return [] |
| rx = re.compile(rf"^{re.escape(prefix)}(\d+)$") |
| out = [] |
| for d in root.iterdir(): |
| if not d.is_dir(): |
| continue |
| m = rx.match(d.name) |
| if m: |
| out.append(int(m.group(1))) |
| return sorted(out) |
|
|
|
|
| def _write_state(state_file: str, run_id: str) -> None: |
| p = Path(state_file) |
| p.parent.mkdir(parents=True, exist_ok=True) |
| p.write_text(run_id) |
|
|
|
|
| |
|
|
| def run_dir(output_root: str, run_id: str) -> Path: |
| """`{output_root}/{run_id}` — created if missing.""" |
| p = Path(output_root) / run_id |
| p.mkdir(parents=True, exist_ok=True) |
| return p |
|
|
|
|
| def stage_dir(output_root: str, run_id: str, subdir: str) -> str: |
| """`{output_root}/{run_id}/{subdir}` as a string (for HF Trainer).""" |
| p = run_dir(output_root, run_id) / subdir |
| p.mkdir(parents=True, exist_ok=True) |
| return str(p) |
|
|
|
|
| |
|
|
| def save_run_config( |
| run_dir_path, |
| spec: "DatasetSpec", |
| model_cfg, |
| train_cfg, |
| extra: Optional[Dict] = None, |
| ) -> None: |
| """ |
| Persist a snapshot of the resolved config into the run directory so each |
| run is self-describing. Writes: |
| |
| {run_dir}/configs/model_config.yaml — full OmegaConf dump |
| {run_dir}/configs/train_config.yaml — full OmegaConf dump |
| {run_dir}/run_meta.json — compact, human-readable summary |
| |
| `run_meta.json` is intentionally small: it carries the fields a person |
| typically wants when comparing two runs side by side (dataset, training |
| schedule, mode flags). The full YAML dumps are the source of truth. |
| |
| `extra` is merged into `run_meta.json` — useful for adding e.g. the git |
| commit hash or the resume source. |
| """ |
| import json as _json |
| from datetime import datetime, timezone |
|
|
| try: |
| from omegaconf import OmegaConf |
| _to_yaml = lambda c: OmegaConf.to_yaml(c) |
| _to_container = lambda c: OmegaConf.to_container(c, resolve=True) |
| except Exception: |
| _to_yaml = lambda c: str(c) |
| _to_container = lambda c: dict(c) |
|
|
| run_dir_path = Path(run_dir_path) |
| cfg_dir = run_dir_path / "configs" |
| cfg_dir.mkdir(parents=True, exist_ok=True) |
|
|
| (cfg_dir / "model_config.yaml").write_text(_to_yaml(model_cfg), encoding="utf-8") |
| (cfg_dir / "train_config.yaml").write_text(_to_yaml(train_cfg), encoding="utf-8") |
|
|
| |
| stage1 = train_cfg.stage1 if "stage1" in _to_container(train_cfg) else {} |
| stage2 = train_cfg.stage2 if "stage2" in _to_container(train_cfg) else {} |
|
|
| meta = { |
| "run_id": run_dir_path.name, |
| "saved_at": datetime.now(timezone.utc).isoformat(timespec="seconds"), |
|
|
| |
| "dataset": spec.dataset_name, |
| "image_root": spec.image_root, |
| "instruct_json": spec.instruct_json, |
| "report_mode": spec.report_mode, |
| "image_mode": spec.image_mode, |
| "max_images": spec.max_images, |
| "tasks": spec.tasks, |
| "task_weights": spec.task_weights, |
|
|
| |
| "stage1": { |
| "enabled": _get(stage1, "enabled", True), |
| "num_epochs": _get(stage1, "num_epochs", None), |
| "learning_rate": _get(stage1, "learning_rate", None), |
| "freeze_llm": _get(stage1, "freeze_llm", True), |
| "freeze_encoder": _get(stage1, "freeze_encoder", True), |
| }, |
| "stage2": { |
| "enabled": _get(stage2, "enabled", True), |
| "num_epochs": _get(stage2, "num_epochs", None), |
| "learning_rate": _get(stage2, "learning_rate", None), |
| "freeze_llm": _get(stage2, "freeze_llm", False), |
| "freeze_encoder": _get(stage2, "freeze_encoder", True), |
| }, |
| "batch_size": _get(train_cfg.training, "per_device_train_batch_size", None), |
| "grad_accum": _get(train_cfg.training, "gradient_accumulation_steps", None), |
| "cutoff_len": _get(train_cfg.training, "cutoff_len", None), |
| "fp16": _get(train_cfg.training, "fp16", None), |
| "bf16": _get(train_cfg.training, "bf16", None), |
|
|
| |
| "llm": _get(model_cfg.llm, "name", None), |
| "lora_r": _get(model_cfg.lora, "r", None), |
| "lora_alpha": _get(model_cfg.lora, "lora_alpha", None), |
| "num_image_tokens": _get(model_cfg.projection, "num_image_tokens", None), |
| "chexpert_enabled": _get(model_cfg.chexpert_classifier, "enabled", None), |
| } |
| if extra: |
| meta.update(extra) |
|
|
| (run_dir_path / "run_meta.json").write_text( |
| _json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8", |
| ) |
| print(f"[save_run_config] snapshot → {run_dir_path}/configs/, run_meta.json") |
|
|
|
|
| |
|
|
| def _get(obj, key: str, default=None): |
| """OmegaConf-safe .get with default.""" |
| try: |
| v = getattr(obj, key) |
| return v if v is not None else default |
| except Exception: |
| return default |
|
|