| """GCP Vertex AI Custom Training Job entrypoint. |
| |
| Mirrors the colab notebook's setup (cells: paths, cfg, resume, stage1): |
| 1. Download dataset payload from HF Hub (if not cached on disk) |
| 2. Patch configs/{train,model}_config.yaml for GPU profile + paths + HF Hub |
| 3. Pin run_id.txt for --mode resume |
| 4. Exec `python -m training.train --mode {fresh,resume}` |
| |
| The container's command is expected to have already cloned the project source |
| (this file) into /workspace/code, then `cd /workspace/code` and run this script. |
| |
| Required env vars: |
| HF_TOKEN — HuggingFace token (read access for code+data, write for runs) |
| DATASET_NAME — 'IU-Xray' | 'MIMIC-CXR' | 'MIMIC-CXR_resized' |
| |
| Optional env vars (defaults shown): |
| HF_USER = hieu3636 |
| REPORT_MODE = split_cascade |
| IMAGE_MODE = all_views_split |
| S1_EPOCHS = 2 |
| S2_EPOCHS = 7 |
| MODE = resume # 'fresh' | 'resume' |
| EXPLICIT_RUN_ID = '' # only matters when MODE=resume |
| HF_RUNS_REPO = hieu3636/cxr-vlm-runs |
| WORK = /workspace |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import shutil |
| import subprocess |
| import sys |
| import tarfile |
| import zipfile |
| from pathlib import Path |
|
|
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") |
| os.environ.setdefault("TRANSFORMERS_VERBOSITY", "warning") |
| os.environ.setdefault("PYTHONUNBUFFERED", "1") |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") |
|
|
|
|
| def env(name: str, default: str | None = None, *, required: bool = False) -> str: |
| val = os.environ.get(name, default) |
| if required and not val: |
| sys.exit(f"[gcp_entrypoint] ERROR: required env var {name} not set") |
| return val or "" |
|
|
|
|
| |
| HF_TOKEN = env("HF_TOKEN", required=True) |
| DATASET_NAME = env("DATASET_NAME", required=True) |
| HF_USER = env("HF_USER", "hieu3636") |
| REPORT_MODE = env("REPORT_MODE", "split_cascade") |
| IMAGE_MODE = env("IMAGE_MODE", "all_views_split") |
| S1_EPOCHS = int(env("S1_EPOCHS", "2")) |
| S2_EPOCHS = int(env("S2_EPOCHS", "7")) |
| MODE = env("MODE", "resume") |
| EXPLICIT_RUN_ID = env("EXPLICIT_RUN_ID", "") |
| HF_RUNS_REPO = env("HF_RUNS_REPO", "hieu3636/cxr-vlm-runs") |
| WORK = Path(env("WORK", "/workspace")) |
|
|
| assert DATASET_NAME in ("IU-Xray", "MIMIC-CXR", "MIMIC-CXR_resized"), DATASET_NAME |
| assert MODE in ("fresh", "resume"), MODE |
|
|
| PROJECT = Path(__file__).resolve().parent.parent |
| DATA_SRC = WORK / "data" |
| CKPT_ROOT = WORK / "ckpt" |
| DATA_SRC.mkdir(parents=True, exist_ok=True) |
| CKPT_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"[gcp_entrypoint] PROJECT = {PROJECT}") |
| print(f"[gcp_entrypoint] WORK = {WORK}") |
| print(f"[gcp_entrypoint] DATA_SRC = {DATA_SRC}") |
| print(f"[gcp_entrypoint] DATASET = {DATASET_NAME} ({REPORT_MODE} / {IMAGE_MODE})") |
| print(f"[gcp_entrypoint] MODE = {MODE} run_id={EXPLICIT_RUN_ID or '(auto)'}") |
|
|
| |
| |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download |
|
|
| if DATASET_NAME == "MIMIC-CXR_resized": |
| mr_dir = DATA_SRC / "MIMIC-CXR_resized" |
| mr_dir.mkdir(parents=True, exist_ok=True) |
| files_dir = mr_dir / "files" |
| manifests_present = all( |
| (mr_dir / f).is_file() |
| for f in ("manifest_train.csv", "manifest_val.csv", "manifest_test.csv") |
| ) |
| if manifests_present and files_dir.is_dir() and any(files_dir.glob("p*")): |
| print(f"[gcp_entrypoint] {mr_dir} already populated — skipping download.") |
| else: |
| api = HfApi(token=HF_TOKEN) |
| all_files = api.list_repo_files( |
| repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset" |
| ) |
| mr_files = [f for f in all_files if f.startswith("MIMIC-CXR_resized/")] |
| tar_files = sorted(f for f in mr_files if f.endswith(".tar")) |
| print(f"[gcp_entrypoint] {len(tar_files)} tar shards on HF") |
|
|
| |
| snapshot_download( |
| repo_id=f"{HF_USER}/cxr-vlm-data", |
| repo_type="dataset", |
| allow_patterns=[ |
| "MIMIC-CXR_resized/*.csv", |
| "MIMIC-CXR_resized/*.json", |
| "MIMIC-CXR_resized/*.txt", |
| "MIMIC-CXR_resized/vqa/**", |
| ], |
| token=HF_TOKEN, |
| local_dir=str(DATA_SRC), |
| ) |
|
|
| |
| for i, tf in enumerate(tar_files, 1): |
| print(f"[gcp_entrypoint] [{i}/{len(tar_files)}] {tf}") |
| tp = Path(hf_hub_download( |
| repo_id=f"{HF_USER}/cxr-vlm-data", |
| repo_type="dataset", |
| filename=tf, |
| token=HF_TOKEN, |
| local_dir=str(DATA_SRC), |
| )) |
| with tarfile.open(tp) as t: |
| t.extractall(mr_dir) |
| tp.unlink(missing_ok=True) |
| print(f"[gcp_entrypoint] {mr_dir} ready.") |
|
|
| DATA_ROOT_RESIZED = mr_dir |
|
|
| else: |
| |
| zip_name = f"{DATASET_NAME}.zip" |
| marker = DATA_SRC / DATASET_NAME |
| if not marker.exists(): |
| print(f"[gcp_entrypoint] downloading {zip_name} ...") |
| zpath = hf_hub_download( |
| repo_id=f"{HF_USER}/cxr-vlm-data", |
| filename=zip_name, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| local_dir=str(DATA_SRC), |
| ) |
| with zipfile.ZipFile(zpath) as zf: |
| zf.extractall(DATA_SRC) |
| try: |
| os.remove(zpath) |
| except OSError: |
| pass |
| else: |
| print(f"[gcp_entrypoint] {marker} already present — skipping download.") |
|
|
| print(f"[gcp_entrypoint] DATA_SRC contents: {sorted(os.listdir(DATA_SRC))}") |
|
|
| |
| import torch |
| from omegaconf import OmegaConf |
|
|
| train_cfg_path = PROJECT / "configs" / "train_config.yaml" |
| model_cfg_path = PROJECT / "configs" / "model_config.yaml" |
| train_cfg = OmegaConf.load(train_cfg_path) |
| model_cfg = OmegaConf.load(model_cfg_path) |
|
|
| |
| train_cfg.data.dataset_name = DATASET_NAME |
| train_cfg.data.report_mode = REPORT_MODE |
| train_cfg.data.image_mode = IMAGE_MODE |
| train_cfg.data.max_images_per_sample = 2 |
|
|
| out_dir = PROJECT / "data" / "data_files" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if DATASET_NAME == "MIMIC-CXR_resized": |
| mr_json_path = out_dir / "mimic_cxr_resized_instruct.json" |
| train_cfg.data.mimic_cxr_resized.root = str(DATA_ROOT_RESIZED) |
| train_cfg.data.mimic_cxr_resized.manifest_dir = None |
| train_cfg.data.mimic_cxr_resized.vqa_dir = None |
| train_cfg.data.mimic_cxr_resized.reports_root = None |
| train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path) |
| train_cfg.data.mimic_cxr_resized.auto_build = True |
| elif DATASET_NAME == "MIMIC-CXR": |
| |
| def _find_mimic_root(root: Path) -> Path: |
| for cand in [root / "MIMIC-CXR", root]: |
| if (cand / "train").exists() and (cand / "valid").exists() and (cand / "test").exists(): |
| return cand |
| for p in root.rglob("train"): |
| if p.is_dir() and (p.parent / "valid").exists() and (p.parent / "test").exists(): |
| return p.parent |
| raise FileNotFoundError(f"MIMIC-CXR train/valid/test not found under {root}") |
| cxr_root = _find_mimic_root(DATA_SRC) |
| train_cfg.data.mimic_cxr_root = str(cxr_root) |
| train_cfg.data.instruct_json = str(out_dir / "mimic_cxr_instruct_unified.json") |
| train_cfg.data.mimic_auto_build = True |
| _cx = sorted(DATA_SRC.rglob("*chexpert*.csv")) or sorted(DATA_SRC.rglob("*chexbert*.csv")) |
| train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None |
| _vqa_candidates = list(DATA_SRC.rglob("vqa")) |
| train_cfg.data.mimic_vqa_root = str(_vqa_candidates[0]) if _vqa_candidates else None |
| else: |
| iu_root = DATA_SRC / "IU-Xray" |
| train_cfg.data.iu_xray.images_dir = str(iu_root / "images") |
| train_cfg.data.iu_xray.labels_dir = str(iu_root / "labels") |
| train_cfg.data.iu_xray.instruct_json = str(out_dir / "iu_xray_instruct.json") |
| train_cfg.data.iu_xray.auto_build = True |
|
|
| train_cfg.data.train_split = "train" |
| train_cfg.data.val_split = "validate" |
| train_cfg.data.test_split = "test" |
| train_cfg.training.output_root = str(CKPT_ROOT) |
|
|
| |
| assert torch.cuda.is_available(), "CUDA not available in container" |
| _props = torch.cuda.get_device_properties(0) |
| _cap = (_props.major, _props.minor) |
| _vram_gb = _props.total_memory / 1e9 |
| _bf16_ok = torch.cuda.is_bf16_supported() |
| _fa2_ok = _cap >= (8, 0) |
|
|
| print(f"[gcp_entrypoint] GPU: {_props.name} {_vram_gb:.1f}GB sm_{_cap[0]}{_cap[1]} bf16={_bf16_ok} fa2_capable={_fa2_ok}") |
|
|
| _flash_attn_installed = False |
| if _fa2_ok: |
| try: |
| import flash_attn |
| _flash_attn_installed = True |
| except Exception: |
| _flash_attn_installed = False |
|
|
| if _vram_gb >= 70: |
| _profile = dict(label="A100/H100 80GB", |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=16, |
| gradient_checkpointing=False) |
| elif _vram_gb >= 35: |
| _profile = dict(label="A100 40GB", |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=12, |
| gradient_checkpointing=False) |
| elif _vram_gb >= 22: |
| _profile = dict(label="3090 / L4 / A10 (24GB)", |
| per_device_train_batch_size=8, per_device_eval_batch_size=8, |
| gradient_accumulation_steps=2, dataloader_num_workers=8, |
| gradient_checkpointing=True) |
| elif _vram_gb >= 14: |
| _profile = dict(label="T4 / V100 (15-16GB)", |
| per_device_train_batch_size=1, per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, dataloader_num_workers=2, |
| gradient_checkpointing=True) |
| else: |
| _profile = dict(label=f"unknown ({_vram_gb:.0f}GB)", |
| per_device_train_batch_size=1, per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, dataloader_num_workers=2, |
| gradient_checkpointing=True) |
|
|
| _profile["bf16"] = bool(_bf16_ok) |
| _profile["fp16"] = not _bf16_ok |
| _profile["attn_implementation"] = ( |
| "flash_attention_2" if (_fa2_ok and _flash_attn_installed) else "sdpa" |
| ) |
| _profile["optim"] = "paged_adamw_8bit" if _cap >= (8, 0) else "adamw_torch" |
| _profile["bnb_4bit_compute_dtype"] = "bfloat16" if _bf16_ok else "float16" |
| _profile["torch_dtype"] = "bfloat16" if _bf16_ok else "float16" |
|
|
| print(f"[gcp_entrypoint] → Profile: {_profile['label']}") |
|
|
| train_cfg.training.per_device_train_batch_size = _profile["per_device_train_batch_size"] |
| train_cfg.training.per_device_eval_batch_size = _profile["per_device_eval_batch_size"] |
| train_cfg.training.gradient_accumulation_steps = _profile["gradient_accumulation_steps"] |
| train_cfg.training.dataloader_num_workers = _profile["dataloader_num_workers"] |
| train_cfg.training.fp16 = _profile["fp16"] |
| train_cfg.training.bf16 = _profile["bf16"] |
| train_cfg.training.dataloader_pin_memory = True |
| train_cfg.training.dataloader_persistent_workers = True |
| train_cfg.training.optim = _profile["optim"] |
| train_cfg.stage1.num_epochs = S1_EPOCHS |
| train_cfg.stage2.num_epochs = S2_EPOCHS |
|
|
| model_cfg.llm.attn_implementation = _profile["attn_implementation"] |
| model_cfg.llm.gradient_checkpointing = _profile["gradient_checkpointing"] |
| model_cfg.llm.torch_dtype = _profile["torch_dtype"] |
| model_cfg.llm.bnb_4bit_compute_dtype = _profile["bnb_4bit_compute_dtype"] |
| model_cfg.llm.bnb_4bit_quant_type = "nf4" |
| model_cfg.llm.bnb_4bit_use_double_quant = True |
| model_cfg.llm.load_in_8bit = False |
| model_cfg.llm.load_in_4bit = True |
| model_cfg.chexpert_classifier.enabled = False |
|
|
| train_cfg.wandb.enabled = False |
| train_cfg.hf_hub.enabled = True |
| train_cfg.hf_hub.repo_id = HF_RUNS_REPO |
| train_cfg.hf_hub.token_env = "HF_TOKEN" |
| train_cfg.hf_hub.private = True |
| train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / "run_id.txt") |
|
|
| OmegaConf.save(train_cfg, train_cfg_path) |
| OmegaConf.save(model_cfg, model_cfg_path) |
| print("[gcp_entrypoint] configs patched.") |
|
|
| |
| if MODE == "resume" and EXPLICIT_RUN_ID: |
| (CKPT_ROOT / "run_id.txt").write_text(EXPLICIT_RUN_ID) |
| print(f"[gcp_entrypoint] pinned run_id = {EXPLICIT_RUN_ID}") |
|
|
| |
| cmd = [ |
| "python", "-u", "-m", "training.train", |
| "--model_config", str(model_cfg_path), |
| "--train_config", str(train_cfg_path), |
| "--mode", MODE, |
| ] |
| if MODE == "resume" and EXPLICIT_RUN_ID: |
| cmd += ["--run_id", EXPLICIT_RUN_ID] |
|
|
| print(f"[gcp_entrypoint] launching: {' '.join(cmd)}", flush=True) |
| os.chdir(PROJECT) |
| sys.exit(subprocess.call(cmd)) |
|
|