File size: 22,635 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b961b41
28b13fc
 
 
 
 
 
 
 
 
 
02426e6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
 
 
28b13fc
 
 
 
 
02426e6
28b13fc
02426e6
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
28b13fc
b961b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215ecd6
 
 
 
 
 
02426e6
 
 
 
b961b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
"""
dataset_resolver.py
-------------------
Centralises the logic that decides, based on `train_cfg.data.dataset_name`:

  1. Which tasks are trainable/evaluable for the chosen dataset.
  2. Where images live (`image_root`).
  3. Where the unified instruction JSON lives (and building it on-demand
     for IU X-ray if missing).
  4. Task-weight normalization (dropping disabled tasks).
  5. The `run_id` used in all output paths:
        {dataset_name}_run_{N}
     Numbering scans the existing checkpoint directory — so re-running
     with the same dataset auto-picks the next N without talking to HF.

Keeping this out of train.py / evaluate.py means those two entry points
stay short, and the MIMIC-CXR code path is untouched.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional


SUPPORTED_DATASETS = ("MIMIC-CXR", "MIMIC-CXR_resized", "IU-Xray")


@dataclass
class DatasetSpec:
    """Resolved data-layer configuration for a training/eval run."""
    dataset_name: str            # "MIMIC-CXR" or "IU-Xray"
    image_root:   str            # passed to CXRInstructDataset
    instruct_json: str           # passed to CXRInstructDataset
    tasks:        List[str]      # which tasks exist in this dataset
    task_weights: Dict[str, float]  # normalized over `tasks`
    report_mode:  str = "split"       # "split" | "merged" | "split_cascade"
    image_mode:   str = "all_views_split"  # "all_views_split" | "frontal_only_split" | "multi_image_merged"
    max_images:   int = 1             # >1 only when image_mode == multi_image_merged


# ─── Dataset resolution ─────────────────────────────────────────────────────

def resolve_dataset_spec(train_cfg) -> DatasetSpec:
    """
    Read `train_cfg.data.dataset_name` and return the matching DatasetSpec.
    For IU-Xray this will also auto-build the instruction JSON if it's
    missing and `iu_xray.auto_build == true`.

    The choice of which tasks are "available" depends on `data.report_mode`:
      "split"         → findings, impression (+ vqa for MIMIC)
      "merged"        → report (+ vqa for MIMIC)
      "split_cascade" → findings, impression (+ vqa for MIMIC); same task set
                        and weights as "split" — only the data builder differs
                        (impression sample carries GT findings as context).
    """
    name = _get(train_cfg.data, "dataset_name", "MIMIC-CXR")
    report_mode = _get(train_cfg.data, "report_mode", "split")
    image_mode  = _get(train_cfg.data, "image_mode",  "all_views_split")
    max_images  = int(_get(train_cfg.data, "max_images_per_sample", 2))
    if report_mode not in ("split", "merged", "split_cascade"):
        raise ValueError(
            f"data.report_mode must be 'split', 'merged', or 'split_cascade', "
            f"got {report_mode!r}"
        )
    if image_mode not in ("all_views_split", "frontal_only_split", "multi_image_merged"):
        raise ValueError(
            f"data.image_mode must be one of all_views_split / frontal_only_split / "
            f"multi_image_merged, got {image_mode!r}"
        )
    # In single-image modes max_images must be 1; otherwise the dataset would
    # pad each sample to N>1 (wasted compute, possibly wrong behaviour).
    effective_max_images = max_images if image_mode == "multi_image_merged" else 1

    if name not in SUPPORTED_DATASETS:
        raise ValueError(
            f"Unsupported dataset_name: {name!r}. "
            f"Expected one of {SUPPORTED_DATASETS}."
        )

    # Extract configured task weights + enabled flags.
    # In "merged" mode findings_generation / impression_generation are ignored
    # in favour of report_generation. In "split" mode the opposite.
    tasks_cfg = train_cfg.tasks
    report_w = float(_get(tasks_cfg, "report_generation",
                          type("_x", (), {"weight": 0.6, "enabled": True})()).weight) \
               if _get(tasks_cfg, "report_generation") is not None else 0.6

    all_weights = {
        "findings":   float(tasks_cfg.findings_generation.weight)
                        if tasks_cfg.findings_generation.enabled else 0.0,
        "impression": float(tasks_cfg.impression_generation.weight)
                        if tasks_cfg.impression_generation.enabled else 0.0,
        "report":     report_w if report_mode == "merged" else 0.0,
        "vqa":        float(tasks_cfg.vqa.weight)
                        if tasks_cfg.vqa.enabled else 0.0,
    }
    if report_mode == "merged":
        # Mute the now-unused single-section weights so they can't sneak back in.
        all_weights["findings"]   = 0.0
        all_weights["impression"] = 0.0

    if name == "MIMIC-CXR":
        # All three tasks available (unchanged legacy behaviour)
        if report_mode == "merged":
            available = ["report", "vqa"]
        else:
            available = ["findings", "impression", "vqa"]
        image_root    = train_cfg.data.mimic_cxr_root
        instruct_json = _ensure_mimic_json_exists(
            train_cfg.data, report_mode, image_mode
        )

    elif name == "MIMIC-CXR_resized":
        # Same semantic dataset as MIMIC-CXR (all 3 tasks) but the on-disk
        # layout is the raw PhysioNet tree {root}/files/pXX/... and splits
        # come from mimic-cxr-2.0.0-split.csv instead of a pre-split dir
        # structure. Reuses the same builder with layout="files".
        if report_mode == "merged":
            available = ["report", "vqa"]
        else:
            available = ["findings", "impression", "vqa"]
        mr = train_cfg.data.mimic_cxr_resized
        image_root    = mr.root
        instruct_json = _ensure_mimic_resized_json_exists(
            mr, report_mode, image_mode
        )

    else:  # IU-Xray
        # IU has no VQA.
        available = ["report"] if report_mode == "merged" else ["findings", "impression"]
        iu = train_cfg.data.iu_xray
        image_root    = iu.images_dir
        instruct_json = _ensure_iu_json_exists(iu, report_mode, image_mode)

    # Keep only enabled tasks that actually exist in the dataset
    selected = [t for t in available if all_weights.get(t, 0.0) > 0]
    if not selected:
        raise ValueError(
            f"No enabled tasks match dataset {name}. "
            f"Enable at least one of {available} in `tasks:` config."
        )

    weights = {t: all_weights[t] for t in selected}
    total   = sum(weights.values())
    weights = {t: w / total for t, w in weights.items()}

    return DatasetSpec(
        dataset_name  = name,
        image_root    = str(image_root),
        instruct_json = str(instruct_json),
        tasks         = selected,
        task_weights  = weights,
        report_mode   = report_mode,
        image_mode    = image_mode,
        max_images    = effective_max_images,
    )


def _ensure_iu_json_exists(iu_cfg,
                           report_mode: str = "split",
                           image_mode:  str = "all_views_split") -> str:
    """
    Build the IU X-ray unified JSON if missing (auto_build=true).

    The cached JSON path is automatically suffixed with BOTH report_mode and
    image_mode (e.g. iu_xray_instruct__split__all_views_split.json) so any
    of the 6 mode combinations gets its own cached file and never overwrites
    a JSON built with different settings.
    """
    base = Path(iu_cfg.instruct_json)
    out  = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")

    if out.is_file():
        return str(out)

    auto = _get(iu_cfg, "auto_build", True)
    if not auto:
        raise FileNotFoundError(
            f"IU X-ray instruct JSON not found at {out} and auto_build=false. "
            f"Run: python -m data.iu_xray_builder --images_dir {iu_cfg.images_dir} "
            f"--labels_dir {iu_cfg.labels_dir} --output {out} "
            f"--report_mode {report_mode} --image_mode {image_mode}"
        )

    # Lazy import to avoid pulling xml.etree on MIMIC-only runs
    from data.iu_xray_builder import build_iu_xray_instruct_json

    print(f"[dataset_resolver] IU X-ray JSON not found → auto-building "
          f"(report_mode={report_mode}, image_mode={image_mode}) …")
    build_iu_xray_instruct_json(
        images_dir   = iu_cfg.images_dir,
        labels_dir   = iu_cfg.labels_dir,
        output_path  = str(out),
        train_ratio  = float(_get(iu_cfg, "train_ratio", 0.70)),
        val_ratio    = float(_get(iu_cfg, "val_ratio",   0.15)),
        test_ratio   = float(_get(iu_cfg, "test_ratio",  0.15)),
        seed         = int(_get(iu_cfg, "seed", 42)),
        image_suffix = str(_get(iu_cfg, "image_suffix", ".png")),
        report_mode  = report_mode,
        image_mode   = image_mode,
    )
    return str(out)


def _ensure_mimic_json_exists(data_cfg,
                              report_mode: str = "split",
                              image_mode:  str = "all_views_split") -> str:
    """
    Build the MIMIC-CXR unified JSON if missing.

    The configured `data.instruct_json` path is suffixed with both
    report_mode and image_mode (mimic_..._instruct__split__all_views_split.json)
    so each of the mode combinations gets its own cache and the RaDialog
    CheXpert-guided JSON never collides with one built under other settings.

    Auto-build (default on) reads `*chexpert*.csv` to bake the 14 oracle
    labels into structured_findings. Set `data.mimic_auto_build: false` to
    require a pre-built file instead.
    """
    base = Path(_get(data_cfg, "instruct_json",
                     "data/data_files/mimic_cxr_instruct_unified.json"))
    out  = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")
    if out.is_file():
        return str(out)

    if not bool(_get(data_cfg, "mimic_auto_build", True)):
        raise FileNotFoundError(
            f"MIMIC instruct JSON not found at {out} and "
            f"data.mimic_auto_build=false. Run: python -m data.mimic_cxr_builder "
            f"--mimic_root {_get(data_cfg, 'mimic_cxr_root')} --output {out} "
            f"--report_mode {report_mode} --image_mode {image_mode}"
        )

    from data.mimic_cxr_builder import build_mimic_cxr_instruct_json
    print(f"[dataset_resolver] MIMIC JSON not found → auto-building "
          f"(report_mode={report_mode}, image_mode={image_mode}) …")
    build_mimic_cxr_instruct_json(
        mimic_root   = str(_get(data_cfg, "mimic_cxr_root")),
        output_path  = str(out),
        chexpert_csv = _get(data_cfg, "mimic_chexpert_csv"),
        vqa_root     = _get(data_cfg, "mimic_vqa_root"),
        report_mode  = report_mode,
        image_mode   = image_mode,
    )
    return str(out)


def _ensure_mimic_resized_json_exists(mr_cfg,
                                      report_mode: str = "split",
                                      image_mode:  str = "all_views_split") -> str:
    """
    Build the MIMIC-CXR_resized unified JSON if missing.

    This dataset is **manifest-driven**, not directory-walking:
        - 3 manifest CSVs (manifest_{train,val,test}.csv) carry every row's
          split label, image/report relative path, and the 14 CheXpert
          labels as chex_* columns. No separate *split*.csv or *chexpert*.csv
          is read.
        - VQA is read from `vqa_dir/{vqa.json, vqa_val.json, vqa_test.json}`.

    The cache path is suffixed with report_mode+image_mode (same convention
    as the other two builders) so each mode combination gets its own cache.
    """
    base = Path(_get(mr_cfg, "instruct_json",
                     "data/data_files/mimic_cxr_resized_instruct.json"))
    out  = base.with_name(f"{base.stem}__{report_mode}__{image_mode}{base.suffix}")
    if out.is_file():
        return str(out)

    if not bool(_get(mr_cfg, "auto_build", True)):
        raise FileNotFoundError(
            f"MIMIC-CXR_resized instruct JSON not found at {out} and "
            f"auto_build=false. Run: python -m data.mimic_cxr_resized_builder "
            f"--root {_get(mr_cfg, 'root')} --output {out} "
            f"--report_mode {report_mode} --image_mode {image_mode}"
        )

    from data.mimic_cxr_resized_builder import build_mimic_cxr_resized_instruct_json
    print(f"[dataset_resolver] MIMIC-CXR_resized JSON not found → auto-building "
          f"(report_mode={report_mode}, image_mode={image_mode}) …")
    root_path = str(_get(mr_cfg, "root"))
    # Convention defaults: manifest CSVs sit at `root`, VQA at `{root}/vqa`.
    # Either can be overridden in config; an explicit empty string for
    # vqa_dir disables VQA entirely.
    manifest_dir = _get(mr_cfg, "manifest_dir") or root_path
    vqa_dir_cfg  = _get(mr_cfg, "vqa_dir")
    if vqa_dir_cfg is None:
        vqa_dir = str(Path(root_path) / "vqa")
    elif vqa_dir_cfg == "":
        vqa_dir = None     # explicit opt-out
    else:
        vqa_dir = str(vqa_dir_cfg)
    build_mimic_cxr_resized_instruct_json(
        root         = root_path,
        manifest_dir = manifest_dir,
        output_path  = str(out),
        vqa_dir      = vqa_dir,
        reports_root = _get(mr_cfg, "reports_root"),
        report_mode  = report_mode,
        image_mode   = image_mode,
    )
    return str(out)


# ─── Run ID resolution (dataset-prefixed) ───────────────────────────────────

def resolve_run_id(
    dataset_name:     str,
    output_root:      str,
    state_file:       str,
    resuming:         bool,
    explicit:         Optional[str] = None,
    hf_repo_id:       Optional[str] = None,
    hf_token:         Optional[str] = None,
) -> str:
    """
    Pick a run_id of the form "{dataset_name}_run_{N}".

    Resolution order:
      1. `explicit` flag (always wins) — pass --run_id to force a specific id
         (e.g. continue a run after VM restart without flagging it as resume).
      2. `resuming=True` (i.e. --resume_from / --resume_from_hf): read
         state_file → fall back to latest run on local disk → fall back to
         latest run on HF Hub.
      3. Fresh session: ALWAYS pick a brand-new id = max(local, remote) + 1.
         The local state file is NOT honoured here — a stale run_id.txt left
         over from a previous run would otherwise silently overwrite that run.
         Use `--run_id <name>` if you really mean to keep appending.
    """
    prefix = f"{dataset_name}_run_"

    if explicit:
        _write_state(state_file, explicit)
        return explicit

    def _all_existing() -> List[int]:
        local  = _scan_local_runs(output_root, prefix)
        remote = _scan_remote_runs(hf_repo_id, hf_token, prefix)
        return sorted(set(local) | set(remote))

    state_path = Path(state_file)
    if resuming:
        if state_path.exists():
            return state_path.read_text().strip()
        # No state file but user said --resume_from: pick the latest run
        # that exists anywhere (local OR remote) as best-effort fallback.
        existing = _all_existing()
        if existing:
            rid = f"{prefix}{max(existing)}"
            _write_state(state_file, rid)
            return rid
        raise RuntimeError(
            f"Cannot resume: no state file at {state_path}, no '{prefix}*' "
            f"folders under {output_root}, and none on HF Hub "
            f"({hf_repo_id or 'no repo configured'}). Pass --run_id explicitly."
        )

    # Fresh session — always allocate a new id, ignoring stale state file.
    existing = _all_existing()
    next_n = (max(existing) + 1) if existing else 1
    rid = f"{prefix}{next_n}"
    _write_state(state_file, rid)
    if existing:
        print(f"[resolve_run_id] fresh run → {rid} "
              f"(found existing: {[f'{prefix}{n}' for n in existing]})")
    else:
        print(f"[resolve_run_id] first run for this dataset → {rid}")
    return rid


def _scan_remote_runs(repo_id: Optional[str], token: Optional[str], prefix: str) -> List[int]:
    """List existing '<prefix>N' folders on the HF Hub repo. Best-effort —
    returns [] on any failure (no token, no repo, network down, …)."""
    if not repo_id:
        return []
    try:
        from huggingface_hub import HfApi
        api   = HfApi(token=token)
        files = api.list_repo_files(repo_id, token=token)
    except Exception as e:
        print(f"[resolve_run_id] could not list HF runs ({type(e).__name__}: {e})")
        return []
    rx = re.compile(rf"^{re.escape(prefix)}(\d+)(?:/|$)")
    nums = set()
    for f in files:
        m = rx.match(f)
        if m:
            nums.add(int(m.group(1)))
    return sorted(nums)


def _scan_local_runs(output_root: str, prefix: str) -> List[int]:
    root = Path(output_root)
    if not root.is_dir():
        return []
    rx = re.compile(rf"^{re.escape(prefix)}(\d+)$")
    out = []
    for d in root.iterdir():
        if not d.is_dir():
            continue
        m = rx.match(d.name)
        if m:
            out.append(int(m.group(1)))
    return sorted(out)


def _write_state(state_file: str, run_id: str) -> None:
    p = Path(state_file)
    p.parent.mkdir(parents=True, exist_ok=True)
    p.write_text(run_id)


# ─── Path helpers ───────────────────────────────────────────────────────────

def run_dir(output_root: str, run_id: str) -> Path:
    """`{output_root}/{run_id}`  — created if missing."""
    p = Path(output_root) / run_id
    p.mkdir(parents=True, exist_ok=True)
    return p


def stage_dir(output_root: str, run_id: str, subdir: str) -> str:
    """`{output_root}/{run_id}/{subdir}` as a string (for HF Trainer)."""
    p = run_dir(output_root, run_id) / subdir
    p.mkdir(parents=True, exist_ok=True)
    return str(p)


# ─── Run-config snapshot ────────────────────────────────────────────────────

def save_run_config(
    run_dir_path,
    spec: "DatasetSpec",
    model_cfg,
    train_cfg,
    extra: Optional[Dict] = None,
) -> None:
    """
    Persist a snapshot of the resolved config into the run directory so each
    run is self-describing. Writes:

        {run_dir}/configs/model_config.yaml   — full OmegaConf dump
        {run_dir}/configs/train_config.yaml   — full OmegaConf dump
        {run_dir}/run_meta.json               — compact, human-readable summary

    `run_meta.json` is intentionally small: it carries the fields a person
    typically wants when comparing two runs side by side (dataset, training
    schedule, mode flags). The full YAML dumps are the source of truth.

    `extra` is merged into `run_meta.json` — useful for adding e.g. the git
    commit hash or the resume source.
    """
    import json as _json
    from datetime import datetime, timezone

    try:
        from omegaconf import OmegaConf
        _to_yaml      = lambda c: OmegaConf.to_yaml(c)
        _to_container = lambda c: OmegaConf.to_container(c, resolve=True)
    except Exception:
        _to_yaml      = lambda c: str(c)
        _to_container = lambda c: dict(c)

    run_dir_path = Path(run_dir_path)
    cfg_dir = run_dir_path / "configs"
    cfg_dir.mkdir(parents=True, exist_ok=True)

    (cfg_dir / "model_config.yaml").write_text(_to_yaml(model_cfg), encoding="utf-8")
    (cfg_dir / "train_config.yaml").write_text(_to_yaml(train_cfg), encoding="utf-8")

    # Compact summary — only the fields that meaningfully change behaviour.
    stage1 = train_cfg.stage1 if "stage1" in _to_container(train_cfg) else {}
    stage2 = train_cfg.stage2 if "stage2" in _to_container(train_cfg) else {}

    meta = {
        "run_id":      run_dir_path.name,
        "saved_at":    datetime.now(timezone.utc).isoformat(timespec="seconds"),

        # — Data
        "dataset":     spec.dataset_name,
        "image_root":  spec.image_root,
        "instruct_json": spec.instruct_json,
        "report_mode": spec.report_mode,
        "image_mode":  spec.image_mode,
        "max_images":  spec.max_images,
        "tasks":       spec.tasks,
        "task_weights": spec.task_weights,

        # — Training schedule
        "stage1": {
            "enabled":       _get(stage1, "enabled", True),
            "num_epochs":    _get(stage1, "num_epochs", None),
            "learning_rate": _get(stage1, "learning_rate", None),
            "freeze_llm":    _get(stage1, "freeze_llm", True),
            "freeze_encoder": _get(stage1, "freeze_encoder", True),
        },
        "stage2": {
            "enabled":       _get(stage2, "enabled", True),
            "num_epochs":    _get(stage2, "num_epochs", None),
            "learning_rate": _get(stage2, "learning_rate", None),
            "freeze_llm":    _get(stage2, "freeze_llm", False),
            "freeze_encoder": _get(stage2, "freeze_encoder", True),
        },
        "batch_size":          _get(train_cfg.training, "per_device_train_batch_size", None),
        "grad_accum":          _get(train_cfg.training, "gradient_accumulation_steps", None),
        "cutoff_len":          _get(train_cfg.training, "cutoff_len", None),
        "fp16":                _get(train_cfg.training, "fp16", None),
        "bf16":                _get(train_cfg.training, "bf16", None),

        # — Model
        "llm":                 _get(model_cfg.llm, "name", None),
        "lora_r":              _get(model_cfg.lora, "r", None),
        "lora_alpha":          _get(model_cfg.lora, "lora_alpha", None),
        "num_image_tokens":    _get(model_cfg.projection, "num_image_tokens", None),
        "chexpert_enabled":    _get(model_cfg.chexpert_classifier, "enabled", None),
    }
    if extra:
        meta.update(extra)

    (run_dir_path / "run_meta.json").write_text(
        _json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8",
    )
    print(f"[save_run_config] snapshot → {run_dir_path}/configs/, run_meta.json")


# ─── Misc ───────────────────────────────────────────────────────────────────

def _get(obj, key: str, default=None):
    """OmegaConf-safe .get with default."""
    try:
        v = getattr(obj, key)
        return v if v is not None else default
    except Exception:
        return default