File size: 14,314 Bytes
2c84a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""GCP Vertex AI Custom Training Job entrypoint.

Mirrors the colab notebook's setup (cells: paths, cfg, resume, stage1):
  1. Download dataset payload from HF Hub (if not cached on disk)
  2. Patch configs/{train,model}_config.yaml for GPU profile + paths + HF Hub
  3. Pin run_id.txt for --mode resume
  4. Exec `python -m training.train --mode {fresh,resume}`

The container's command is expected to have already cloned the project source
(this file) into /workspace/code, then `cd /workspace/code` and run this script.

Required env vars:
  HF_TOKEN         — HuggingFace token (read access for code+data, write for runs)
  DATASET_NAME     — 'IU-Xray' | 'MIMIC-CXR' | 'MIMIC-CXR_resized'

Optional env vars (defaults shown):
  HF_USER             = hieu3636
  REPORT_MODE         = split_cascade
  IMAGE_MODE          = all_views_split
  S1_EPOCHS           = 2
  S2_EPOCHS           = 7
  MODE                = resume        # 'fresh' | 'resume'
  EXPLICIT_RUN_ID     = ''            # only matters when MODE=resume
  HF_RUNS_REPO        = hieu3636/cxr-vlm-runs
  WORK                = /workspace
"""

from __future__ import annotations

import os
import shutil
import subprocess
import sys
import tarfile
import zipfile
from pathlib import Path

# ── Tame HF/transformers chatter so logs are readable in Cloud Logging ────────
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "warning")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")


def env(name: str, default: str | None = None, *, required: bool = False) -> str:
    val = os.environ.get(name, default)
    if required and not val:
        sys.exit(f"[gcp_entrypoint] ERROR: required env var {name} not set")
    return val or ""


# ── 1) Resolve config from env ────────────────────────────────────────────────
HF_TOKEN        = env("HF_TOKEN", required=True)
DATASET_NAME    = env("DATASET_NAME", required=True)
HF_USER         = env("HF_USER", "hieu3636")
REPORT_MODE     = env("REPORT_MODE", "split_cascade")
IMAGE_MODE      = env("IMAGE_MODE", "all_views_split")
S1_EPOCHS       = int(env("S1_EPOCHS", "2"))
S2_EPOCHS       = int(env("S2_EPOCHS", "7"))
MODE            = env("MODE", "resume")
EXPLICIT_RUN_ID = env("EXPLICIT_RUN_ID", "")
HF_RUNS_REPO    = env("HF_RUNS_REPO", "hieu3636/cxr-vlm-runs")
WORK            = Path(env("WORK", "/workspace"))

assert DATASET_NAME in ("IU-Xray", "MIMIC-CXR", "MIMIC-CXR_resized"), DATASET_NAME
assert MODE in ("fresh", "resume"), MODE

PROJECT   = Path(__file__).resolve().parent.parent      # /workspace/code
DATA_SRC  = WORK / "data"
CKPT_ROOT = WORK / "ckpt"
DATA_SRC.mkdir(parents=True, exist_ok=True)
CKPT_ROOT.mkdir(parents=True, exist_ok=True)

print(f"[gcp_entrypoint] PROJECT  = {PROJECT}")
print(f"[gcp_entrypoint] WORK     = {WORK}")
print(f"[gcp_entrypoint] DATA_SRC = {DATA_SRC}")
print(f"[gcp_entrypoint] DATASET  = {DATASET_NAME}  ({REPORT_MODE} / {IMAGE_MODE})")
print(f"[gcp_entrypoint] MODE     = {MODE}  run_id={EXPLICIT_RUN_ID or '(auto)'}")

# ── 2) Download dataset payload from HF Hub ───────────────────────────────────
# Mirrors cell-paths logic for each dataset shape.
from huggingface_hub import HfApi, hf_hub_download, snapshot_download  # noqa: E402

if DATASET_NAME == "MIMIC-CXR_resized":
    mr_dir = DATA_SRC / "MIMIC-CXR_resized"
    mr_dir.mkdir(parents=True, exist_ok=True)
    files_dir = mr_dir / "files"
    manifests_present = all(
        (mr_dir / f).is_file()
        for f in ("manifest_train.csv", "manifest_val.csv", "manifest_test.csv")
    )
    if manifests_present and files_dir.is_dir() and any(files_dir.glob("p*")):
        print(f"[gcp_entrypoint] {mr_dir} already populated — skipping download.")
    else:
        api = HfApi(token=HF_TOKEN)
        all_files = api.list_repo_files(
            repo_id=f"{HF_USER}/cxr-vlm-data", repo_type="dataset"
        )
        mr_files = [f for f in all_files if f.startswith("MIMIC-CXR_resized/")]
        tar_files = sorted(f for f in mr_files if f.endswith(".tar"))
        print(f"[gcp_entrypoint] {len(tar_files)} tar shards on HF")

        # Metadata (manifests, vqa, SHARDS.txt, _manifest.json) — small
        snapshot_download(
            repo_id=f"{HF_USER}/cxr-vlm-data",
            repo_type="dataset",
            allow_patterns=[
                "MIMIC-CXR_resized/*.csv",
                "MIMIC-CXR_resized/*.json",
                "MIMIC-CXR_resized/*.txt",
                "MIMIC-CXR_resized/vqa/**",
            ],
            token=HF_TOKEN,
            local_dir=str(DATA_SRC),
        )

        # Image shards — download, extract, delete to keep peak disk down
        for i, tf in enumerate(tar_files, 1):
            print(f"[gcp_entrypoint]  [{i}/{len(tar_files)}] {tf}")
            tp = Path(hf_hub_download(
                repo_id=f"{HF_USER}/cxr-vlm-data",
                repo_type="dataset",
                filename=tf,
                token=HF_TOKEN,
                local_dir=str(DATA_SRC),
            ))
            with tarfile.open(tp) as t:
                t.extractall(mr_dir)
            tp.unlink(missing_ok=True)
        print(f"[gcp_entrypoint] {mr_dir} ready.")

    DATA_ROOT_RESIZED = mr_dir

else:
    # MIMIC-CXR / IU-Xray: single zip per dataset
    zip_name = f"{DATASET_NAME}.zip"
    marker = DATA_SRC / DATASET_NAME
    if not marker.exists():
        print(f"[gcp_entrypoint] downloading {zip_name} ...")
        zpath = hf_hub_download(
            repo_id=f"{HF_USER}/cxr-vlm-data",
            filename=zip_name,
            repo_type="dataset",
            token=HF_TOKEN,
            local_dir=str(DATA_SRC),
        )
        with zipfile.ZipFile(zpath) as zf:
            zf.extractall(DATA_SRC)
        try:
            os.remove(zpath)
        except OSError:
            pass
    else:
        print(f"[gcp_entrypoint] {marker} already present — skipping download.")

print(f"[gcp_entrypoint] DATA_SRC contents: {sorted(os.listdir(DATA_SRC))}")

# ── 3) Patch configs (mirrors cell-cfg) ───────────────────────────────────────
import torch  # noqa: E402
from omegaconf import OmegaConf  # noqa: E402

train_cfg_path = PROJECT / "configs" / "train_config.yaml"
model_cfg_path = PROJECT / "configs" / "model_config.yaml"
train_cfg = OmegaConf.load(train_cfg_path)
model_cfg = OmegaConf.load(model_cfg_path)

# Dataset + training-scheme switches
train_cfg.data.dataset_name           = DATASET_NAME
train_cfg.data.report_mode            = REPORT_MODE
train_cfg.data.image_mode             = IMAGE_MODE
train_cfg.data.max_images_per_sample  = 2

out_dir = PROJECT / "data" / "data_files"
out_dir.mkdir(parents=True, exist_ok=True)

if DATASET_NAME == "MIMIC-CXR_resized":
    mr_json_path = out_dir / "mimic_cxr_resized_instruct.json"
    train_cfg.data.mimic_cxr_resized.root          = str(DATA_ROOT_RESIZED)
    train_cfg.data.mimic_cxr_resized.manifest_dir  = None
    train_cfg.data.mimic_cxr_resized.vqa_dir       = None
    train_cfg.data.mimic_cxr_resized.reports_root  = None
    train_cfg.data.mimic_cxr_resized.instruct_json = str(mr_json_path)
    train_cfg.data.mimic_cxr_resized.auto_build    = True
elif DATASET_NAME == "MIMIC-CXR":
    # Find the canonical {train,valid,test}/pXX/... layout
    def _find_mimic_root(root: Path) -> Path:
        for cand in [root / "MIMIC-CXR", root]:
            if (cand / "train").exists() and (cand / "valid").exists() and (cand / "test").exists():
                return cand
        for p in root.rglob("train"):
            if p.is_dir() and (p.parent / "valid").exists() and (p.parent / "test").exists():
                return p.parent
        raise FileNotFoundError(f"MIMIC-CXR train/valid/test not found under {root}")
    cxr_root = _find_mimic_root(DATA_SRC)
    train_cfg.data.mimic_cxr_root = str(cxr_root)
    train_cfg.data.instruct_json  = str(out_dir / "mimic_cxr_instruct_unified.json")
    train_cfg.data.mimic_auto_build = True
    _cx = sorted(DATA_SRC.rglob("*chexpert*.csv")) or sorted(DATA_SRC.rglob("*chexbert*.csv"))
    train_cfg.data.mimic_chexpert_csv = str(_cx[0]) if _cx else None
    _vqa_candidates = list(DATA_SRC.rglob("vqa"))
    train_cfg.data.mimic_vqa_root = str(_vqa_candidates[0]) if _vqa_candidates else None
else:  # IU-Xray
    iu_root = DATA_SRC / "IU-Xray"
    train_cfg.data.iu_xray.images_dir    = str(iu_root / "images")
    train_cfg.data.iu_xray.labels_dir    = str(iu_root / "labels")
    train_cfg.data.iu_xray.instruct_json = str(out_dir / "iu_xray_instruct.json")
    train_cfg.data.iu_xray.auto_build    = True

train_cfg.data.train_split = "train"
train_cfg.data.val_split   = "validate"
train_cfg.data.test_split  = "test"
train_cfg.training.output_root = str(CKPT_ROOT)

# ── GPU auto-profile (verbatim from cell-cfg) ────────────────────────────────
assert torch.cuda.is_available(), "CUDA not available in container"
_props   = torch.cuda.get_device_properties(0)
_cap     = (_props.major, _props.minor)
_vram_gb = _props.total_memory / 1e9
_bf16_ok = torch.cuda.is_bf16_supported()
_fa2_ok  = _cap >= (8, 0)

print(f"[gcp_entrypoint] GPU: {_props.name}  {_vram_gb:.1f}GB  sm_{_cap[0]}{_cap[1]}  bf16={_bf16_ok}  fa2_capable={_fa2_ok}")

_flash_attn_installed = False
if _fa2_ok:
    try:
        import flash_attn  # noqa: F401
        _flash_attn_installed = True
    except Exception:
        _flash_attn_installed = False

if _vram_gb >= 70:
    _profile = dict(label="A100/H100 80GB",
                    per_device_train_batch_size=8, per_device_eval_batch_size=8,
                    gradient_accumulation_steps=2, dataloader_num_workers=16,
                    gradient_checkpointing=False)
elif _vram_gb >= 35:
    _profile = dict(label="A100 40GB",
                    per_device_train_batch_size=8, per_device_eval_batch_size=8,
                    gradient_accumulation_steps=2, dataloader_num_workers=12,
                    gradient_checkpointing=False)
elif _vram_gb >= 22:
    _profile = dict(label="3090 / L4 / A10 (24GB)",
                    per_device_train_batch_size=8, per_device_eval_batch_size=8,
                    gradient_accumulation_steps=2, dataloader_num_workers=8,
                    gradient_checkpointing=True)
elif _vram_gb >= 14:
    _profile = dict(label="T4 / V100 (15-16GB)",
                    per_device_train_batch_size=1, per_device_eval_batch_size=1,
                    gradient_accumulation_steps=16, dataloader_num_workers=2,
                    gradient_checkpointing=True)
else:
    _profile = dict(label=f"unknown ({_vram_gb:.0f}GB)",
                    per_device_train_batch_size=1, per_device_eval_batch_size=1,
                    gradient_accumulation_steps=16, dataloader_num_workers=2,
                    gradient_checkpointing=True)

_profile["bf16"] = bool(_bf16_ok)
_profile["fp16"] = not _bf16_ok
_profile["attn_implementation"] = (
    "flash_attention_2" if (_fa2_ok and _flash_attn_installed) else "sdpa"
)
_profile["optim"] = "paged_adamw_8bit" if _cap >= (8, 0) else "adamw_torch"
_profile["bnb_4bit_compute_dtype"] = "bfloat16" if _bf16_ok else "float16"
_profile["torch_dtype"]            = "bfloat16" if _bf16_ok else "float16"

print(f"[gcp_entrypoint] → Profile: {_profile['label']}")

train_cfg.training.per_device_train_batch_size   = _profile["per_device_train_batch_size"]
train_cfg.training.per_device_eval_batch_size    = _profile["per_device_eval_batch_size"]
train_cfg.training.gradient_accumulation_steps   = _profile["gradient_accumulation_steps"]
train_cfg.training.dataloader_num_workers        = _profile["dataloader_num_workers"]
train_cfg.training.fp16                          = _profile["fp16"]
train_cfg.training.bf16                          = _profile["bf16"]
train_cfg.training.dataloader_pin_memory         = True
train_cfg.training.dataloader_persistent_workers = True
train_cfg.training.optim                         = _profile["optim"]
train_cfg.stage1.num_epochs                      = S1_EPOCHS
train_cfg.stage2.num_epochs                      = S2_EPOCHS

model_cfg.llm.attn_implementation       = _profile["attn_implementation"]
model_cfg.llm.gradient_checkpointing    = _profile["gradient_checkpointing"]
model_cfg.llm.torch_dtype               = _profile["torch_dtype"]
model_cfg.llm.bnb_4bit_compute_dtype    = _profile["bnb_4bit_compute_dtype"]
model_cfg.llm.bnb_4bit_quant_type       = "nf4"
model_cfg.llm.bnb_4bit_use_double_quant = True
model_cfg.llm.load_in_8bit              = False
model_cfg.llm.load_in_4bit              = True
model_cfg.chexpert_classifier.enabled   = False

train_cfg.wandb.enabled         = False
train_cfg.hf_hub.enabled        = True
train_cfg.hf_hub.repo_id        = HF_RUNS_REPO
train_cfg.hf_hub.token_env      = "HF_TOKEN"
train_cfg.hf_hub.private        = True
train_cfg.hf_hub.run_state_file = str(CKPT_ROOT / "run_id.txt")

OmegaConf.save(train_cfg, train_cfg_path)
OmegaConf.save(model_cfg, model_cfg_path)
print("[gcp_entrypoint] configs patched.")

# ── 4) Pin run_id.txt if resuming with an explicit id ─────────────────────────
if MODE == "resume" and EXPLICIT_RUN_ID:
    (CKPT_ROOT / "run_id.txt").write_text(EXPLICIT_RUN_ID)
    print(f"[gcp_entrypoint] pinned run_id = {EXPLICIT_RUN_ID}")

# ── 5) Launch training ────────────────────────────────────────────────────────
cmd = [
    "python", "-u", "-m", "training.train",
    "--model_config", str(model_cfg_path),
    "--train_config", str(train_cfg_path),
    "--mode", MODE,
]
if MODE == "resume" and EXPLICIT_RUN_ID:
    cmd += ["--run_id", EXPLICIT_RUN_ID]

print(f"[gcp_entrypoint] launching: {' '.join(cmd)}", flush=True)
os.chdir(PROJECT)
sys.exit(subprocess.call(cmd))