cxr-vlm-code / scripts /_apply_notebook_edits.py
convitom
f
c61f01a
"""One-shot helper to surgically edit the Colab training notebook.
Replaces cell-cfg with the GPU auto-profile version and inserts a new
'pre-compute image features' cell after it. Idempotent — re-running
replaces the new cell rather than duplicating it.
Run from project root:
python scripts/_apply_notebook_edits.py
"""
import json
from pathlib import Path
NB_PATH = Path(__file__).resolve().parent / "cxrvlm_colab_train.ipynb"
NEW_CFG_SRC = r'''from omegaconf import OmegaConf
import torch
train_cfg = OmegaConf.load(PROJECT / 'configs' / 'train_config.yaml')
model_cfg = OmegaConf.load(PROJECT / 'configs' / 'model_config.yaml')
# ── dataset selector ──
train_cfg.data.dataset_name = DATASET_NAME
# ── training-scheme switches (thesis ablations) ──
# report_mode: 'split' → 2 tasks (findings + impression separately)
# 'merged' → 1 task (full report "Findings: ...\n\nImpression: ...")
# 'split_cascade' → split, but impression's context = GT findings
# image_mode : 'all_views_split' | 'frontal_only_split' | 'multi_image_merged'
train_cfg.data.report_mode = 'split'
train_cfg.data.image_mode = 'all_views_split'
train_cfg.data.max_images_per_sample = 2 # only used in multi_image_merged
# ── dataset-specific paths ──
if DATASET_NAME == 'MIMIC-CXR':
train_cfg.data.mimic_cxr_root = str(CXR_ROOT)
# Base path; the resolver suffixes __{report_mode}__{image_mode} and
# auto-builds (PNU CheXpert + VQA) via data.mimic_cxr_builder.
train_cfg.data.instruct_json = str(mimic_json_path)
train_cfg.data.mimic_auto_build = True
# RaDialog / U-MultiClass abnormality guidance: locate the CheXpert
# label CSV so the builder can bake the PNU structured_findings string.
_cx = (sorted(DATA_SRC.rglob('*chexpert*.csv'))
or sorted(DATA_SRC.rglob('*chexbert*.csv')))
train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None
print('CheXpert CSV :', train_cfg.data.mimic_chexpert_csv
or 'NOT FOUND — PNU abnormality guidance DISABLED!')
# VQA pairs ({train,valid,test}.json) → abnormality-guided VQA.
train_cfg.data.mimic_vqa_root = str(VQA_ROOT) if VQA_ROOT is not None else None
print('VQA root :', train_cfg.data.mimic_vqa_root or '(none — VQA skipped)')
elif DATASET_NAME == 'MIMIC-CXR_resized':
# The MIMIC-CXR_resized builder is manifest-driven: it reads
# `manifest_{train,val,test}.csv` for split + the 14 chex_* labels
# (PNU bucketed directly from the CSV, no separate chexpert.csv needed),
# uses `report_relpath` from the manifest to find each .txt, and pulls
# VQA from `vqa/{vqa,vqa_val,vqa_test}.json`.
train_cfg.data.mimic_cxr_resized.root = str(MR_ROOT)
train_cfg.data.mimic_cxr_resized.manifest_dir = None # null → defaults to root
train_cfg.data.mimic_cxr_resized.vqa_dir = None # null → {root}/vqa
train_cfg.data.mimic_cxr_resized.reports_root = None # null → auto-probe {root} then {root}/reports
train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path)
train_cfg.data.mimic_cxr_resized.auto_build = True
else: # IU-Xray
train_cfg.data.iu_xray.images_dir = str(IU_IMAGES_DIR)
train_cfg.data.iu_xray.labels_dir = str(IU_LABELS_DIR)
train_cfg.data.iu_xray.instruct_json = str(iu_json_path)
train_cfg.data.iu_xray.auto_build = True
train_cfg.data.train_split = 'train'
train_cfg.data.val_split = 'validate'
train_cfg.data.test_split = 'test'
# ── checkpoint root (Persistence keeps /content/ckpt/) ──
CKPT_ROOT = WORK / 'ckpt'
train_cfg.training.output_root = str(CKPT_ROOT)
# ─────────────────────────────────────────────────────────────────────────
# ── GPU auto-profile ────────────────────────────────────────────────────
# Pick batch size / precision / attention backend / GC / optimizer based on
# what the current GPU can actually do. Override anything below this block
# if you want to force a specific setting.
#
# Profile rules (compute capability + total VRAM):
# T4 (sm_75, 15GB) → FP16 + SDPA + GC ON + bs=1 accum=16 + fp32 AdamW
# 3090/L4/A10 (sm_80+, 24GB) → BF16 + FA2 + GC ON + bs=8 accum=2 + 8-bit AdamW
# A100 40GB (sm_80, 40GB) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW
# A100/H100 80GB (sm_80+, 80G) → BF16 + FA2 + GC OFF + bs=8 accum=2 + 8-bit AdamW
# unknown → conservative T4-style profile
#
# Why GC ON for 24GB? Bigger batch amortizes the ~25-30% GC overhead.
# Math (eff batch = 16):
# GC OFF, bs=4, accum=4 → 4 × T = 4.0T per eff-batch
# GC ON, bs=8, accum=2 → 2 × 1.5T × 1.3 = 3.9T per eff-batch ✓
# Sub-linear GPU scaling (time(bs=8) ≈ 1.5 × time(bs=4), not 2×) is what
# tips the balance. On 40GB+ there's room without GC so we skip it there.
assert torch.cuda.is_available(), 'CUDA not available — refusing to write a CPU profile.'
_props = torch.cuda.get_device_properties(0)
_cap = (_props.major, _props.minor)
_vram_gb = _props.total_memory / 1e9
_bf16_ok = torch.cuda.is_bf16_supported()
_fa2_ok = _cap >= (8, 0) # FA2 needs Ampere+ (sm_80 or newer)
print(f'GPU : {_props.name} ({_vram_gb:.1f} GB)')
print(f'Compute cap : sm_{_cap[0]}{_cap[1]}')
print(f'BF16 native : {_bf16_ok}')
print(f'FA2 capable : {_fa2_ok}')
# Try to detect whether flash-attn package is actually importable. If FA2 is
# requested by the profile but the wheel isn't installed, cxr_vlm.py will
# auto-fall-back to sdpa, but we surface it here so the user knows.
_flash_attn_installed = False
if _fa2_ok:
try:
import flash_attn # noqa: F401
_flash_attn_installed = True
except Exception:
_flash_attn_installed = False
print(f'flash-attn : {"installed" if _flash_attn_installed else "NOT installed (will fall back to sdpa)"}')
# ── Pick profile ─────────────────────────────────────────────────────────
if _vram_gb >= 70: # A100/H100 80GB
_profile = dict(
label='A100/H100 80GB',
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=16,
gradient_checkpointing=False,
)
elif _vram_gb >= 35: # A100 40GB
_profile = dict(
label='A100 40GB',
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=12,
gradient_checkpointing=False,
)
elif _vram_gb >= 22: # 3090 / L4 / A10 24GB
# GC ON + bigger batch beats GC OFF + smaller batch on throughput here.
# Per-eff-batch wall time (eff=16): 4×T (GC OFF, bs=4) vs ~3.9×T (GC ON,
# bs=8) — sub-linear scaling means bs=8 step is ~1.5×T, not 2×T, so the
# GC overhead (~1.3×) is more than paid back.
_profile = dict(
label='RTX 3090 / L4 / A10 (24GB)',
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=8,
gradient_checkpointing=True,
)
elif _vram_gb >= 14: # T4 / V100 16GB
_profile = dict(
label='T4 / V100 (15-16GB)',
per_device_train_batch_size=1, per_device_eval_batch_size=1,
gradient_accumulation_steps=16, dataloader_num_workers=2,
gradient_checkpointing=True,
)
else: # tiny / unknown
_profile = dict(
label=f'unknown ({_vram_gb:.0f}GB) — conservative',
per_device_train_batch_size=1, per_device_eval_batch_size=1,
gradient_accumulation_steps=16, dataloader_num_workers=2,
gradient_checkpointing=True,
)
# Precision: BF16 on Ampere+, FP16 on Turing (T4) and older.
_profile['bf16'] = bool(_bf16_ok)
_profile['fp16'] = not _bf16_ok
# Attention backend: FA2 if Ampere+ AND flash-attn wheel present, else SDPA.
_profile['attn_implementation'] = (
'flash_attention_2' if (_fa2_ok and _flash_attn_installed) else 'sdpa'
)
# 8-bit AdamW: bnb's paged_adamw_8bit cuts optimizer-state VRAM ~4× with no
# measurable quality loss. Skip on Turing where bnb paged optimizer perf is
# weaker — keep adamw_torch there.
_profile['optim'] = 'paged_adamw_8bit' if _cap >= (8, 0) else 'adamw_torch'
# 4-bit compute dtype tracks precision.
_profile['bnb_4bit_compute_dtype'] = 'bfloat16' if _bf16_ok else 'float16'
_profile['torch_dtype'] = 'bfloat16' if _bf16_ok else 'float16'
print(f'\n→ Profile : {_profile["label"]}')
for k, v in _profile.items():
if k == 'label': continue
print(f' {k:<32}= {v}')
# ── Write profile into the configs ───────────────────────────────────────
train_cfg.training.per_device_train_batch_size = _profile['per_device_train_batch_size']
train_cfg.training.per_device_eval_batch_size = _profile['per_device_eval_batch_size']
train_cfg.training.gradient_accumulation_steps = _profile['gradient_accumulation_steps']
train_cfg.training.dataloader_num_workers = _profile['dataloader_num_workers']
train_cfg.training.fp16 = _profile['fp16']
train_cfg.training.bf16 = _profile['bf16']
train_cfg.training.dataloader_pin_memory = True
train_cfg.training.dataloader_persistent_workers = True
train_cfg.training.optim = _profile['optim']
# Ensure stage2 still uses the same per-run epoch count we want.
train_cfg.stage2.num_epochs = 5
model_cfg.llm.attn_implementation = _profile['attn_implementation']
model_cfg.llm.gradient_checkpointing = _profile['gradient_checkpointing']
model_cfg.llm.torch_dtype = _profile['torch_dtype']
model_cfg.llm.bnb_4bit_compute_dtype = _profile['bnb_4bit_compute_dtype']
model_cfg.llm.bnb_4bit_quant_type = 'nf4'
model_cfg.llm.bnb_4bit_use_double_quant = True
# ── task weights (sampling ratio enforced by WeightedRandomSampler) ──
# Defaults in train_config.yaml: 0.30 / 0.20 / 0.50 (RRG ≈ VQA, impression
# lower because in split_cascade mode it sees GT findings as input).
# Resolver auto-renormalizes and drops vqa for IU-Xray. Override here only
# if you want to experiment per-run, e.g.:
# train_cfg.tasks.findings_generation.weight = 0.30
# train_cfg.tasks.impression_generation.weight = 0.20
# train_cfg.tasks.vqa.weight = 0.50
# ── wandb off ──
train_cfg.wandb.enabled = False
# ── HuggingFace Hub run tracking ──
train_cfg.hf_hub.enabled = True
train_cfg.hf_hub.repo_id = 'hieu3636/cxr-vlm-runs' # <<< EDIT ME
train_cfg.hf_hub.token_env = 'HF_TOKEN'
train_cfg.hf_hub.private = True
train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / 'run_id.txt')
# ── 4-bit QLoRA ──
model_cfg.llm.load_in_8bit = False
model_cfg.llm.load_in_4bit = True
# Oracle PNU path does NOT use the CheXpert classifier module (labels come
# from the GT csv/manifest baked into the prompt). Keep it disabled until
# you wire the learned classifier for realistic inference.
model_cfg.chexpert_classifier.enabled = False
OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml')
OmegaConf.save(model_cfg, PROJECT / 'configs' / 'model_config.yaml')
print('--- train_cfg.data ---'); print(OmegaConf.to_yaml(train_cfg.data))
print('--- train_cfg.tasks ---'); print(OmegaConf.to_yaml(train_cfg.tasks))
print('--- train_cfg.training ---');print(OmegaConf.to_yaml(train_cfg.training))
print('--- train_cfg.hf_hub ---'); print(OmegaConf.to_yaml(train_cfg.hf_hub))
print('--- model_cfg.llm ---'); print(OmegaConf.to_yaml(model_cfg.llm))
'''
FEATURE_CACHE_SRC = r'''# ─── Optional: pre-compute image patch features (skip frozen encoder forward) ──
#
# The image encoder is frozen + the transform is deterministic, so encoding the
# same image every step is wasted work. Run this ONCE per dataset to cache
# (P, 768) patch tensors under {WORK}/feature_cache/{DATASET_NAME}/ and the
# training loop will load them instead of re-encoding.
#
# Set CACHE_FEATURES = False to skip (e.g. first time you set up the run, want
# the smoke test to use the raw path, or you're debugging the encoder).
#
# Disk usage: ~3 MB per image (P=1024 patches × 768 dim × fp16). For ~30k
# unique images that's ~90 GB — make sure WORK has the room, or set
# CACHE_FEATURES=False on tight quotas.
CACHE_FEATURES = True
if CACHE_FEATURES:
feature_cache_dir = WORK / 'feature_cache' / DATASET_NAME
feature_cache_dir.mkdir(parents=True, exist_ok=True)
train_cfg.data.feature_cache_dir = str(feature_cache_dir)
OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml')
# Re-running this cell is safe: --overwrite is OFF by default so cached
# files are skipped. To force a full rebuild, add `--overwrite` below.
print(f'feature_cache_dir = {feature_cache_dir}')
!python -m scripts.precompute_image_features \
--model_config configs/model_config.yaml \
--train_config configs/train_config.yaml \
--cache_dir "{feature_cache_dir}" \
--batch_size 16
else:
train_cfg.data.feature_cache_dir = None
OmegaConf.save(train_cfg, PROJECT / 'configs' / 'train_config.yaml')
print('Feature cache DISABLED. Training will run the image encoder every step.')
'''
def src_to_lines(s: str):
"""Convert a string into Jupyter's list-of-lines source representation."""
lines = s.split("\n")
return [ln + "\n" for ln in lines[:-1]] + ([lines[-1]] if lines[-1] else [])
def main():
with open(NB_PATH, "r", encoding="utf-8") as f:
nb = json.load(f)
# Find cell-cfg index
cfg_idx = None
for i, c in enumerate(nb["cells"]):
if c.get("id") == "cell-cfg":
cfg_idx = i
break
if cfg_idx is None:
raise RuntimeError("cell-cfg not found in notebook")
print(f"cell-cfg at index {cfg_idx}")
# Replace cell-cfg
nb["cells"][cfg_idx]["source"] = src_to_lines(NEW_CFG_SRC)
nb["cells"][cfg_idx]["outputs"] = []
nb["cells"][cfg_idx]["execution_count"] = None
# Remove any pre-existing feature-cache cells (idempotent re-run)
nb["cells"] = [
c for c in nb["cells"]
if c.get("id") not in ("cell-feature-cache", "cell-feature-cache-md")
]
# Re-find cell-cfg index (may have shifted if we removed earlier ones — but
# those would have been after it, so index is stable)
for i, c in enumerate(nb["cells"]):
if c.get("id") == "cell-cfg":
cfg_idx = i
break
# Insert markdown + code cells after cell-cfg
md_cell = {
"cell_type": "markdown",
"id": "cell-feature-cache-md",
"metadata": {},
"source": ["## 4b. Pre-compute image features (optional speedup)\n"],
}
code_cell = {
"cell_type": "code",
"id": "cell-feature-cache",
"metadata": {},
"execution_count": None,
"outputs": [],
"source": src_to_lines(FEATURE_CACHE_SRC),
}
nb["cells"].insert(cfg_idx + 1, md_cell)
nb["cells"].insert(cfg_idx + 2, code_cell)
with open(NB_PATH, "w", encoding="utf-8") as f:
json.dump(nb, f, indent=1, ensure_ascii=False)
f.write("\n")
print(f"Wrote {NB_PATH}")
print(f"New cell count: {len(nb['cells'])}")
if __name__ == "__main__":
main()