cxr-vlm-code / scripts /gcp_entrypoint.py
convitom
f
2c84a70
"""GCP Vertex AI Custom Training Job entrypoint.
Mirrors the colab notebook's setup (cells: paths, cfg, resume, stage1):
1. Download dataset payload from HF Hub (if not cached on disk)
2. Patch configs/{train,model}_config.yaml for GPU profile + paths + HF Hub
3. Pin run_id.txt for --mode resume
4. Exec `python -m training.train --mode {fresh,resume}`
The container's command is expected to have already cloned the project source
(this file) into /workspace/code, then `cd /workspace/code` and run this script.
Required env vars:
HF_TOKEN — HuggingFace token (read access for code+data, write for runs)
DATASET_NAME — 'IU-Xray' | 'MIMIC-CXR' | 'MIMIC-CXR_resized'
Optional env vars (defaults shown):
HF_USER = hieu3636
REPORT_MODE = split_cascade
IMAGE_MODE = all_views_split
S1_EPOCHS = 2
S2_EPOCHS = 7
MODE = resume # 'fresh' | 'resume'
EXPLICIT_RUN_ID = '' # only matters when MODE=resume
HF_RUNS_REPO = hieu3636/cxr-vlm-runs
WORK = /workspace
"""
from __future__ import annotations
import os
import shutil
import subprocess
import sys
import tarfile
import zipfile
from pathlib import Path
# ── Tame HF/transformers chatter so logs are readable in Cloud Logging ────────
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "warning")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
def env(name: str, default: str | None = None, *, required: bool = False) -> str:
val = os.environ.get(name, default)
if required and not val:
sys.exit(f"[gcp_entrypoint] ERROR: required env var {name} not set")
return val or ""
# ── 1) Resolve config from env ────────────────────────────────────────────────
HF_TOKEN = env("HF_TOKEN", required=True)
DATASET_NAME = env("DATASET_NAME", required=True)
HF_USER = env("HF_USER", "hieu3636")
REPORT_MODE = env("REPORT_MODE", "split_cascade")
IMAGE_MODE = env("IMAGE_MODE", "all_views_split")
S1_EPOCHS = int(env("S1_EPOCHS", "2"))
S2_EPOCHS = int(env("S2_EPOCHS", "7"))
MODE = env("MODE", "resume")
EXPLICIT_RUN_ID = env("EXPLICIT_RUN_ID", "")
HF_RUNS_REPO = env("HF_RUNS_REPO", "hieu3636/cxr-vlm-runs")
WORK = Path(env("WORK", "/workspace"))
assert DATASET_NAME in ("IU-Xray", "MIMIC-CXR", "MIMIC-CXR_resized"), DATASET_NAME
assert MODE in ("fresh", "resume"), MODE
PROJECT = Path(__file__).resolve().parent.parent # /workspace/code
DATA_SRC = WORK / "data"
CKPT_ROOT = WORK / "ckpt"
DATA_SRC.mkdir(parents=True, exist_ok=True)
CKPT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"[gcp_entrypoint] PROJECT = {PROJECT}")
print(f"[gcp_entrypoint] WORK = {WORK}")
print(f"[gcp_entrypoint] DATA_SRC = {DATA_SRC}")
print(f"[gcp_entrypoint] DATASET = {DATASET_NAME} ({REPORT_MODE} / {IMAGE_MODE})")
print(f"[gcp_entrypoint] MODE = {MODE} run_id={EXPLICIT_RUN_ID or '(auto)'}")
# ── 2) Download dataset payload from HF Hub ───────────────────────────────────
# Mirrors cell-paths logic for each dataset shape.
from huggingface_hub import HfApi, hf_hub_download, snapshot_download # noqa: E402
if DATASET_NAME == "MIMIC-CXR_resized":
mr_dir = DATA_SRC / "MIMIC-CXR_resized"
mr_dir.mkdir(parents=True, exist_ok=True)
files_dir = mr_dir / "files"
manifests_present = all(
(mr_dir / f).is_file()
for f in ("manifest_train.csv", "manifest_val.csv", "manifest_test.csv")
)
if manifests_present and files_dir.is_dir() and any(files_dir.glob("p*")):
print(f"[gcp_entrypoint] {mr_dir} already populated — skipping download.")
else:
api = HfApi(token=HF_TOKEN)
all_files = api.list_repo_files(
repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset"
)
mr_files = [f for f in all_files if f.startswith("MIMIC-CXR_resized/")]
tar_files = sorted(f for f in mr_files if f.endswith(".tar"))
print(f"[gcp_entrypoint] {len(tar_files)} tar shards on HF")
# Metadata (manifests, vqa, SHARDS.txt, _manifest.json) — small
snapshot_download(
repo_id=f"{HF_USER}/cxr-vlm-data",
repo_type="dataset",
allow_patterns=[
"MIMIC-CXR_resized/*.csv",
"MIMIC-CXR_resized/*.json",
"MIMIC-CXR_resized/*.txt",
"MIMIC-CXR_resized/vqa/**",
],
token=HF_TOKEN,
local_dir=str(DATA_SRC),
)
# Image shards — download, extract, delete to keep peak disk down
for i, tf in enumerate(tar_files, 1):
print(f"[gcp_entrypoint] [{i}/{len(tar_files)}] {tf}")
tp = Path(hf_hub_download(
repo_id=f"{HF_USER}/cxr-vlm-data",
repo_type="dataset",
filename=tf,
token=HF_TOKEN,
local_dir=str(DATA_SRC),
))
with tarfile.open(tp) as t:
t.extractall(mr_dir)
tp.unlink(missing_ok=True)
print(f"[gcp_entrypoint] {mr_dir} ready.")
DATA_ROOT_RESIZED = mr_dir
else:
# MIMIC-CXR / IU-Xray: single zip per dataset
zip_name = f"{DATASET_NAME}.zip"
marker = DATA_SRC / DATASET_NAME
if not marker.exists():
print(f"[gcp_entrypoint] downloading {zip_name} ...")
zpath = hf_hub_download(
repo_id=f"{HF_USER}/cxr-vlm-data",
filename=zip_name,
repo_type="dataset",
token=HF_TOKEN,
local_dir=str(DATA_SRC),
)
with zipfile.ZipFile(zpath) as zf:
zf.extractall(DATA_SRC)
try:
os.remove(zpath)
except OSError:
pass
else:
print(f"[gcp_entrypoint] {marker} already present — skipping download.")
print(f"[gcp_entrypoint] DATA_SRC contents: {sorted(os.listdir(DATA_SRC))}")
# ── 3) Patch configs (mirrors cell-cfg) ───────────────────────────────────────
import torch # noqa: E402
from omegaconf import OmegaConf # noqa: E402
train_cfg_path = PROJECT / "configs" / "train_config.yaml"
model_cfg_path = PROJECT / "configs" / "model_config.yaml"
train_cfg = OmegaConf.load(train_cfg_path)
model_cfg = OmegaConf.load(model_cfg_path)
# Dataset + training-scheme switches
train_cfg.data.dataset_name = DATASET_NAME
train_cfg.data.report_mode = REPORT_MODE
train_cfg.data.image_mode = IMAGE_MODE
train_cfg.data.max_images_per_sample = 2
out_dir = PROJECT / "data" / "data_files"
out_dir.mkdir(parents=True, exist_ok=True)
if DATASET_NAME == "MIMIC-CXR_resized":
mr_json_path = out_dir / "mimic_cxr_resized_instruct.json"
train_cfg.data.mimic_cxr_resized.root = str(DATA_ROOT_RESIZED)
train_cfg.data.mimic_cxr_resized.manifest_dir = None
train_cfg.data.mimic_cxr_resized.vqa_dir = None
train_cfg.data.mimic_cxr_resized.reports_root = None
train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path)
train_cfg.data.mimic_cxr_resized.auto_build = True
elif DATASET_NAME == "MIMIC-CXR":
# Find the canonical {train,valid,test}/pXX/... layout
def _find_mimic_root(root: Path) -> Path:
for cand in [root / "MIMIC-CXR", root]:
if (cand / "train").exists() and (cand / "valid").exists() and (cand / "test").exists():
return cand
for p in root.rglob("train"):
if p.is_dir() and (p.parent / "valid").exists() and (p.parent / "test").exists():
return p.parent
raise FileNotFoundError(f"MIMIC-CXR train/valid/test not found under {root}")
cxr_root = _find_mimic_root(DATA_SRC)
train_cfg.data.mimic_cxr_root = str(cxr_root)
train_cfg.data.instruct_json = str(out_dir / "mimic_cxr_instruct_unified.json")
train_cfg.data.mimic_auto_build = True
_cx = sorted(DATA_SRC.rglob("*chexpert*.csv")) or sorted(DATA_SRC.rglob("*chexbert*.csv"))
train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None
_vqa_candidates = list(DATA_SRC.rglob("vqa"))
train_cfg.data.mimic_vqa_root = str(_vqa_candidates[0]) if _vqa_candidates else None
else: # IU-Xray
iu_root = DATA_SRC / "IU-Xray"
train_cfg.data.iu_xray.images_dir = str(iu_root / "images")
train_cfg.data.iu_xray.labels_dir = str(iu_root / "labels")
train_cfg.data.iu_xray.instruct_json = str(out_dir / "iu_xray_instruct.json")
train_cfg.data.iu_xray.auto_build = True
train_cfg.data.train_split = "train"
train_cfg.data.val_split = "validate"
train_cfg.data.test_split = "test"
train_cfg.training.output_root = str(CKPT_ROOT)
# ── GPU auto-profile (verbatim from cell-cfg) ────────────────────────────────
assert torch.cuda.is_available(), "CUDA not available in container"
_props = torch.cuda.get_device_properties(0)
_cap = (_props.major, _props.minor)
_vram_gb = _props.total_memory / 1e9
_bf16_ok = torch.cuda.is_bf16_supported()
_fa2_ok = _cap >= (8, 0)
print(f"[gcp_entrypoint] GPU: {_props.name} {_vram_gb:.1f}GB sm_{_cap[0]}{_cap[1]} bf16={_bf16_ok} fa2_capable={_fa2_ok}")
_flash_attn_installed = False
if _fa2_ok:
try:
import flash_attn # noqa: F401
_flash_attn_installed = True
except Exception:
_flash_attn_installed = False
if _vram_gb >= 70:
_profile = dict(label="A100/H100 80GB",
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=16,
gradient_checkpointing=False)
elif _vram_gb >= 35:
_profile = dict(label="A100 40GB",
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=12,
gradient_checkpointing=False)
elif _vram_gb >= 22:
_profile = dict(label="3090 / L4 / A10 (24GB)",
per_device_train_batch_size=8, per_device_eval_batch_size=8,
gradient_accumulation_steps=2, dataloader_num_workers=8,
gradient_checkpointing=True)
elif _vram_gb >= 14:
_profile = dict(label="T4 / V100 (15-16GB)",
per_device_train_batch_size=1, per_device_eval_batch_size=1,
gradient_accumulation_steps=16, dataloader_num_workers=2,
gradient_checkpointing=True)
else:
_profile = dict(label=f"unknown ({_vram_gb:.0f}GB)",
per_device_train_batch_size=1, per_device_eval_batch_size=1,
gradient_accumulation_steps=16, dataloader_num_workers=2,
gradient_checkpointing=True)
_profile["bf16"] = bool(_bf16_ok)
_profile["fp16"] = not _bf16_ok
_profile["attn_implementation"] = (
"flash_attention_2" if (_fa2_ok and _flash_attn_installed) else "sdpa"
)
_profile["optim"] = "paged_adamw_8bit" if _cap >= (8, 0) else "adamw_torch"
_profile["bnb_4bit_compute_dtype"] = "bfloat16" if _bf16_ok else "float16"
_profile["torch_dtype"] = "bfloat16" if _bf16_ok else "float16"
print(f"[gcp_entrypoint] → Profile: {_profile['label']}")
train_cfg.training.per_device_train_batch_size = _profile["per_device_train_batch_size"]
train_cfg.training.per_device_eval_batch_size = _profile["per_device_eval_batch_size"]
train_cfg.training.gradient_accumulation_steps = _profile["gradient_accumulation_steps"]
train_cfg.training.dataloader_num_workers = _profile["dataloader_num_workers"]
train_cfg.training.fp16 = _profile["fp16"]
train_cfg.training.bf16 = _profile["bf16"]
train_cfg.training.dataloader_pin_memory = True
train_cfg.training.dataloader_persistent_workers = True
train_cfg.training.optim = _profile["optim"]
train_cfg.stage1.num_epochs = S1_EPOCHS
train_cfg.stage2.num_epochs = S2_EPOCHS
model_cfg.llm.attn_implementation = _profile["attn_implementation"]
model_cfg.llm.gradient_checkpointing = _profile["gradient_checkpointing"]
model_cfg.llm.torch_dtype = _profile["torch_dtype"]
model_cfg.llm.bnb_4bit_compute_dtype = _profile["bnb_4bit_compute_dtype"]
model_cfg.llm.bnb_4bit_quant_type = "nf4"
model_cfg.llm.bnb_4bit_use_double_quant = True
model_cfg.llm.load_in_8bit = False
model_cfg.llm.load_in_4bit = True
model_cfg.chexpert_classifier.enabled = False
train_cfg.wandb.enabled = False
train_cfg.hf_hub.enabled = True
train_cfg.hf_hub.repo_id = HF_RUNS_REPO
train_cfg.hf_hub.token_env = "HF_TOKEN"
train_cfg.hf_hub.private = True
train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / "run_id.txt")
OmegaConf.save(train_cfg, train_cfg_path)
OmegaConf.save(model_cfg, model_cfg_path)
print("[gcp_entrypoint] configs patched.")
# ── 4) Pin run_id.txt if resuming with an explicit id ─────────────────────────
if MODE == "resume" and EXPLICIT_RUN_ID:
(CKPT_ROOT / "run_id.txt").write_text(EXPLICIT_RUN_ID)
print(f"[gcp_entrypoint] pinned run_id = {EXPLICIT_RUN_ID}")
# ── 5) Launch training ────────────────────────────────────────────────────────
cmd = [
"python", "-u", "-m", "training.train",
"--model_config", str(model_cfg_path),
"--train_config", str(train_cfg_path),
"--mode", MODE,
]
if MODE == "resume" and EXPLICIT_RUN_ID:
cmd += ["--run_id", EXPLICIT_RUN_ID]
print(f"[gcp_entrypoint] launching: {' '.join(cmd)}", flush=True)
os.chdir(PROJECT)
sys.exit(subprocess.call(cmd))