"""One-shot helper to surgically edit the Colab training notebook. Replaces cell-cfg with the GPU auto-profile version and inserts a new 'pre-compute image features' cell after it. Idempotent — re-running replaces the new cell rather than duplicating it. Run from project root: python scripts/_apply_notebook_edits.py """ import json from pathlib import Path NB_PATH = Path(__file__).resolve().parent / "cxrvlm_colab_train.ipynb" NEW_CFG_SRC = r'''from omegaconf import OmegaConf import torch train_cfg = OmegaConf.load(PROJECT / 'configs' / 'train_config.yaml') model_cfg = OmegaConf.load(PROJECT / 'configs' / 'model_config.yaml') # ── dataset selector ── train_cfg.data.dataset_name = DATASET_NAME # ── training-scheme switches (thesis ablations) ── # report_mode: 'split' → 2 tasks (findings + impression separately) # 'merged' → 1 task (full report "Findings: ...\n\nImpression: ...") # 'split_cascade' → split, but impression's context = GT findings # image_mode : 'all_views_split' | 'frontal_only_split' | 'multi_image_merged' train_cfg.data.report_mode = 'split' train_cfg.data.image_mode = 'all_views_split' train_cfg.data.max_images_per_sample = 2 # only used in multi_image_merged # ── dataset-specific paths ── if DATASET_NAME == 'MIMIC-CXR': train_cfg.data.mimic_cxr_root = str(CXR_ROOT) # Base path; the resolver suffixes __{report_mode}__{image_mode} and # auto-builds (PNU CheXpert + VQA) via data.mimic_cxr_builder. train_cfg.data.instruct_json = str(mimic_json_path) train_cfg.data.mimic_auto_build = True # RaDialog / U-MultiClass abnormality guidance: locate the CheXpert # label CSV so the builder can bake the PNU structured_findings string. _cx = (sorted(DATA_SRC.rglob('*chexpert*.csv')) or sorted(DATA_SRC.rglob('*chexbert*.csv'))) train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None print('CheXpert CSV :', train_cfg.data.mimic_chexpert_csv or 'NOT FOUND — PNU abnormality guidance DISABLED!') # VQA pairs ({train,valid,test}.json) → abnormality-guided VQA. train_cfg.data.mimic_vqa_root = str(VQA_ROOT) if VQA_ROOT is not None else None print('VQA root :', train_cfg.data.mimic_vqa_root or '(none — VQA skipped)') elif DATASET_NAME == 'MIMIC-CXR_resized': # The MIMIC-CXR_resized builder is manifest-driven: it reads # `manifest_{train,val,test}.csv` for split + the 14 chex_* labels # (PNU bucketed directly from the CSV, no separate chexpert.csv needed), # uses `report_relpath` from the manifest to find each .txt, and pulls # VQA from `vqa/{vqa,vqa_val,vqa_test}.json`. train_cfg.data.mimic_cxr_resized.root = str(MR_ROOT) train_cfg.data.mimic_cxr_resized.manifest_dir = None # null → defaults to root train_cfg.data.mimic_cxr_resized.vqa_dir = None # null → {root}/vqa train_cfg.data.mimic_cxr_resized.reports_root = None # null → auto-probe {root} then {root}/reports train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path) train_cfg.data.mimic_cxr_resized.auto_build = True else: # IU-Xray train_cfg.data.iu_xray.images_dir = str(IU_IMAGES_DIR) train_cfg.data.iu_xray.labels_dir = str(IU_LABELS_DIR) train_cfg.data.iu_xray.instruct_json = str(iu_json_path) train_cfg.data.iu_xray.auto_build = True train_cfg.data.train_split = 'train' train_cfg.data.val_split = 'validate' train_cfg.data.test_split = 'test' # ── checkpoint root (Persistence keeps /content/ckpt/) ── CKPT_ROOT = WORK / 'ckpt' train_cfg.training.output_root = str(CKPT_ROOT) # ───────────────────────────────────────────────────────────────────────── # ── GPU auto-profile ──────────────────────────────────────────────────── # Pick batch size / precision / attention backend / GC / optimizer based on # what the current GPU can actually do. Override anything below this block # if you want to force a specific setting. # # Profile rules (compute capability + total VRAM): # T4 (sm_75, 15GB) → FP16 + SDPA + GC ON + bs=1 accum=16 + fp32 AdamW # 3090/L4/A10 (sm_80+, 24GB) → BF16 + FA2 + GC ON + bs=8 accum=2 + 8-bit AdamW # A100 40GB (sm_80, 40GB) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW # A100/H100 80GB (sm_80+, 80G) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW # unknown → conservative T4-style profile # # Why GC ON for 24GB? Bigger batch amortizes the ~25-30% GC overhead. # Math (eff batch = 16): # GC OFF, bs=4, accum=4 → 4 × T = 4.0T per eff-batch # GC ON, bs=8, accum=2 → 2 × 1.5T × 1.3 = 3.9T per eff-batch ✓ # Sub-linear GPU scaling (time(bs=8) ≈ 1.5 × time(bs=4), not 2×) is what # tips the balance. On 40GB+ there's room without GC so we skip it there. assert torch.cuda.is_available(), 'CUDA not available — refusing to write a CPU profile.' _props = torch.cuda.get_device_properties(0) _cap = (_props.major, _props.minor) _vram_gb = _props.total_memory / 1e9 _bf16_ok = torch.cuda.is_bf16_supported() _fa2_ok = _cap >= (8, 0) # FA2 needs Ampere+ (sm_80 or newer) print(f'GPU : {_props.name} ({_vram_gb:.1f} GB)') print(f'Compute cap : sm_{_cap[0]}{_cap[1]}') print(f'BF16 native : {_bf16_ok}') print(f'FA2 capable : {_fa2_ok}') # Try to detect whether flash-attn package is actually importable. If FA2 is # requested by the profile but the wheel isn't installed, cxr_vlm.py will # auto-fall-back to sdpa, but we surface it here so the user knows. _flash_attn_installed = False if _fa2_ok: try: import flash_attn # noqa: F401 _flash_attn_installed = True except Exception: _flash_attn_installed = False print(f'flash-attn : {"installed" if _flash_attn_installed else "NOT installed (will fall back to sdpa)"}') # ── Pick profile ───────────────────────────────────────────────────────── if _vram_gb >= 70: # A100/H100 80GB _profile = dict( label='A100/H100 80GB', per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=16, gradient_checkpointing=False, ) elif _vram_gb >= 35: # A100 40GB _profile = dict( label='A100 40GB', per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=12, gradient_checkpointing=False, ) elif _vram_gb >= 22: # 3090 / L4 / A10 24GB # GC ON + bigger batch beats GC OFF + smaller batch on throughput here. # Per-eff-batch wall time (eff=16): 4×T (GC OFF, bs=4) vs ~3.9×T (GC ON, # bs=8) — sub-linear scaling means bs=8 step is ~1.5×T, not 2×T, so the # GC overhead (~1.3×) is more than paid back. _profile = dict( label='RTX 3090 / L4 / A10 (24GB)', per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=8, gradient_checkpointing=True, ) elif _vram_gb >= 14: # T4 / V100 16GB _profile = dict( label='T4 / V100 (15-16GB)', per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=16, dataloader_num_workers=2, gradient_checkpointing=True, ) else: # tiny / unknown _profile = dict( label=f'unknown ({_vram_gb:.0f}GB) — conservative', per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=16, dataloader_num_workers=2, gradient_checkpointing=True, ) # Precision: BF16 on Ampere+, FP16 on Turing (T4) and older. _profile['bf16'] = bool(_bf16_ok) _profile['fp16'] = not _bf16_ok # Attention backend: FA2 if Ampere+ AND flash-attn wheel present, else SDPA. _profile['attn_implementation'] = ( 'flash_attention_2' if (_fa2_ok and _flash_attn_installed) else 'sdpa' ) # 8-bit AdamW: bnb's paged_adamw_8bit cuts optimizer-state VRAM ~4× with no # measurable quality loss. Skip on Turing where bnb paged optimizer perf is # weaker — keep adamw_torch there. _profile['optim'] = 'paged_adamw_8bit' if _cap >= (8, 0) else 'adamw_torch' # 4-bit compute dtype tracks precision. _profile['bnb_4bit_compute_dtype'] = 'bfloat16' if _bf16_ok else 'float16' _profile['torch_dtype'] = 'bfloat16' if _bf16_ok else 'float16' print(f'\n→ Profile : {_profile["label"]}') for k, v in _profile.items(): if k == 'label': continue print(f' {k:<32}= {v}') # ── Write profile into the configs ─────────────────────────────────────── train_cfg.training.per_device_train_batch_size = _profile['per_device_train_batch_size'] train_cfg.training.per_device_eval_batch_size = _profile['per_device_eval_batch_size'] train_cfg.training.gradient_accumulation_steps = _profile['gradient_accumulation_steps'] train_cfg.training.dataloader_num_workers = _profile['dataloader_num_workers'] train_cfg.training.fp16 = _profile['fp16'] train_cfg.training.bf16 = _profile['bf16'] train_cfg.training.dataloader_pin_memory = True train_cfg.training.dataloader_persistent_workers = True train_cfg.training.optim = _profile['optim'] # Ensure stage2 still uses the same per-run epoch count we want. train_cfg.stage2.num_epochs = 5 model_cfg.llm.attn_implementation = _profile['attn_implementation'] model_cfg.llm.gradient_checkpointing = _profile['gradient_checkpointing'] model_cfg.llm.torch_dtype = _profile['torch_dtype'] model_cfg.llm.bnb_4bit_compute_dtype = _profile['bnb_4bit_compute_dtype'] model_cfg.llm.bnb_4bit_quant_type = 'nf4' model_cfg.llm.bnb_4bit_use_double_quant = True # ── task weights (sampling ratio enforced by WeightedRandomSampler) ── # Defaults in train_config.yaml: 0.30 / 0.20 / 0.50 (RRG ≈ VQA, impression # lower because in split_cascade mode it sees GT findings as input). # Resolver auto-renormalizes and drops vqa for IU-Xray. Override here only # if you want to experiment per-run, e.g.: # train_cfg.tasks.findings_generation.weight = 0.30 # train_cfg.tasks.impression_generation.weight = 0.20 # train_cfg.tasks.vqa.weight = 0.50 # ── wandb off ── train_cfg.wandb.enabled = False # ── HuggingFace Hub run tracking ── train_cfg.hf_hub.enabled = True train_cfg.hf_hub.repo_id = 'hieu3636/cxr-vlm-runs' # <<< EDIT ME train_cfg.hf_hub.token_env = 'HF_TOKEN' train_cfg.hf_hub.private = True train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / 'run_id.txt') # ── 4-bit QLoRA ── model_cfg.llm.load_in_8bit = False model_cfg.llm.load_in_4bit = True # Oracle PNU path does NOT use the CheXpert classifier module (labels come # from the GT csv/manifest baked into the prompt). Keep it disabled until # you wire the learned classifier for realistic inference. model_cfg.chexpert_classifier.enabled = False OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') OmegaConf.save(model_cfg, PROJECT / 'configs' / 'model_config.yaml') print('--- train_cfg.data ---'); print(OmegaConf.to_yaml(train_cfg.data)) print('--- train_cfg.tasks ---'); print(OmegaConf.to_yaml(train_cfg.tasks)) print('--- train_cfg.training ---');print(OmegaConf.to_yaml(train_cfg.training)) print('--- train_cfg.hf_hub ---'); print(OmegaConf.to_yaml(train_cfg.hf_hub)) print('--- model_cfg.llm ---'); print(OmegaConf.to_yaml(model_cfg.llm)) ''' FEATURE_CACHE_SRC = r'''# ─── Optional: pre-compute image patch features (skip frozen encoder forward) ── # # The image encoder is frozen + the transform is deterministic, so encoding the # same image every step is wasted work. Run this ONCE per dataset to cache # (P, 768) patch tensors under {WORK}/feature_cache/{DATASET_NAME}/ and the # training loop will load them instead of re-encoding. # # Set CACHE_FEATURES = False to skip (e.g. first time you set up the run, want # the smoke test to use the raw path, or you're debugging the encoder). # # Disk usage: ~3 MB per image (P=1024 patches × 768 dim × fp16). For ~30k # unique images that's ~90 GB — make sure WORK has the room, or set # CACHE_FEATURES=False on tight quotas. CACHE_FEATURES = True if CACHE_FEATURES: feature_cache_dir = WORK / 'feature_cache' / DATASET_NAME feature_cache_dir.mkdir(parents=True, exist_ok=True) train_cfg.data.feature_cache_dir = str(feature_cache_dir) OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') # Re-running this cell is safe: --overwrite is OFF by default so cached # files are skipped. To force a full rebuild, add `--overwrite` below. print(f'feature_cache_dir = {feature_cache_dir}') !python -m scripts.precompute_image_features \ --model_config configs/model_config.yaml \ --train_config configs/train_config.yaml \ --cache_dir "{feature_cache_dir}" \ --batch_size 16 else: train_cfg.data.feature_cache_dir = None OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml') print('Feature cache DISABLED. Training will run the image encoder every step.') ''' def src_to_lines(s: str): """Convert a string into Jupyter's list-of-lines source representation.""" lines = s.split("\n") return [ln + "\n" for ln in lines[:-1]] + ([lines[-1]] if lines[-1] else []) def main(): with open(NB_PATH, "r", encoding="utf-8") as f: nb = json.load(f) # Find cell-cfg index cfg_idx = None for i, c in enumerate(nb["cells"]): if c.get("id") == "cell-cfg": cfg_idx = i break if cfg_idx is None: raise RuntimeError("cell-cfg not found in notebook") print(f"cell-cfg at index {cfg_idx}") # Replace cell-cfg nb["cells"][cfg_idx]["source"] = src_to_lines(NEW_CFG_SRC) nb["cells"][cfg_idx]["outputs"] = [] nb["cells"][cfg_idx]["execution_count"] = None # Remove any pre-existing feature-cache cells (idempotent re-run) nb["cells"] = [ c for c in nb["cells"] if c.get("id") not in ("cell-feature-cache", "cell-feature-cache-md") ] # Re-find cell-cfg index (may have shifted if we removed earlier ones — but # those would have been after it, so index is stable) for i, c in enumerate(nb["cells"]): if c.get("id") == "cell-cfg": cfg_idx = i break # Insert markdown + code cells after cell-cfg md_cell = { "cell_type": "markdown", "id": "cell-feature-cache-md", "metadata": {}, "source": ["## 4b. Pre-compute image features (optional speedup)\n"], } code_cell = { "cell_type": "code", "id": "cell-feature-cache", "metadata": {}, "execution_count": None, "outputs": [], "source": src_to_lines(FEATURE_CACHE_SRC), } nb["cells"].insert(cfg_idx + 1, md_cell) nb["cells"].insert(cfg_idx + 2, code_cell) with open(NB_PATH, "w", encoding="utf-8") as f: json.dump(nb, f, indent=1, ensure_ascii=False) f.write("\n") print(f"Wrote {NB_PATH}") print(f"New cell count: {len(nb['cells'])}") if __name__ == "__main__": main()