| """One-shot helper to surgically edit the Colab training notebook. |
| |
| Replaces cell-cfg with the GPU auto-profile version and inserts a new |
| 'pre-compute image features' cell after it. Idempotent — re-running |
| replaces the new cell rather than duplicating it. |
| |
| Run from project root: |
| python scripts/_apply_notebook_edits.py |
| """ |
| import json |
| from pathlib import Path |
|
|
| NB_PATH = Path(__file__).resolve().parent / "cxrvlm_colab_train.ipynb" |
|
|
|
|
| NEW_CFG_SRC = r'''from omegaconf import OmegaConf |
| import torch |
| |
| train_cfg = OmegaConf.load(PROJECT / 'configs' / 'train_config.yaml') |
| model_cfg = OmegaConf.load(PROJECT / 'configs' / 'model_config.yaml') |
| |
| # ── dataset selector ── |
| train_cfg.data.dataset_name = DATASET_NAME |
| |
| # ── training-scheme switches (thesis ablations) ── |
| # report_mode: 'split' → 2 tasks (findings + impression separately) |
| # 'merged' → 1 task (full report "Findings: ...\n\nImpression: ...") |
| # 'split_cascade' → split, but impression's context = GT findings |
| # image_mode : 'all_views_split' | 'frontal_only_split' | 'multi_image_merged' |
| train_cfg.data.report_mode = 'split' |
| train_cfg.data.image_mode = 'all_views_split' |
| train_cfg.data.max_images_per_sample = 2 # only used in multi_image_merged |
| |
| # ── dataset-specific paths ── |
| if DATASET_NAME == 'MIMIC-CXR': |
| train_cfg.data.mimic_cxr_root = str(CXR_ROOT) |
| # Base path; the resolver suffixes __{report_mode}__{image_mode} and |
| # auto-builds (PNU CheXpert + VQA) via data.mimic_cxr_builder. |
| train_cfg.data.instruct_json = str(mimic_json_path) |
| train_cfg.data.mimic_auto_build = True |
| |
| # RaDialog / U-MultiClass abnormality guidance: locate the CheXpert |
| # label CSV so the builder can bake the PNU structured_findings string. |
| _cx = (sorted(DATA_SRC.rglob('*chexpert*.csv')) |
| or sorted(DATA_SRC.rglob('*chexbert*.csv'))) |
| train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None |
| print('CheXpert CSV :', train_cfg.data.mimic_chexpert_csv |
| or 'NOT FOUND — PNU abnormality guidance DISABLED!') |
| |
| # VQA pairs ({train,valid,test}.json) → abnormality-guided VQA. |
| train_cfg.data.mimic_vqa_root = str(VQA_ROOT) if VQA_ROOT is not None else None |
| print('VQA root :', train_cfg.data.mimic_vqa_root or '(none — VQA skipped)') |
| |
| elif DATASET_NAME == 'MIMIC-CXR_resized': |
| # The MIMIC-CXR_resized builder is manifest-driven: it reads |
| # `manifest_{train,val,test}.csv` for split + the 14 chex_* labels |
| # (PNU bucketed directly from the CSV, no separate chexpert.csv needed), |
| # uses `report_relpath` from the manifest to find each .txt, and pulls |
| # VQA from `vqa/{vqa,vqa_val,vqa_test}.json`. |
| train_cfg.data.mimic_cxr_resized.root = str(MR_ROOT) |
| train_cfg.data.mimic_cxr_resized.manifest_dir = None # null → defaults to root |
| train_cfg.data.mimic_cxr_resized.vqa_dir = None # null → {root}/vqa |
| train_cfg.data.mimic_cxr_resized.reports_root = None # null → auto-probe {root} then {root}/reports |
| train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path) |
| train_cfg.data.mimic_cxr_resized.auto_build = True |
| |
| else: # IU-Xray |
| train_cfg.data.iu_xray.images_dir = str(IU_IMAGES_DIR) |
| train_cfg.data.iu_xray.labels_dir = str(IU_LABELS_DIR) |
| train_cfg.data.iu_xray.instruct_json = str(iu_json_path) |
| train_cfg.data.iu_xray.auto_build = True |
| |
| train_cfg.data.train_split = 'train' |
| train_cfg.data.val_split = 'validate' |
| train_cfg.data.test_split = 'test' |
| |
| # ── checkpoint root (Persistence keeps /content/ckpt/) ── |
| CKPT_ROOT = WORK / 'ckpt' |
| train_cfg.training.output_root = str(CKPT_ROOT) |
| |
| |
| # ───────────────────────────────────────────────────────────────────────── |
| # ── GPU auto-profile ──────────────────────────────────────────────────── |
| # Pick batch size / precision / attention backend / GC / optimizer based on |
| # what the current GPU can actually do. Override anything below this block |
| # if you want to force a specific setting. |
| # |
| # Profile rules (compute capability + total VRAM): |
| # T4 (sm_75, 15GB) → FP16 + SDPA + GC ON + bs=1 accum=16 + fp32 AdamW |
| # 3090/L4/A10 (sm_80+, 24GB) → BF16 + FA2 + GC ON + bs=8 accum=2 + 8-bit AdamW |
| # A100 40GB (sm_80, 40GB) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW |
| # A100/H100 80GB (sm_80+, 80G) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW |
| # unknown → conservative T4-style profile |
| # |
| # Why GC ON for 24GB? Bigger batch amortizes the ~25-30% GC overhead. |
| # Math (eff batch = 16): |
| # GC OFF, bs=4, accum=4 → 4 × T = 4.0T per eff-batch |
| # GC ON, bs=8, accum=2 → 2 × 1.5T × 1.3 = 3.9T per eff-batch ✓ |
| # Sub-linear GPU scaling (time(bs=8) ≈ 1.5 × time(bs=4), not 2×) is what |
| # tips the balance. On 40GB+ there's room without GC so we skip it there. |
| |
| assert torch.cuda.is_available(), 'CUDA not available — refusing to write a CPU profile.' |
| _props = torch.cuda.get_device_properties(0) |
| _cap = (_props.major, _props.minor) |
| _vram_gb = _props.total_memory / 1e9 |
| _bf16_ok = torch.cuda.is_bf16_supported() |
| _fa2_ok = _cap >= (8, 0) # FA2 needs Ampere+ (sm_80 or newer) |
| |
| print(f'GPU : {_props.name} ({_vram_gb:.1f} GB)') |
| print(f'Compute cap : sm_{_cap[0]}{_cap[1]}') |
| print(f'BF16 native : {_bf16_ok}') |
| print(f'FA2 capable : {_fa2_ok}') |
| |
| # Try to detect whether flash-attn package is actually importable. If FA2 is |
| # requested by the profile but the wheel isn't installed, cxr_vlm.py will |
| # auto-fall-back to sdpa, but we surface it here so the user knows. |
| _flash_attn_installed = False |
| if _fa2_ok: |
| try: |
| import flash_attn # noqa: F401 |
| _flash_attn_installed = True |
| except Exception: |
| _flash_attn_installed = False |
| print(f'flash-attn : {"installed" if _flash_attn_installed else "NOT installed (will fall back to sdpa)"}') |
| |
| # ── Pick profile ───────────────────────────────────────────────────────── |
| if _vram_gb >= 70: # A100/H100 80GB |
| _profile = dict( |
| label='A100/H100 80GB', |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=16, |
| gradient_checkpointing=False, |
| ) |
| elif _vram_gb >= 35: # A100 40GB |
| _profile = dict( |
| label='A100 40GB', |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=12, |
| gradient_checkpointing=False, |
| ) |
| elif _vram_gb >= 22: # 3090 / L4 / A10 24GB |
| # GC ON + bigger batch beats GC OFF + smaller batch on throughput here. |
| # Per-eff-batch wall time (eff=16): 4×T (GC OFF, bs=4) vs ~3.9×T (GC ON, |
| # bs=8) — sub-linear scaling means bs=8 step is ~1.5×T, not 2×T, so the |
| # GC overhead (~1.3×) is more than paid back. |
| _profile = dict( |
| label='RTX 3090 / L4 / A10 (24GB)', |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=8, |
| gradient_checkpointing=True, |
| ) |
| elif _vram_gb >= 14: # T4 / V100 16GB |
| _profile = dict( |
| label='T4 / V100 (15-16GB)', |
| per_device_train_batch_size=1, per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, dataloader_num_workers=2, |
| gradient_checkpointing=True, |
| ) |
| else: # tiny / unknown |
| _profile = dict( |
| label=f'unknown ({_vram_gb:.0f}GB) — conservative', |
| per_device_train_batch_size=1, per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, dataloader_num_workers=2, |
| gradient_checkpointing=True, |
| ) |
| |
| # Precision: BF16 on Ampere+, FP16 on Turing (T4) and older. |
| _profile['bf16'] = bool(_bf16_ok) |
| _profile['fp16'] = not _bf16_ok |
| |
| # Attention backend: FA2 if Ampere+ AND flash-attn wheel present, else SDPA. |
| _profile['attn_implementation'] = ( |
| 'flash_attention_2' if (_fa2_ok and _flash_attn_installed) else 'sdpa' |
| ) |
| |
| # 8-bit AdamW: bnb's paged_adamw_8bit cuts optimizer-state VRAM ~4× with no |
| # measurable quality loss. Skip on Turing where bnb paged optimizer perf is |
| # weaker — keep adamw_torch there. |
| _profile['optim'] = 'paged_adamw_8bit' if _cap >= (8, 0) else 'adamw_torch' |
| |
| # 4-bit compute dtype tracks precision. |
| _profile['bnb_4bit_compute_dtype'] = 'bfloat16' if _bf16_ok else 'float16' |
| _profile['torch_dtype'] = 'bfloat16' if _bf16_ok else 'float16' |
| |
| print(f'\n→ Profile : {_profile["label"]}') |
| for k, v in _profile.items(): |
| if k == 'label': continue |
| print(f' {k:<32}= {v}') |
| |
| # ── Write profile into the configs ─────────────────────────────────────── |
| train_cfg.training.per_device_train_batch_size = _profile['per_device_train_batch_size'] |
| train_cfg.training.per_device_eval_batch_size = _profile['per_device_eval_batch_size'] |
| train_cfg.training.gradient_accumulation_steps = _profile['gradient_accumulation_steps'] |
| train_cfg.training.dataloader_num_workers = _profile['dataloader_num_workers'] |
| train_cfg.training.fp16 = _profile['fp16'] |
| train_cfg.training.bf16 = _profile['bf16'] |
| train_cfg.training.dataloader_pin_memory = True |
| train_cfg.training.dataloader_persistent_workers = True |
| train_cfg.training.optim = _profile['optim'] |
| # Ensure stage2 still uses the same per-run epoch count we want. |
| train_cfg.stage2.num_epochs = 5 |
| |
| model_cfg.llm.attn_implementation = _profile['attn_implementation'] |
| model_cfg.llm.gradient_checkpointing = _profile['gradient_checkpointing'] |
| model_cfg.llm.torch_dtype = _profile['torch_dtype'] |
| model_cfg.llm.bnb_4bit_compute_dtype = _profile['bnb_4bit_compute_dtype'] |
| model_cfg.llm.bnb_4bit_quant_type = 'nf4' |
| model_cfg.llm.bnb_4bit_use_double_quant = True |
| |
| # ── task weights (sampling ratio enforced by WeightedRandomSampler) ── |
| # Defaults in train_config.yaml: 0.30 / 0.20 / 0.50 (RRG ≈ VQA, impression |
| # lower because in split_cascade mode it sees GT findings as input). |
| # Resolver auto-renormalizes and drops vqa for IU-Xray. Override here only |
| # if you want to experiment per-run, e.g.: |
| # train_cfg.tasks.findings_generation.weight = 0.30 |
| # train_cfg.tasks.impression_generation.weight = 0.20 |
| # train_cfg.tasks.vqa.weight = 0.50 |
| |
| # ── wandb off ── |
| train_cfg.wandb.enabled = False |
| |
| # ── HuggingFace Hub run tracking ── |
| train_cfg.hf_hub.enabled = True |
| train_cfg.hf_hub.repo_id = 'hieu3636/cxr-vlm-runs' # <<< EDIT ME |
| train_cfg.hf_hub.token_env = 'HF_TOKEN' |
| train_cfg.hf_hub.private = True |
| train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / 'run_id.txt') |
| |
| # ── 4-bit QLoRA ── |
| model_cfg.llm.load_in_8bit = False |
| model_cfg.llm.load_in_4bit = True |
| # Oracle PNU path does NOT use the CheXpert classifier module (labels come |
| # from the GT csv/manifest baked into the prompt). Keep it disabled until |
| # you wire the learned classifier for realistic inference. |
| model_cfg.chexpert_classifier.enabled = False |
| |
| OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') |
| OmegaConf.save(model_cfg, PROJECT / 'configs' / 'model_config.yaml') |
| |
| print('--- train_cfg.data ---'); print(OmegaConf.to_yaml(train_cfg.data)) |
| print('--- train_cfg.tasks ---'); print(OmegaConf.to_yaml(train_cfg.tasks)) |
| print('--- train_cfg.training ---');print(OmegaConf.to_yaml(train_cfg.training)) |
| print('--- train_cfg.hf_hub ---'); print(OmegaConf.to_yaml(train_cfg.hf_hub)) |
| print('--- model_cfg.llm ---'); print(OmegaConf.to_yaml(model_cfg.llm)) |
| ''' |
|
|
|
|
| FEATURE_CACHE_SRC = r'''# ─── Optional: pre-compute image patch features (skip frozen encoder forward) ── |
| # |
| # The image encoder is frozen + the transform is deterministic, so encoding the |
| # same image every step is wasted work. Run this ONCE per dataset to cache |
| # (P, 768) patch tensors under {WORK}/feature_cache/{DATASET_NAME}/ and the |
| # training loop will load them instead of re-encoding. |
| # |
| # Set CACHE_FEATURES = False to skip (e.g. first time you set up the run, want |
| # the smoke test to use the raw path, or you're debugging the encoder). |
| # |
| # Disk usage: ~3 MB per image (P=1024 patches × 768 dim × fp16). For ~30k |
| # unique images that's ~90 GB — make sure WORK has the room, or set |
| # CACHE_FEATURES=False on tight quotas. |
| |
| CACHE_FEATURES = True |
| |
| if CACHE_FEATURES: |
| feature_cache_dir = WORK / 'feature_cache' / DATASET_NAME |
| feature_cache_dir.mkdir(parents=True, exist_ok=True) |
| train_cfg.data.feature_cache_dir = str(feature_cache_dir) |
| OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') |
| |
| # Re-running this cell is safe: --overwrite is OFF by default so cached |
| # files are skipped. To force a full rebuild, add `--overwrite` below. |
| print(f'feature_cache_dir = {feature_cache_dir}') |
| !python -m scripts.precompute_image_features \ |
| --model_config configs/model_config.yaml \ |
| --train_config configs/train_config.yaml \ |
| --cache_dir "{feature_cache_dir}" \ |
| --batch_size 16 |
| else: |
| train_cfg.data.feature_cache_dir = None |
| OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') |
| print('Feature cache DISABLED. Training will run the image encoder every step.') |
| ''' |
|
|
|
|
| def src_to_lines(s: str): |
| """Convert a string into Jupyter's list-of-lines source representation.""" |
| lines = s.split("\n") |
| return [ln + "\n" for ln in lines[:-1]] + ([lines[-1]] if lines[-1] else []) |
|
|
|
|
| def main(): |
| with open(NB_PATH, "r", encoding="utf-8") as f: |
| nb = json.load(f) |
|
|
| |
| cfg_idx = None |
| for i, c in enumerate(nb["cells"]): |
| if c.get("id") == "cell-cfg": |
| cfg_idx = i |
| break |
| if cfg_idx is None: |
| raise RuntimeError("cell-cfg not found in notebook") |
| print(f"cell-cfg at index {cfg_idx}") |
|
|
| |
| nb["cells"][cfg_idx]["source"] = src_to_lines(NEW_CFG_SRC) |
| nb["cells"][cfg_idx]["outputs"] = [] |
| nb["cells"][cfg_idx]["execution_count"] = None |
|
|
| |
| nb["cells"] = [ |
| c for c in nb["cells"] |
| if c.get("id") not in ("cell-feature-cache", "cell-feature-cache-md") |
| ] |
|
|
| |
| |
| for i, c in enumerate(nb["cells"]): |
| if c.get("id") == "cell-cfg": |
| cfg_idx = i |
| break |
|
|
| |
| md_cell = { |
| "cell_type": "markdown", |
| "id": "cell-feature-cache-md", |
| "metadata": {}, |
| "source": ["## 4b. Pre-compute image features (optional speedup)\n"], |
| } |
| code_cell = { |
| "cell_type": "code", |
| "id": "cell-feature-cache", |
| "metadata": {}, |
| "execution_count": None, |
| "outputs": [], |
| "source": src_to_lines(FEATURE_CACHE_SRC), |
| } |
| nb["cells"].insert(cfg_idx + 1, md_cell) |
| nb["cells"].insert(cfg_idx + 2, code_cell) |
|
|
| with open(NB_PATH, "w", encoding="utf-8") as f: |
| json.dump(nb, f, indent=1, ensure_ascii=False) |
| f.write("\n") |
|
|
| print(f"Wrote {NB_PATH}") |
| print(f"New cell count: {len(nb['cells'])}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|