"""GCP Vertex AI Custom Training Job entrypoint. Mirrors the colab notebook's setup (cells: paths, cfg, resume, stage1): 1. Download dataset payload from HF Hub (if not cached on disk) 2. Patch configs/{train,model}_config.yaml for GPU profile + paths + HF Hub 3. Pin run_id.txt for --mode resume 4. Exec `python -m training.train --mode {fresh,resume}` The container's command is expected to have already cloned the project source (this file) into /workspace/code, then `cd /workspace/code` and run this script. Required env vars: HF_TOKEN — HuggingFace token (read access for code+data, write for runs) DATASET_NAME — 'IU-Xray' | 'MIMIC-CXR' | 'MIMIC-CXR_resized' Optional env vars (defaults shown): HF_USER = hieu3636 REPORT_MODE = split_cascade IMAGE_MODE = all_views_split S1_EPOCHS = 2 S2_EPOCHS = 7 MODE = resume # 'fresh' | 'resume' EXPLICIT_RUN_ID = '' # only matters when MODE=resume HF_RUNS_REPO = hieu3636/cxr-vlm-runs WORK = /workspace """ from __future__ import annotations import os import shutil import subprocess import sys import tarfile import zipfile from pathlib import Path # ── Tame HF/transformers chatter so logs are readable in Cloud Logging ──────── os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ.setdefault("TRANSFORMERS_VERBOSITY", "warning") os.environ.setdefault("PYTHONUNBUFFERED", "1") os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") def env(name: str, default: str | None = None, *, required: bool = False) -> str: val = os.environ.get(name, default) if required and not val: sys.exit(f"[gcp_entrypoint] ERROR: required env var {name} not set") return val or "" # ── 1) Resolve config from env ──────────────────────────────────────────────── HF_TOKEN = env("HF_TOKEN", required=True) DATASET_NAME = env("DATASET_NAME", required=True) HF_USER = env("HF_USER", "hieu3636") REPORT_MODE = env("REPORT_MODE", "split_cascade") IMAGE_MODE = env("IMAGE_MODE", "all_views_split") S1_EPOCHS = int(env("S1_EPOCHS", "2")) S2_EPOCHS = int(env("S2_EPOCHS", "7")) MODE = env("MODE", "resume") EXPLICIT_RUN_ID = env("EXPLICIT_RUN_ID", "") HF_RUNS_REPO = env("HF_RUNS_REPO", "hieu3636/cxr-vlm-runs") WORK = Path(env("WORK", "/workspace")) assert DATASET_NAME in ("IU-Xray", "MIMIC-CXR", "MIMIC-CXR_resized"), DATASET_NAME assert MODE in ("fresh", "resume"), MODE PROJECT = Path(__file__).resolve().parent.parent # /workspace/code DATA_SRC = WORK / "data" CKPT_ROOT = WORK / "ckpt" DATA_SRC.mkdir(parents=True, exist_ok=True) CKPT_ROOT.mkdir(parents=True, exist_ok=True) print(f"[gcp_entrypoint] PROJECT = {PROJECT}") print(f"[gcp_entrypoint] WORK = {WORK}") print(f"[gcp_entrypoint] DATA_SRC = {DATA_SRC}") print(f"[gcp_entrypoint] DATASET = {DATASET_NAME} ({REPORT_MODE} / {IMAGE_MODE})") print(f"[gcp_entrypoint] MODE = {MODE} run_id={EXPLICIT_RUN_ID or '(auto)'}") # ── 2) Download dataset payload from HF Hub ─────────────────────────────────── # Mirrors cell-paths logic for each dataset shape. from huggingface_hub import HfApi, hf_hub_download, snapshot_download # noqa: E402 if DATASET_NAME == "MIMIC-CXR_resized": mr_dir = DATA_SRC / "MIMIC-CXR_resized" mr_dir.mkdir(parents=True, exist_ok=True) files_dir = mr_dir / "files" manifests_present = all( (mr_dir / f).is_file() for f in ("manifest_train.csv", "manifest_val.csv", "manifest_test.csv") ) if manifests_present and files_dir.is_dir() and any(files_dir.glob("p*")): print(f"[gcp_entrypoint] {mr_dir} already populated — skipping download.") else: api = HfApi(token=HF_TOKEN) all_files = api.list_repo_files( repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset" ) mr_files = [f for f in all_files if f.startswith("MIMIC-CXR_resized/")] tar_files = sorted(f for f in mr_files if f.endswith(".tar")) print(f"[gcp_entrypoint] {len(tar_files)} tar shards on HF") # Metadata (manifests, vqa, SHARDS.txt, _manifest.json) — small snapshot_download( repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset", allow_patterns=[ "MIMIC-CXR_resized/*.csv", "MIMIC-CXR_resized/*.json", "MIMIC-CXR_resized/*.txt", "MIMIC-CXR_resized/vqa/**", ], token=HF_TOKEN, local_dir=str(DATA_SRC), ) # Image shards — download, extract, delete to keep peak disk down for i, tf in enumerate(tar_files, 1): print(f"[gcp_entrypoint] [{i}/{len(tar_files)}] {tf}") tp = Path(hf_hub_download( repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset", filename=tf, token=HF_TOKEN, local_dir=str(DATA_SRC), )) with tarfile.open(tp) as t: t.extractall(mr_dir) tp.unlink(missing_ok=True) print(f"[gcp_entrypoint] {mr_dir} ready.") DATA_ROOT_RESIZED = mr_dir else: # MIMIC-CXR / IU-Xray: single zip per dataset zip_name = f"{DATASET_NAME}.zip" marker = DATA_SRC / DATASET_NAME if not marker.exists(): print(f"[gcp_entrypoint] downloading {zip_name} ...") zpath = hf_hub_download( repo_id=f"{HF_USER}/cxr-vlm-data", filename=zip_name, repo_type="dataset", token=HF_TOKEN, local_dir=str(DATA_SRC), ) with zipfile.ZipFile(zpath) as zf: zf.extractall(DATA_SRC) try: os.remove(zpath) except OSError: pass else: print(f"[gcp_entrypoint] {marker} already present — skipping download.") print(f"[gcp_entrypoint] DATA_SRC contents: {sorted(os.listdir(DATA_SRC))}") # ── 3) Patch configs (mirrors cell-cfg) ─────────────────────────────────────── import torch # noqa: E402 from omegaconf import OmegaConf # noqa: E402 train_cfg_path = PROJECT / "configs" / "train_config.yaml" model_cfg_path = PROJECT / "configs" / "model_config.yaml" train_cfg = OmegaConf.load(train_cfg_path) model_cfg = OmegaConf.load(model_cfg_path) # Dataset + training-scheme switches train_cfg.data.dataset_name = DATASET_NAME train_cfg.data.report_mode = REPORT_MODE train_cfg.data.image_mode = IMAGE_MODE train_cfg.data.max_images_per_sample = 2 out_dir = PROJECT / "data" / "data_files" out_dir.mkdir(parents=True, exist_ok=True) if DATASET_NAME == "MIMIC-CXR_resized": mr_json_path = out_dir / "mimic_cxr_resized_instruct.json" train_cfg.data.mimic_cxr_resized.root = str(DATA_ROOT_RESIZED) train_cfg.data.mimic_cxr_resized.manifest_dir = None train_cfg.data.mimic_cxr_resized.vqa_dir = None train_cfg.data.mimic_cxr_resized.reports_root = None train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path) train_cfg.data.mimic_cxr_resized.auto_build = True elif DATASET_NAME == "MIMIC-CXR": # Find the canonical {train,valid,test}/pXX/... layout def _find_mimic_root(root: Path) -> Path: for cand in [root / "MIMIC-CXR", root]: if (cand / "train").exists() and (cand / "valid").exists() and (cand / "test").exists(): return cand for p in root.rglob("train"): if p.is_dir() and (p.parent / "valid").exists() and (p.parent / "test").exists(): return p.parent raise FileNotFoundError(f"MIMIC-CXR train/valid/test not found under {root}") cxr_root = _find_mimic_root(DATA_SRC) train_cfg.data.mimic_cxr_root = str(cxr_root) train_cfg.data.instruct_json = str(out_dir / "mimic_cxr_instruct_unified.json") train_cfg.data.mimic_auto_build = True _cx = sorted(DATA_SRC.rglob("*chexpert*.csv")) or sorted(DATA_SRC.rglob("*chexbert*.csv")) train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None _vqa_candidates = list(DATA_SRC.rglob("vqa")) train_cfg.data.mimic_vqa_root = str(_vqa_candidates[0]) if _vqa_candidates else None else: # IU-Xray iu_root = DATA_SRC / "IU-Xray" train_cfg.data.iu_xray.images_dir = str(iu_root / "images") train_cfg.data.iu_xray.labels_dir = str(iu_root / "labels") train_cfg.data.iu_xray.instruct_json = str(out_dir / "iu_xray_instruct.json") train_cfg.data.iu_xray.auto_build = True train_cfg.data.train_split = "train" train_cfg.data.val_split = "validate" train_cfg.data.test_split = "test" train_cfg.training.output_root = str(CKPT_ROOT) # ── GPU auto-profile (verbatim from cell-cfg) ──────────────────────────────── assert torch.cuda.is_available(), "CUDA not available in container" _props = torch.cuda.get_device_properties(0) _cap = (_props.major, _props.minor) _vram_gb = _props.total_memory / 1e9 _bf16_ok = torch.cuda.is_bf16_supported() _fa2_ok = _cap >= (8, 0) print(f"[gcp_entrypoint] GPU: {_props.name} {_vram_gb:.1f}GB sm_{_cap[0]}{_cap[1]} bf16={_bf16_ok} fa2_capable={_fa2_ok}") _flash_attn_installed = False if _fa2_ok: try: import flash_attn # noqa: F401 _flash_attn_installed = True except Exception: _flash_attn_installed = False if _vram_gb >= 70: _profile = dict(label="A100/H100 80GB", per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=16, gradient_checkpointing=False) elif _vram_gb >= 35: _profile = dict(label="A100 40GB", per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=12, gradient_checkpointing=False) elif _vram_gb >= 22: _profile = dict(label="3090 / L4 / A10 (24GB)", per_device_train_batch_size=8, per_device_eval_batch_size=8, gradient_accumulation_steps=2, dataloader_num_workers=8, gradient_checkpointing=True) elif _vram_gb >= 14: _profile = dict(label="T4 / V100 (15-16GB)", per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=16, dataloader_num_workers=2, gradient_checkpointing=True) else: _profile = dict(label=f"unknown ({_vram_gb:.0f}GB)", per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=16, dataloader_num_workers=2, gradient_checkpointing=True) _profile["bf16"] = bool(_bf16_ok) _profile["fp16"] = not _bf16_ok _profile["attn_implementation"] = ( "flash_attention_2" if (_fa2_ok and _flash_attn_installed) else "sdpa" ) _profile["optim"] = "paged_adamw_8bit" if _cap >= (8, 0) else "adamw_torch" _profile["bnb_4bit_compute_dtype"] = "bfloat16" if _bf16_ok else "float16" _profile["torch_dtype"] = "bfloat16" if _bf16_ok else "float16" print(f"[gcp_entrypoint] → Profile: {_profile['label']}") train_cfg.training.per_device_train_batch_size = _profile["per_device_train_batch_size"] train_cfg.training.per_device_eval_batch_size = _profile["per_device_eval_batch_size"] train_cfg.training.gradient_accumulation_steps = _profile["gradient_accumulation_steps"] train_cfg.training.dataloader_num_workers = _profile["dataloader_num_workers"] train_cfg.training.fp16 = _profile["fp16"] train_cfg.training.bf16 = _profile["bf16"] train_cfg.training.dataloader_pin_memory = True train_cfg.training.dataloader_persistent_workers = True train_cfg.training.optim = _profile["optim"] train_cfg.stage1.num_epochs = S1_EPOCHS train_cfg.stage2.num_epochs = S2_EPOCHS model_cfg.llm.attn_implementation = _profile["attn_implementation"] model_cfg.llm.gradient_checkpointing = _profile["gradient_checkpointing"] model_cfg.llm.torch_dtype = _profile["torch_dtype"] model_cfg.llm.bnb_4bit_compute_dtype = _profile["bnb_4bit_compute_dtype"] model_cfg.llm.bnb_4bit_quant_type = "nf4" model_cfg.llm.bnb_4bit_use_double_quant = True model_cfg.llm.load_in_8bit = False model_cfg.llm.load_in_4bit = True model_cfg.chexpert_classifier.enabled = False train_cfg.wandb.enabled = False train_cfg.hf_hub.enabled = True train_cfg.hf_hub.repo_id = HF_RUNS_REPO train_cfg.hf_hub.token_env = "HF_TOKEN" train_cfg.hf_hub.private = True train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / "run_id.txt") OmegaConf.save(train_cfg, train_cfg_path) OmegaConf.save(model_cfg, model_cfg_path) print("[gcp_entrypoint] configs patched.") # ── 4) Pin run_id.txt if resuming with an explicit id ───────────────────────── if MODE == "resume" and EXPLICIT_RUN_ID: (CKPT_ROOT / "run_id.txt").write_text(EXPLICIT_RUN_ID) print(f"[gcp_entrypoint] pinned run_id = {EXPLICIT_RUN_ID}") # ── 5) Launch training ──────────────────────────────────────────────────────── cmd = [ "python", "-u", "-m", "training.train", "--model_config", str(model_cfg_path), "--train_config", str(train_cfg_path), "--mode", MODE, ] if MODE == "resume" and EXPLICIT_RUN_ID: cmd += ["--run_id", EXPLICIT_RUN_ID] print(f"[gcp_entrypoint] launching: {' '.join(cmd)}", flush=True) os.chdir(PROJECT) sys.exit(subprocess.call(cmd))