convitom commited on
Commit
9dadb47
·
1 Parent(s): c369576
Files changed (2) hide show
  1. scripts/cxrvlm_colab_train.ipynb +107 -365
  2. training/train.py +214 -15
scripts/cxrvlm_colab_train.ipynb CHANGED
@@ -966,174 +966,91 @@
966
  {
967
  "cell_type": "markdown",
968
  "metadata": {
969
- "id": "cell-resume-md"
970
  },
971
  "source": [
972
- "## 5b. Resume a previous run (only if you were interrupted)\n",
973
  "\n",
974
- "**Skip this section if you're starting fresh.** Set `RESUME_STAGE = None` in the cell below and run Stage 1 → Stage 2 normally.\n",
 
975
  "\n",
976
- "### Khi nào cần resume\n",
 
 
 
977
  "\n",
978
- "| Tình huống | Cần làm |\n",
979
- "|---|---|\n",
980
- "| Stage 1 đang train dở, cùng VM | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID=None` |\n",
981
- "| Stage 1 dở, **VM mới** (Colab disconnect) | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID='IU-Xray_run_1'` |\n",
982
- "| Stage 1 xong, chạy tiếp stage 2 | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` (cùng VM) hoặc set (VM mới) |\n",
983
- "| Stage 2 đang dở, cùng VM | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` |\n",
984
- "| Stage 2 dở, **VM mới** | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID='IU-Xray_run_1'` |\n",
985
  "\n",
986
- "### Nếu VM mới (Colab đã disconnect ít nhất 1 lần)\n",
987
- "\n",
988
- "Phải chạy lại **tất cả các cell từ đầu** trước khi đến đây:\n",
989
- "1. `cell-select` → `cell-env` → `cell-paths` (pull code + data từ HF)\n",
990
- "2. `cell-pip` → **Runtime → Restart session** → `cell-pip-verify`\n",
991
- "3. `cell-hf-token` → `cell-cfg` → `cell-sanity`\n",
992
- "4. **Skip `cell-stage1` / `cell-stage2` mặc định** — chạy `cell-resume` ngay bên dưới, rồi chạy cell train tương ứng.\n",
993
- "\n",
994
- "Cell resume bên dưới tự:\n",
995
- "- Pull các folder `checkpoint-<step>` mới nhất từ HF Hub về (từ `stage1/` hoặc `stage2/` trên hub).\n",
996
- "- Rename `stage1/` → `stage1_projection/`, `stage2/` → `stage2_instruct/` cho khớp layout local mà `train.py` kỳ vọng.\n",
997
- "- Ghi lại `run_id.txt` trong `CKPT_ROOT`.\n",
998
- "- Tìm checkpoint có step cao nhất và set `RESUME_FROM` cho cell train.\n",
999
- "\n",
1000
- "### Train tiếp sẽ lưu ở đâu trên HF?\n",
1001
- "\n",
1002
- "**Vẫn cùng folder `<RUN_ID>/` trên HF.** `HFRunTracker` tái sử dụng `run_id` đã resolve nên:\n",
1003
- "- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/stage2/checkpoint-XXXX/` (intermediate, callback mỗi 200 step)\n",
1004
- "- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/stage2/stage2_final.pt` (final)\n",
1005
- "- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/meta.json` (merge với meta cũ, tăng `resume_count`)\n",
1006
- "\n",
1007
- "Không tạo `run_2` mới.\n",
1008
- "\n",
1009
- "### Auto-detect run_id\n",
1010
- "\n",
1011
- "`train.py` resolve `run_id` theo thứ tự:\n",
1012
- "1. `--run_id` CLI (cell resume truyền cái này khi bạn set `EXPLICIT_RUN_ID`).\n",
1013
- "2. `run_id.txt` trong `CKPT_ROOT` (cell resume ghi lại).\n",
1014
- "3. Nếu cả 2 trống + `--resume_from` → tự tìm run mới nhất trên HF.\n"
1015
  ],
1016
  "id": "cell-resume-md"
1017
  },
1018
  {
1019
  "cell_type": "code",
1020
  "metadata": {
1021
- "id": "cell-resume",
1022
  "colab": {
1023
  "base_uri": "https://localhost:8080/"
1024
  },
1025
  "outputId": "6f1b15fb-e751-4209-ea52-7a98a8c40212"
1026
  },
1027
  "execution_count": null,
1028
- "outputs": [
1029
- {
1030
- "output_type": "stream",
1031
- "name": "stdout",
1032
- "text": [
1033
- "RESUME_STAGE=None — fresh run. Skip this cell; go straight to cell-stage1.\n"
1034
- ]
1035
- }
1036
- ],
1037
  "source": [
1038
- "# Resume controller — edit 2 variables, run once, then run cell-stage1 or cell-stage2.\n",
1039
- "RESUME_STAGE = None # None | 1 | 2 (None = fresh run, skip this cell)\n",
1040
- "EXPLICIT_RUN_ID = \"IU-Xray_run_1\" # None | \"IU-Xray_run_1\" (set this if VM is fresh)\n",
1041
- "\n",
1042
- "RESUME_FROM = None\n",
1043
- "RESUME_RUN_ID = None\n",
1044
- "\n",
1045
- "if RESUME_STAGE is not None:\n",
1046
- " assert RESUME_STAGE in (1, 2), \"RESUME_STAGE must be 1 or 2\"\n",
1047
- "\n",
1048
- " # 1) Resolve run_id: explicit > local state_file\n",
1049
- " if EXPLICIT_RUN_ID:\n",
1050
- " RESUME_RUN_ID = EXPLICIT_RUN_ID\n",
1051
- " CKPT_ROOT.mkdir(parents=True, exist_ok=True)\n",
1052
- " (CKPT_ROOT / \"run_id.txt\").write_text(RESUME_RUN_ID)\n",
1053
- " print(f\"Using EXPLICIT_RUN_ID = {RESUME_RUN_ID} (wrote run_id.txt)\")\n",
1054
- " else:\n",
1055
- " state_file = CKPT_ROOT / \"run_id.txt\"\n",
1056
- " assert state_file.exists(), (\n",
1057
- " \"No local run_id.txt — looks like a fresh VM. \"\n",
1058
- " \"Set EXPLICIT_RUN_ID to the run folder on HF (e.g. \\\"IU-Xray_run_1\\\").\"\n",
1059
- " )\n",
1060
- " RESUME_RUN_ID = state_file.read_text().strip()\n",
1061
- " print(f\"Using run_id from state file: {RESUME_RUN_ID}\")\n",
1062
- "\n",
1063
- " # 2) Local subdir names (code expects long names; HF uses short \"stage1\"/\"stage2\")\n",
1064
- " local_subdir = \"stage1_projection\" if RESUME_STAGE == 1 else \"stage2_instruct\"\n",
1065
- " remote_subdir = \"stage1\" if RESUME_STAGE == 1 else \"stage2\"\n",
1066
- " local_stage_dir = CKPT_ROOT / RESUME_RUN_ID / local_subdir\n",
1067
- " local_stage_dir.mkdir(parents=True, exist_ok=True)\n",
1068
- "\n",
1069
- " # 3) If no local checkpoints, pull from HF\n",
1070
- " existing = sorted(local_stage_dir.glob(\"checkpoint-*\"),\n",
1071
- " key=lambda p: int(p.name.split(\"-\")[1]))\n",
1072
- " if not existing:\n",
1073
- " print(f\"No local checkpoints under {local_stage_dir} — pulling from HF Hub …\")\n",
1074
- " from huggingface_hub import snapshot_download\n",
1075
- " hub_prefix = f\"{RESUME_RUN_ID}/{remote_subdir}/\"\n",
1076
- " pulled = snapshot_download(\n",
1077
- " repo_id = train_cfg.hf_hub.repo_id,\n",
1078
- " repo_type = \"model\",\n",
1079
- " token = os.environ[\"HF_TOKEN\"],\n",
1080
- " allow_patterns = [f\"{hub_prefix}checkpoint-*/**\"],\n",
1081
- " local_dir = str(WORK / \"hf_pull\"),\n",
1082
- " )\n",
1083
- " hub_stage_dir = Path(pulled) / RESUME_RUN_ID / remote_subdir\n",
1084
- " assert hub_stage_dir.exists() and any(hub_stage_dir.glob(\"checkpoint-*\")), (\n",
1085
- " f\"No checkpoints found under {hub_prefix} on HF repo \"\n",
1086
- " f\"{train_cfg.hf_hub.repo_id}. Did you set the right EXPLICIT_RUN_ID?\"\n",
1087
- " )\n",
1088
- " for ck in hub_stage_dir.glob(\"checkpoint-*\"):\n",
1089
- " dst = local_stage_dir / ck.name\n",
1090
- " if dst.exists():\n",
1091
- " shutil.rmtree(dst)\n",
1092
- " shutil.move(str(ck), str(dst))\n",
1093
- " print(f\"Pulled {len(list(local_stage_dir.glob('checkpoint-*')))} checkpoint(s) → {local_stage_dir}\")\n",
1094
- " existing = sorted(local_stage_dir.glob(\"checkpoint-*\"),\n",
1095
- " key=lambda p: int(p.name.split(\"-\")[1]))\n",
1096
- "\n",
1097
- " assert existing, f\"Still no checkpoints under {local_stage_dir}\"\n",
1098
- "\n",
1099
- " # 4) Latest checkpoint = highest global_step\n",
1100
- " RESUME_FROM = existing[-1]\n",
1101
- " print()\n",
1102
- " print(f\"✓ Ready to resume STAGE {RESUME_STAGE} of run {RESUME_RUN_ID}\")\n",
1103
- " print(f\" checkpoints on disk : {[c.name for c in existing]}\")\n",
1104
- " print(f\" will resume from : {RESUME_FROM}\")\n",
1105
- " print()\n",
1106
- " print(f\"→ Now run cell-stage{RESUME_STAGE} below.\")\n",
1107
  "else:\n",
1108
- " print(\"RESUME_STAGE=Nonefresh run. Skip this cell; go straight to cell-stage1.\")\n"
1109
  ],
1110
  "id": "cell-resume"
1111
  },
1112
  {
1113
  "cell_type": "markdown",
1114
  "metadata": {
1115
- "id": "cell-stage1-md"
1116
  },
1117
  "source": [
1118
- "## 6. Stage 1 projection layer only (~2 epochs)\n",
1119
  "\n",
1120
- "First launch creates `{DATASET_NAME}_run_1/` on HF and on disk. Subsequent fresh launches auto-increment to `run_2`, `run_3`, … — tracked via `ckpt/run_id.txt`.\n",
1121
  "\n",
1122
- "If you need to continue training from an existing checkpoint, pass `--resume_from <ckpt>` that reuses the same `run_N` folder."
 
 
 
 
 
 
 
 
 
1123
  ],
1124
  "id": "cell-stage1-md"
1125
  },
1126
  {
1127
  "cell_type": "code",
1128
  "metadata": {
1129
- "id": "cell-stage1",
1130
  "colab": {
1131
  "base_uri": "https://localhost:8080/"
1132
  },
1133
  "outputId": "c7d6c209-6790-473c-c1b7-a44441141785"
1134
  },
1135
  "source": [
1136
- "# Picks up RESUME_FROM / RESUME_RUN_ID from cell-resume (None if fresh run).\n",
1137
  "import time as _time, json as _json\n",
1138
  "from datetime import datetime as _dt, timezone as _tz\n",
1139
  "from pathlib import Path as _Path\n",
@@ -1146,16 +1063,13 @@
1146
  " try:\n",
1147
  " from huggingface_hub import hf_hub_download\n",
1148
  " hf_hub_download(\n",
1149
- " repo_id = repo_id,\n",
1150
- " repo_type = \"model\",\n",
1151
- " filename = f\"{run_id}/timing.json\",\n",
1152
- " token = token,\n",
1153
- " local_dir = str(ckpt_root),\n",
1154
  " )\n",
1155
  " print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
1156
- " except Exception as e:\n",
1157
- " # First run for this run_id → no remote file yet. That's fine.\n",
1158
- " pass\n",
1159
  "\n",
1160
  "def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
1161
  " # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
@@ -1165,247 +1079,87 @@
1165
  " try:\n",
1166
  " from huggingface_hub import HfApi\n",
1167
  " HfApi(token=token).upload_file(\n",
1168
- " path_or_fileobj = str(local),\n",
1169
- " path_in_repo = f\"{run_id}/timing.json\",\n",
1170
- " repo_id = repo_id,\n",
1171
- " repo_type = \"model\",\n",
1172
- " commit_message = f\"timing.json @ {run_id}\",\n",
1173
  " )\n",
1174
  " print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
1175
  " except Exception as e:\n",
1176
  " print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
1177
  "\n",
1178
  "\n",
1179
- "_resume_args = \"\"\n",
1180
- "_is_resume = False\n",
1181
- "if \"RESUME_FROM\" in dir() and RESUME_FROM and RESUME_STAGE == 1:\n",
1182
- " _resume_args = f'--resume_from \"{RESUME_FROM}\" --run_id \"{RESUME_RUN_ID}\"'\n",
1183
- " _is_resume = True\n",
1184
- " print(\"▶ STAGE 1 resuming from\", RESUME_FROM)\n",
1185
- "else:\n",
1186
- " print(\"▶ STAGE 1 fresh run\")\n",
1187
- "\n",
1188
- "# ─── Pre-pull timing.json from HF if resuming (best-effort) ────────────────\n",
1189
  "_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
1190
  "_hf_token = os.environ.get(\"HF_TOKEN\")\n",
1191
- "if _is_resume and \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
1192
- " _pull_timing_from_hf(RESUME_RUN_ID, CKPT_ROOT, _hf_repo, _hf_token)\n",
1193
  "\n",
1194
- "# ─── Start-of-stage timer ──────────────────────────────────────────────────\n",
1195
- "_t0_stage1 = _time.time()\n",
1196
- "_iso_start_stage1 = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
 
 
1197
  "\n",
1198
- "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
1199
- "python -u -m training.train \\\n",
1200
- " --model_config configs/model_config.yaml \\\n",
1201
- " --train_config configs/train_config.yaml \\\n",
1202
- " --stage 1 {_resume_args}\n",
1203
- "\n",
1204
- "# ─── End-of-stage timer + persist cumulative time ──────────────────────────\n",
1205
- "_elapsed_stage1 = _time.time() - _t0_stage1\n",
1206
- "_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
1207
- "if _run_id_file.exists():\n",
1208
- " _run_id_now = _run_id_file.read_text().strip()\n",
1209
- " _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
1210
- " _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
1211
- " _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
1212
- " \"stage1_elapsed_sec\": 0.0,\n",
1213
- " \"stage2_elapsed_sec\": 0.0,\n",
1214
- " \"resume_count_stage1\": 0,\n",
1215
- " \"resume_count_stage2\": 0,\n",
1216
- " \"first_started_at\": None,\n",
1217
- " \"last_finished_at\": None,\n",
1218
- " \"session_history\": [],\n",
1219
- " }\n",
1220
- " if _t.get(\"first_started_at\") is None:\n",
1221
- " _t[\"first_started_at\"] = _iso_start_stage1\n",
1222
- " _t[\"stage1_elapsed_sec\"] = float(_t.get(\"stage1_elapsed_sec\", 0.0)) + _elapsed_stage1\n",
1223
- " _t[\"resume_count_stage1\"] = int(_t.get(\"resume_count_stage1\", 0)) + (1 if _is_resume else 0)\n",
1224
- " _t[\"last_finished_at\"] = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
1225
- " _t.setdefault(\"session_history\", []).append({\n",
1226
- " \"stage\": 1,\n",
1227
- " \"resumed\": _is_resume,\n",
1228
- " \"started\": _iso_start_stage1,\n",
1229
- " \"finished\": _t[\"last_finished_at\"],\n",
1230
- " \"elapsed_sec\": _elapsed_stage1,\n",
1231
- " })\n",
1232
- " _timing_path.write_text(_json.dumps(_t, indent=2))\n",
1233
- "\n",
1234
- " # ─── Push to HF Hub so the timer survives a fresh VM ─────────────────\n",
1235
- " _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
1236
  "\n",
1237
- " def _fmt(sec):\n",
1238
- " h, r = divmod(int(sec), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
1239
- " print()\n",
1240
- " print(f\"[TIMING] Stage 1 this session : {_fmt(_elapsed_stage1)}\")\n",
1241
- " print(f\"[TIMING] Stage 1 cumulative : {_fmt(_t['stage1_elapsed_sec'])} (resumes so far: {_t['resume_count_stage1']})\")\n",
1242
- " print(f\"[TIMING] persisted to : {_timing_path}\")\n",
1243
- "else:\n",
1244
- " print(\"[TIMING] run_id.txt missing — could not persist timing (training likely failed before resolve_run_id ran).\")\n"
1245
- ],
1246
- "execution_count": null,
1247
- "outputs": [],
1248
- "id": "cell-stage1"
1249
- },
1250
- {
1251
- "cell_type": "markdown",
1252
- "metadata": {
1253
- "id": "cell-stage2-md"
1254
- },
1255
- "source": [
1256
- "## 7. Stage 2 — projection + LoRA instruction tuning\n",
1257
- "\n",
1258
- "Kaggle caps a GPU session at 9h. If Stage 2 doesn't finish, Persistence keeps the Trainer checkpoints in `/kaggle/working/ckpt/{RUN_ID}/stage2_instruct/checkpoint-XXXX/` — resume next session with:\n",
1259
- "```\n",
1260
- "!python -m training.train --stage 2 \\\n",
1261
- " --model_config configs/model_config.yaml --train_config configs/train_config.yaml \\\n",
1262
- " --resume_from /kaggle/working/ckpt/{RUN_ID}/stage2_instruct/checkpoint-XXXX\n",
1263
- "```\n",
1264
- "This **does not** create a new `run_N+1` on HF — it reuses the existing run."
1265
- ],
1266
- "id": "cell-stage2-md"
1267
- },
1268
- {
1269
- "cell_type": "code",
1270
- "metadata": {
1271
- "id": "cell-stage2",
1272
- "colab": {
1273
- "base_uri": "https://localhost:8080/"
1274
- },
1275
- "outputId": "d9cd4a96-ec88-4907-fc0e-38b6eb2f66a7"
1276
- },
1277
- "source": [
1278
- "# Picks up RESUME_FROM / RESUME_RUN_ID from cell-resume (None if fresh run).\n",
1279
- "import time as _time, json as _json\n",
1280
- "from datetime import datetime as _dt, timezone as _tz\n",
1281
- "from pathlib import Path as _Path\n",
1282
- "\n",
1283
- "def _pull_timing_from_hf(run_id, ckpt_root, repo_id, token):\n",
1284
- " # Pull timing.json from HF Hub for this run if not present locally.\n",
1285
- " local = ckpt_root / run_id / \"timing.json\"\n",
1286
- " if local.exists() or not repo_id or not token:\n",
1287
- " return\n",
1288
- " try:\n",
1289
- " from huggingface_hub import hf_hub_download\n",
1290
- " hf_hub_download(\n",
1291
- " repo_id = repo_id,\n",
1292
- " repo_type = \"model\",\n",
1293
- " filename = f\"{run_id}/timing.json\",\n",
1294
- " token = token,\n",
1295
- " local_dir = str(ckpt_root),\n",
1296
- " )\n",
1297
- " print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
1298
- " except Exception as e:\n",
1299
- " # First run for this run_id → no remote file yet. That's fine.\n",
1300
- " pass\n",
1301
- "\n",
1302
- "def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
1303
- " # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
1304
- " local = ckpt_root / run_id / \"timing.json\"\n",
1305
- " if not local.exists() or not repo_id or not token:\n",
1306
- " return\n",
1307
- " try:\n",
1308
- " from huggingface_hub import HfApi\n",
1309
- " HfApi(token=token).upload_file(\n",
1310
- " path_or_fileobj = str(local),\n",
1311
- " path_in_repo = f\"{run_id}/timing.json\",\n",
1312
- " repo_id = repo_id,\n",
1313
- " repo_type = \"model\",\n",
1314
- " commit_message = f\"timing.json @ {run_id}\",\n",
1315
- " )\n",
1316
- " print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
1317
- " except Exception as e:\n",
1318
- " print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
1319
- "\n",
1320
- "\n",
1321
- "_resume_args = \"\"\n",
1322
- "_is_resume = False\n",
1323
- "if \"RESUME_FROM\" in dir() and RESUME_FROM and RESUME_STAGE == 2:\n",
1324
- " _resume_args = f'--resume_from \"{RESUME_FROM}\" --run_id \"{RESUME_RUN_ID}\"'\n",
1325
- " _is_resume = True\n",
1326
- " print(\"▶ STAGE 2 resuming from\", RESUME_FROM)\n",
1327
- "elif \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
1328
- " _resume_args = f'--run_id \"{RESUME_RUN_ID}\"'\n",
1329
- " print(\"▶ STAGE 2 fresh start, pinned to run_id\", RESUME_RUN_ID)\n",
1330
- "else:\n",
1331
- " # ─── FIX: pin stage 2 to the run_id stage 1 just wrote ────────────────\n",
1332
- " # Without this, train.py treats stage 2 as a brand-new launch and\n",
1333
- " # allocates a NEW run_N folder, splitting stage1/stage2 across two runs.\n",
1334
- " _state_file = CKPT_ROOT / \"run_id.txt\"\n",
1335
- " if _state_file.exists():\n",
1336
- " _pinned = _state_file.read_text().strip()\n",
1337
- " _resume_args = f'--run_id \"{_pinned}\"'\n",
1338
- " print(f\"▶ STAGE 2 fresh, auto-pinned to run_id from state file: {_pinned}\")\n",
1339
- " else:\n",
1340
- " print(\"▶ STAGE 2 fresh (no state file — train.py will allocate a new run_id)\")\n",
1341
- "\n",
1342
- "# ─── Pre-pull timing.json from HF (in case of fresh VM) ───────────────────\n",
1343
- "_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
1344
- "_hf_token = os.environ.get(\"HF_TOKEN\")\n",
1345
- "# Best guess at run_id BEFORE training (may be missing if stage 1 wasn't run here)\n",
1346
- "_pre_state = CKPT_ROOT / \"run_id.txt\"\n",
1347
- "if _pre_state.exists():\n",
1348
- " _pull_timing_from_hf(_pre_state.read_text().strip(), CKPT_ROOT, _hf_repo, _hf_token)\n",
1349
  "\n",
1350
- "# ─── Start-of-stage timer ──────────────────────────────────────────────────\n",
1351
- "_t0_stage2 = _time.time()\n",
1352
- "_iso_start_stage2 = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
1353
  "\n",
1354
  "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
1355
  "python -u -m training.train \\\n",
1356
  " --model_config configs/model_config.yaml \\\n",
1357
  " --train_config configs/train_config.yaml \\\n",
1358
- " --stage 2 {_resume_args}\n",
 
 
1359
  "\n",
1360
- "# ─── End-of-stage timer + persist cumulative time ──────────────────────────\n",
1361
- "_elapsed_stage2 = _time.time() - _t0_stage2\n",
1362
  "_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
1363
  "if _run_id_file.exists():\n",
1364
  " _run_id_now = _run_id_file.read_text().strip()\n",
1365
  " _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
1366
  " _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
1367
  " _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
1368
- " \"stage1_elapsed_sec\": 0.0,\n",
1369
- " \"stage2_elapsed_sec\": 0.0,\n",
1370
- " \"resume_count_stage1\": 0,\n",
1371
- " \"resume_count_stage2\": 0,\n",
1372
- " \"first_started_at\": None,\n",
1373
- " \"last_finished_at\": None,\n",
1374
- " \"session_history\": [],\n",
1375
  " }\n",
1376
  " if _t.get(\"first_started_at\") is None:\n",
1377
- " _t[\"first_started_at\"] = _iso_start_stage2\n",
1378
- " _t[\"stage2_elapsed_sec\"] = float(_t.get(\"stage2_elapsed_sec\", 0.0)) + _elapsed_stage2\n",
1379
- " _t[\"resume_count_stage2\"] = int(_t.get(\"resume_count_stage2\", 0)) + (1 if _is_resume else 0)\n",
1380
- " _t[\"last_finished_at\"] = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
1381
  " _t.setdefault(\"session_history\", []).append({\n",
1382
- " \"stage\": 2,\n",
1383
- " \"resumed\": _is_resume,\n",
1384
- " \"started\": _iso_start_stage2,\n",
1385
- " \"finished\": _t[\"last_finished_at\"],\n",
1386
- " \"elapsed_sec\": _elapsed_stage2,\n",
1387
  " })\n",
1388
  " _timing_path.write_text(_json.dumps(_t, indent=2))\n",
1389
- "\n",
1390
- " # ─── Push to HF Hub ──────────────────────────────────────────────────\n",
1391
  " _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
1392
  "\n",
1393
  " def _fmt(sec):\n",
1394
  " h, r = divmod(int(sec), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
1395
- " _total = _t[\"stage1_elapsed_sec\"] + _t[\"stage2_elapsed_sec\"]\n",
1396
  " print()\n",
1397
- " print(f\"[TIMING] Stage 2 this session : {_fmt(_elapsed_stage2)}\")\n",
1398
- " print(f\"[TIMING] Stage 2 cumulative : {_fmt(_t['stage2_elapsed_sec'])} (resumes so far: {_t['resume_count_stage2']})\")\n",
1399
- " print(f\"[TIMING] Stage 1 + Stage 2 : {_fmt(_total)}\")\n",
1400
- " print(f\"[TIMING] first started at : {_t.get('first_started_at')}\")\n",
1401
- " print(f\"[TIMING] last finished at : {_t.get('last_finished_at')}\")\n",
1402
- " print(f\"[TIMING] persisted to : {_timing_path}\")\n",
1403
  "else:\n",
1404
  " print(\"[TIMING] run_id.txt missing — could not persist timing.\")\n"
1405
  ],
1406
  "execution_count": null,
1407
  "outputs": [],
1408
- "id": "cell-stage2"
1409
  },
1410
  {
1411
  "cell_type": "markdown",
@@ -1557,16 +1311,13 @@
1557
  " try:\n",
1558
  " from huggingface_hub import hf_hub_download\n",
1559
  " hf_hub_download(\n",
1560
- " repo_id = repo_id,\n",
1561
- " repo_type = \"model\",\n",
1562
- " filename = f\"{run_id}/timing.json\",\n",
1563
- " token = token,\n",
1564
- " local_dir = str(ckpt_root),\n",
1565
  " )\n",
1566
  " print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
1567
- " except Exception as e:\n",
1568
- " # First run for this run_id → no remote file yet. That's fine.\n",
1569
- " pass\n",
1570
  "\n",
1571
  "def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
1572
  " # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
@@ -1576,11 +1327,10 @@
1576
  " try:\n",
1577
  " from huggingface_hub import HfApi\n",
1578
  " HfApi(token=token).upload_file(\n",
1579
- " path_or_fileobj = str(local),\n",
1580
- " path_in_repo = f\"{run_id}/timing.json\",\n",
1581
- " repo_id = repo_id,\n",
1582
- " repo_type = \"model\",\n",
1583
- " commit_message = f\"timing.json @ {run_id}\",\n",
1584
  " )\n",
1585
  " print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
1586
  " except Exception as e:\n",
@@ -1588,10 +1338,10 @@
1588
  "\n",
1589
  "\n",
1590
  "_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
1591
- "assert _run_id_file.exists(), \"No run_id.txt — train at least one stage first.\"\n",
1592
  "_run_id = _run_id_file.read_text().strip()\n",
1593
  "\n",
1594
- "# Pull the latest timing.json from HF in case we're on a fresh VM.\n",
1595
  "_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
1596
  "_hf_token = os.environ.get(\"HF_TOKEN\")\n",
1597
  "_pull_timing_from_hf(_run_id, CKPT_ROOT, _hf_repo, _hf_token)\n",
@@ -1599,32 +1349,24 @@
1599
  "_timing_path = CKPT_ROOT / _run_id / \"timing.json\"\n",
1600
  "assert _timing_path.exists(), (\n",
1601
  " f\"No timing.json under {_timing_path.parent} (also not on HF). \"\n",
1602
- " f\"Was the stage cell run via the wrapped version?\"\n",
1603
  ")\n",
1604
  "\n",
1605
  "_t = _json.loads(_timing_path.read_text())\n",
1606
  "\n",
1607
  "def _fmt(sec):\n",
1608
- " h, r = divmod(int(sec or 0), 3600)\n",
1609
- " m, s = divmod(r, 60)\n",
1610
- " return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
1611
- "\n",
1612
- "_s1 = float(_t.get(\"stage1_elapsed_sec\", 0.0))\n",
1613
- "_s2 = float(_t.get(\"stage2_elapsed_sec\", 0.0))\n",
1614
- "_total = _s1 + _s2\n",
1615
  "\n",
1616
  "print(f\"Run : {_run_id}\")\n",
1617
  "print(f\"First started at : {_t.get('first_started_at')}\")\n",
1618
  "print(f\"Last finished at : {_t.get('last_finished_at')}\")\n",
1619
- "print()\n",
1620
- "print(f\"Stage 1 cumulative : {_fmt(_s1)} (resumes: {_t.get('resume_count_stage1', 0)})\")\n",
1621
- "print(f\"Stage 2 cumulative : {_fmt(_s2)} (resumes: {_t.get('resume_count_stage2', 0)})\")\n",
1622
- "print(f\"TOTAL : {_fmt(_total)}\")\n",
1623
  "print()\n",
1624
  "print(\"Session history :\")\n",
1625
  "for _i, _s in enumerate(_t.get(\"session_history\", []), 1):\n",
1626
- " _tag = \"(resume)\" if _s.get(\"resumed\") else \"(fresh) \"\n",
1627
- " print(f\" {_i:2d}. stage {_s['stage']} {_tag} {_fmt(_s['elapsed_sec'])} {_s['started']} → {_s['finished']}\")\n"
1628
  ],
1629
  "outputs": [],
1630
  "execution_count": null
 
966
  {
967
  "cell_type": "markdown",
968
  "metadata": {
969
+ "id": "cell-mode-md"
970
  },
971
  "source": [
972
+ "## 5b. Resume controller\n",
973
  "\n",
974
+ "Single switch. No more \"which stage\" `train.py` auto-detects which stage\n",
975
+ "to continue from by inspecting checkpoints on disk.\n",
976
  "\n",
977
+ "| MODE | What happens |\n",
978
+ "|---------------------|--------------|\n",
979
+ "| `'fresh'` | Allocate a brand-new `{DATASET}_run_N+1` folder. Train both stages from scratch. |\n",
980
+ "| `'resume'` | Reuse latest matching `{DATASET}_run_N` (or `EXPLICIT_RUN_ID`). Auto-detect: stage 1 mid-checkpoint, stage 1 done → stage 2 fresh, stage 2 mid-checkpoint, or both done. |\n",
981
  "\n",
982
+ "`EXPLICIT_RUN_ID` is optional (set to `None` to auto-pick the latest run on\n",
983
+ "disk or HF Hub that matches the current dataset prefix).\n",
 
 
 
 
 
984
  "\n",
985
+ "When `MODE='resume'` on a fresh VM the train cell will pull the previous\n",
986
+ "run's checkpoints from HF before training. The `--mode resume` flag in\n",
987
+ "`train.py` does the auto-detect no further action needed in the notebook."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
  ],
989
  "id": "cell-resume-md"
990
  },
991
  {
992
  "cell_type": "code",
993
  "metadata": {
994
+ "id": "cell-mode",
995
  "colab": {
996
  "base_uri": "https://localhost:8080/"
997
  },
998
  "outputId": "6f1b15fb-e751-4209-ea52-7a98a8c40212"
999
  },
1000
  "execution_count": null,
1001
+ "outputs": [],
 
 
 
 
 
 
 
 
1002
  "source": [
1003
+ "# Resume controller — set MODE once, run the train cell below.\n",
1004
+ "MODE = 'fresh' # 'fresh' | 'resume'\n",
1005
+ "EXPLICIT_RUN_ID = None # None | 'IU-Xray_run_5' (only matters when MODE='resume')\n",
1006
+ "\n",
1007
+ "assert MODE in ('fresh', 'resume'), \"MODE must be 'fresh' or 'resume'\"\n",
1008
+ "if MODE == 'resume' and EXPLICIT_RUN_ID:\n",
1009
+ " CKPT_ROOT.mkdir(parents=True, exist_ok=True)\n",
1010
+ " (CKPT_ROOT / 'run_id.txt').write_text(EXPLICIT_RUN_ID)\n",
1011
+ " print(f\"MODE=resume, pinning run_id = {EXPLICIT_RUN_ID}\")\n",
1012
+ "elif MODE == 'resume':\n",
1013
+ " print(\"MODE=resume, run_id will be auto-resolved to the latest \"\n",
1014
+ " f\"'{DATASET_NAME}_run_*' on disk (or HF Hub).\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1015
  "else:\n",
1016
+ " print(\"MODE=freshtrain.py will allocate a new run folder.\")\n"
1017
  ],
1018
  "id": "cell-resume"
1019
  },
1020
  {
1021
  "cell_type": "markdown",
1022
  "metadata": {
1023
+ "id": "cell-train-md"
1024
  },
1025
  "source": [
1026
+ "## 6. Train (both stages, one shot)\n",
1027
  "\n",
1028
+ "Single cell runs `train.py --mode {MODE}` which:\n",
1029
  "\n",
1030
+ "1. Resolves `run_id` (new vs. latest matching).\n",
1031
+ "2. Prints the total step plan (stage 1 + stage 2 = global step count).\n",
1032
+ "3. Auto-detects which stage to resume from by scanning the run folder.\n",
1033
+ "4. Runs stage 1 (skip if already done), then stage 2.\n",
1034
+ "5. Pushes intermediate `checkpoint-NNN/` + `best/` of each stage to HF Hub\n",
1035
+ " as before. `timing.json` is updated and uploaded after each stage.\n",
1036
+ "\n",
1037
+ "Kaggle / Colab caps a GPU session at ~9h. If the session dies mid-stage,\n",
1038
+ "just re-run this cell with `MODE='resume'` — `train.py` picks up at the\n",
1039
+ "last `checkpoint-NNN/` of whichever stage was in progress."
1040
  ],
1041
  "id": "cell-stage1-md"
1042
  },
1043
  {
1044
  "cell_type": "code",
1045
  "metadata": {
1046
+ "id": "cell-train",
1047
  "colab": {
1048
  "base_uri": "https://localhost:8080/"
1049
  },
1050
  "outputId": "c7d6c209-6790-473c-c1b7-a44441141785"
1051
  },
1052
  "source": [
1053
+ "# Unified train cell drives both stages in one shot, with auto-resume.\n",
1054
  "import time as _time, json as _json\n",
1055
  "from datetime import datetime as _dt, timezone as _tz\n",
1056
  "from pathlib import Path as _Path\n",
 
1063
  " try:\n",
1064
  " from huggingface_hub import hf_hub_download\n",
1065
  " hf_hub_download(\n",
1066
+ " repo_id=repo_id, repo_type=\"model\",\n",
1067
+ " filename=f\"{run_id}/timing.json\",\n",
1068
+ " token=token, local_dir=str(ckpt_root),\n",
 
 
1069
  " )\n",
1070
  " print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
1071
+ " except Exception:\n",
1072
+ " pass # first time for this run_id → no remote file yet, fine\n",
 
1073
  "\n",
1074
  "def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
1075
  " # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
 
1079
  " try:\n",
1080
  " from huggingface_hub import HfApi\n",
1081
  " HfApi(token=token).upload_file(\n",
1082
+ " path_or_fileobj=str(local),\n",
1083
+ " path_in_repo=f\"{run_id}/timing.json\",\n",
1084
+ " repo_id=repo_id, repo_type=\"model\",\n",
1085
+ " commit_message=f\"timing.json @ {run_id}\",\n",
 
1086
  " )\n",
1087
  " print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
1088
  " except Exception as e:\n",
1089
  " print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
1090
  "\n",
1091
  "\n",
1092
+ "assert MODE in ('fresh', 'resume')\n",
 
 
 
 
 
 
 
 
 
1093
  "_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
1094
  "_hf_token = os.environ.get(\"HF_TOKEN\")\n",
 
 
1095
  "\n",
1096
+ "# ─── Pre-pull timing.json if resuming ──────────────────────────────────────\n",
1097
+ "if MODE == 'resume':\n",
1098
+ " _pre_state = CKPT_ROOT / \"run_id.txt\"\n",
1099
+ " if _pre_state.exists():\n",
1100
+ " _pull_timing_from_hf(_pre_state.read_text().strip(), CKPT_ROOT, _hf_repo, _hf_token)\n",
1101
  "\n",
1102
+ "_run_id_arg = \"\"\n",
1103
+ "if MODE == 'resume' and EXPLICIT_RUN_ID:\n",
1104
+ " _run_id_arg = f'--run_id \"{EXPLICIT_RUN_ID}\"'\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1105
  "\n",
1106
+ "print(f\"▶ Launching train.py --mode {MODE} {_run_id_arg}\")\n",
1107
+ "print(f\" (Both stages dispatched in one call. Stage will be auto-detected for resume.)\")\n",
1108
+ "print()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1109
  "\n",
1110
+ "# ─── Wall-clock timing (entire training session) ──────────────────────────\n",
1111
+ "_t0 = _time.time()\n",
1112
+ "_iso_start = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
1113
  "\n",
1114
  "!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
1115
  "python -u -m training.train \\\n",
1116
  " --model_config configs/model_config.yaml \\\n",
1117
  " --train_config configs/train_config.yaml \\\n",
1118
+ " --mode {MODE} {_run_id_arg}\n",
1119
+ "\n",
1120
+ "_elapsed = _time.time() - _t0\n",
1121
  "\n",
1122
+ "# ─── Persist timing.json (cumulative across resumes) ──────────────────────\n",
 
1123
  "_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
1124
  "if _run_id_file.exists():\n",
1125
  " _run_id_now = _run_id_file.read_text().strip()\n",
1126
  " _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
1127
  " _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
1128
  " _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
1129
+ " \"total_elapsed_sec\": 0.0,\n",
1130
+ " \"session_count\": 0,\n",
1131
+ " \"first_started_at\": None,\n",
1132
+ " \"last_finished_at\": None,\n",
1133
+ " \"session_history\": [],\n",
 
 
1134
  " }\n",
1135
  " if _t.get(\"first_started_at\") is None:\n",
1136
+ " _t[\"first_started_at\"] = _iso_start\n",
1137
+ " _t[\"total_elapsed_sec\"] = float(_t.get(\"total_elapsed_sec\", 0.0)) + _elapsed\n",
1138
+ " _t[\"session_count\"] = int(_t.get(\"session_count\", 0)) + 1\n",
1139
+ " _t[\"last_finished_at\"] = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
1140
  " _t.setdefault(\"session_history\", []).append({\n",
1141
+ " \"mode\": MODE,\n",
1142
+ " \"started\": _iso_start,\n",
1143
+ " \"finished\": _t[\"last_finished_at\"],\n",
1144
+ " \"elapsed_sec\": _elapsed,\n",
 
1145
  " })\n",
1146
  " _timing_path.write_text(_json.dumps(_t, indent=2))\n",
 
 
1147
  " _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
1148
  "\n",
1149
  " def _fmt(sec):\n",
1150
  " h, r = divmod(int(sec), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
 
1151
  " print()\n",
1152
+ " print(f\"[TIMING] This session : {_fmt(_elapsed)}\")\n",
1153
+ " print(f\"[TIMING] Cumulative : {_fmt(_t['total_elapsed_sec'])} ({_t['session_count']} session(s))\")\n",
1154
+ " print(f\"[TIMING] first started : {_t.get('first_started_at')}\")\n",
1155
+ " print(f\"[TIMING] last finished : {_t.get('last_finished_at')}\")\n",
1156
+ " print(f\"[TIMING] persisted to : {_timing_path}\")\n",
 
1157
  "else:\n",
1158
  " print(\"[TIMING] run_id.txt missing — could not persist timing.\")\n"
1159
  ],
1160
  "execution_count": null,
1161
  "outputs": [],
1162
+ "id": "cell-stage1"
1163
  },
1164
  {
1165
  "cell_type": "markdown",
 
1311
  " try:\n",
1312
  " from huggingface_hub import hf_hub_download\n",
1313
  " hf_hub_download(\n",
1314
+ " repo_id=repo_id, repo_type=\"model\",\n",
1315
+ " filename=f\"{run_id}/timing.json\",\n",
1316
+ " token=token, local_dir=str(ckpt_root),\n",
 
 
1317
  " )\n",
1318
  " print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
1319
+ " except Exception:\n",
1320
+ " pass # first time for this run_id → no remote file yet, fine\n",
 
1321
  "\n",
1322
  "def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
1323
  " # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
 
1327
  " try:\n",
1328
  " from huggingface_hub import HfApi\n",
1329
  " HfApi(token=token).upload_file(\n",
1330
+ " path_or_fileobj=str(local),\n",
1331
+ " path_in_repo=f\"{run_id}/timing.json\",\n",
1332
+ " repo_id=repo_id, repo_type=\"model\",\n",
1333
+ " commit_message=f\"timing.json @ {run_id}\",\n",
 
1334
  " )\n",
1335
  " print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
1336
  " except Exception as e:\n",
 
1338
  "\n",
1339
  "\n",
1340
  "_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
1341
+ "assert _run_id_file.exists(), \"No run_id.txt — run the train cell at least once.\"\n",
1342
  "_run_id = _run_id_file.read_text().strip()\n",
1343
  "\n",
1344
+ "# Pull latest timing.json from HF in case this is a fresh VM\n",
1345
  "_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
1346
  "_hf_token = os.environ.get(\"HF_TOKEN\")\n",
1347
  "_pull_timing_from_hf(_run_id, CKPT_ROOT, _hf_repo, _hf_token)\n",
 
1349
  "_timing_path = CKPT_ROOT / _run_id / \"timing.json\"\n",
1350
  "assert _timing_path.exists(), (\n",
1351
  " f\"No timing.json under {_timing_path.parent} (also not on HF). \"\n",
1352
+ " f\"Did the train cell run?\"\n",
1353
  ")\n",
1354
  "\n",
1355
  "_t = _json.loads(_timing_path.read_text())\n",
1356
  "\n",
1357
  "def _fmt(sec):\n",
1358
+ " h, r = divmod(int(sec or 0), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
 
 
 
 
 
 
1359
  "\n",
1360
  "print(f\"Run : {_run_id}\")\n",
1361
  "print(f\"First started at : {_t.get('first_started_at')}\")\n",
1362
  "print(f\"Last finished at : {_t.get('last_finished_at')}\")\n",
1363
+ "print(f\"Session count : {_t.get('session_count', 0)}\")\n",
1364
+ "print(f\"TOTAL elapsed : {_fmt(_t.get('total_elapsed_sec', 0.0))}\")\n",
 
 
1365
  "print()\n",
1366
  "print(\"Session history :\")\n",
1367
  "for _i, _s in enumerate(_t.get(\"session_history\", []), 1):\n",
1368
+ " print(f\" {_i:2d}. mode={_s.get('mode','?'):6s} {_fmt(_s['elapsed_sec'])} \"\n",
1369
+ " f\"{_s['started']} → {_s['finished']}\")\n"
1370
  ],
1371
  "outputs": [],
1372
  "execution_count": null
training/train.py CHANGED
@@ -69,11 +69,19 @@ def parse_args():
69
  )
70
  parser.add_argument(
71
  "--stage", type=int, default=None,
72
- help="Run only stage 1 or stage 2 (default: run both)"
 
 
 
 
 
 
 
 
73
  )
74
  parser.add_argument(
75
  "--resume_from", type=str, default=None,
76
- help="Path to checkpoint to resume from"
77
  )
78
  parser.add_argument(
79
  "--run_id", type=str, default=None,
@@ -87,6 +95,134 @@ def parse_args():
87
  return parser.parse_args()
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def get_trainer(
91
  model,
92
  train_dataset,
@@ -469,11 +605,21 @@ def main():
469
  train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
470
  ) if train_cfg.hf_hub.enabled else None
471
  hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
 
 
 
 
 
 
 
 
 
 
472
  run_id = resolve_run_id(
473
  dataset_name = spec.dataset_name,
474
  output_root = output_root,
475
  state_file = state_file,
476
- resuming = bool(args.resume_from) or args.resume_from_hf,
477
  explicit = args.run_id,
478
  hf_repo_id = hf_repo_id,
479
  hf_token = hf_token,
@@ -509,6 +655,35 @@ def main():
509
  stage2_out = stage_dir(output_root, run_id,
510
  str(train_cfg.stage2.get("subdir", "stage2_instruct")))
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  # ── Snapshot resolved config into the run dir ────────────────────
513
  # Every run gets its own self-describing folder so we never have to ask
514
  # "what config did IU-Xray_run_3 actually use?" — open run_meta.json.
@@ -560,19 +735,41 @@ def main():
560
  load_checkpoint(model, args.resume_from)
561
 
562
  # Run training stages
563
- run_s1 = (args.stage is None or args.stage == 1) and train_cfg.stage1.enabled
564
- run_s2 = (args.stage is None or args.stage == 2) and train_cfg.stage2.enabled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  if run_s1:
567
- # Only pass resume_from to stage1 if stage1 is the explicit target
568
- # (`--stage 1`). If user runs both stages with --resume_from, we assume
569
- # the checkpoint is for stage2 and let stage1 run fresh (or finish fast
570
- # if projection was already trained).
571
- s1_resume = args.resume_from if args.stage == 1 else None
572
  model = run_stage1(
573
  model, train_cfg, model_cfg, spec, stage1_out, logger,
574
  tracker = tracker,
575
- resume_from = s1_resume,
576
  )
577
 
578
  if run_s2:
@@ -580,15 +777,17 @@ def main():
580
  # Priority:
581
  # 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt
582
  # 2. Not running stage1 but stage1_final.pt exists on disk → load it
583
- # 3. Nothing warn loudly; stage2 starts with random projection.
 
 
584
  stage1_ckpt = Path(stage1_out) / "stage1_final.pt"
585
  if run_s1:
586
  load_checkpoint(model, str(stage1_ckpt))
587
  logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}")
588
- elif stage1_ckpt.exists() and not args.resume_from:
589
  load_checkpoint(model, str(stage1_ckpt))
590
  logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}")
591
- elif not args.resume_from:
592
  logger.warning(
593
  "⚠ No stage1 weights found and not resuming. Projection layer "
594
  "will start RANDOMLY for stage2. Expect degraded convergence. "
@@ -597,7 +796,7 @@ def main():
597
 
598
  model = run_stage2(
599
  model, train_cfg, model_cfg, spec, stage2_out, logger,
600
- resume_from = args.resume_from if not run_s1 else None,
601
  tracker = tracker,
602
  )
603
 
 
69
  )
70
  parser.add_argument(
71
  "--stage", type=int, default=None,
72
+ help="Run only stage 1 or stage 2 (default: run both). With --mode resume, "
73
+ "the stage is auto-detected and this flag should be left unset."
74
+ )
75
+ parser.add_argument(
76
+ "--mode", type=str, default=None, choices=["fresh", "resume"],
77
+ help="Unified resume controller. 'fresh' → new run_N folder. "
78
+ "'resume' → reuse latest matching run_id (or --run_id), auto-detect "
79
+ "which stage to continue from based on checkpoints on disk. "
80
+ "If unset, behaviour is inferred from --resume_from / --run_id (legacy)."
81
  )
82
  parser.add_argument(
83
  "--resume_from", type=str, default=None,
84
+ help="Path to checkpoint to resume from (legacy; prefer --mode resume)"
85
  )
86
  parser.add_argument(
87
  "--run_id", type=str, default=None,
 
95
  return parser.parse_args()
96
 
97
 
98
+ # ─── Resume-point auto-detection ────────────────────────────────────────────
99
+
100
+ def _list_checkpoints(stage_dir):
101
+ """Return [Path, …] of `checkpoint-NNN` folders sorted ascending by step."""
102
+ if not stage_dir.is_dir():
103
+ return []
104
+ out = []
105
+ for p in stage_dir.iterdir():
106
+ if not p.is_dir() or not p.name.startswith("checkpoint-"):
107
+ continue
108
+ suffix = p.name.split("-", 1)[1]
109
+ if suffix.isdigit():
110
+ out.append((int(suffix), p))
111
+ return [p for _, p in sorted(out)]
112
+
113
+
114
+ def detect_resume_point(run_dir_path, stage1_subdir, stage2_subdir):
115
+ """
116
+ Inspect the run dir on disk and decide where to pick up training.
117
+
118
+ Returns a tuple `(target_stage, ckpt_path)` where:
119
+ target_stage : "stage1" | "stage2" | "done"
120
+ ckpt_path : Path to the checkpoint folder to pass to HF Trainer,
121
+ or None if the stage should start from scratch.
122
+
123
+ Priority:
124
+ 1. stage 2 final saved → ("done", None) everything finished
125
+ 2. stage 2 has ckpts → ("stage2", latest) resume mid-stage2
126
+ 3. stage 1 final saved → ("stage2", None) stage 1 done; start stage 2
127
+ 4. stage 1 has ckpts → ("stage1", latest) resume mid-stage1
128
+ 5. otherwise → ("stage1", None) brand-new run
129
+ """
130
+ from pathlib import Path as _P
131
+ run_dir_path = _P(run_dir_path)
132
+ s1d = run_dir_path / stage1_subdir
133
+ s2d = run_dir_path / stage2_subdir
134
+
135
+ if (s2d / "stage2_final_projection.pt").exists():
136
+ return ("done", None)
137
+
138
+ s2_ckpts = _list_checkpoints(s2d)
139
+ if s2_ckpts:
140
+ return ("stage2", s2_ckpts[-1])
141
+
142
+ if (s1d / "stage1_final_projection.pt").exists():
143
+ return ("stage2", None)
144
+
145
+ s1_ckpts = _list_checkpoints(s1d)
146
+ if s1_ckpts:
147
+ return ("stage1", s1_ckpts[-1])
148
+
149
+ return ("stage1", None)
150
+
151
+
152
+ def compute_training_plan(train_cfg, instruct_json_path):
153
+ """
154
+ Compute a coarse plan of total optimizer steps across stage 1 + stage 2,
155
+ derived from the train_config + the train-split sample count in the
156
+ instruct JSON. Used to print a human-readable summary at startup.
157
+
158
+ Returns a dict (all ints) — gracefully handles missing fields.
159
+ """
160
+ import json as _json
161
+ tr = train_cfg.training
162
+ try:
163
+ with open(instruct_json_path, "r", encoding="utf-8") as f:
164
+ all_samples = _json.load(f)
165
+ train_count = sum(1 for s in all_samples if s.get("split") == "train")
166
+ except Exception:
167
+ train_count = 0
168
+
169
+ bs = int(getattr(tr, "per_device_train_batch_size", 1))
170
+ ga = int(getattr(tr, "gradient_accumulation_steps", 1))
171
+ eff = max(1, bs * ga)
172
+ steps_per_epoch = max(1, (train_count + eff - 1) // eff)
173
+
174
+ s1_enabled = bool(getattr(train_cfg.stage1, "enabled", True))
175
+ s2_enabled = bool(getattr(train_cfg.stage2, "enabled", True))
176
+ s1_epochs = int(getattr(train_cfg.stage1, "num_epochs", 0)) if s1_enabled else 0
177
+ s2_epochs = int(getattr(train_cfg.stage2, "num_epochs", 0)) if s2_enabled else 0
178
+
179
+ s1_steps = steps_per_epoch * s1_epochs
180
+ s2_steps = steps_per_epoch * s2_epochs
181
+ return {
182
+ "train_samples": train_count,
183
+ "effective_batch": eff,
184
+ "steps_per_epoch": steps_per_epoch,
185
+ "stage1_steps": s1_steps,
186
+ "stage2_steps": s2_steps,
187
+ "total_steps": s1_steps + s2_steps,
188
+ "stage1_epochs": s1_epochs,
189
+ "stage2_epochs": s2_epochs,
190
+ }
191
+
192
+
193
+ def _fmt_plan_banner(plan, run_id, target_stage, resume_ckpt):
194
+ s1, s2, tot = plan["stage1_steps"], plan["stage2_steps"], plan["total_steps"]
195
+ head = f"TRAINING PLAN — {run_id}"
196
+ sep = "=" * max(len(head) + 4, 60)
197
+ cur = ""
198
+ if target_stage == "stage1":
199
+ offset = 0
200
+ if resume_ckpt and str(resume_ckpt).split("-")[-1].isdigit():
201
+ offset = int(str(resume_ckpt).split("-")[-1])
202
+ cur = f"Resuming at step {offset} / {tot} (inside stage 1)"
203
+ elif target_stage == "stage2":
204
+ offset = s1
205
+ if resume_ckpt and str(resume_ckpt).split("-")[-1].isdigit():
206
+ offset = s1 + int(str(resume_ckpt).split("-")[-1])
207
+ cur = f"Resuming at step {offset} / {tot} (inside stage 2)"
208
+ elif target_stage == "done":
209
+ cur = f"All {tot} steps already complete — nothing to do"
210
+
211
+ lines = [
212
+ sep, f" {head}", sep,
213
+ f" Train samples : {plan['train_samples']:,}",
214
+ f" Effective batch : {plan['effective_batch']}",
215
+ f" Steps / epoch : {plan['steps_per_epoch']}",
216
+ f" Stage 1 : {plan['stage1_epochs']} epochs → {s1} steps (global steps 1–{s1})",
217
+ f" Stage 2 : {plan['stage2_epochs']} epochs → {s2} steps (global steps {s1+1}–{tot})",
218
+ f" TOTAL : {tot} optimizer steps",
219
+ ]
220
+ if cur:
221
+ lines += [" " + "─" * (len(sep) - 4), f" {cur}"]
222
+ lines.append(sep)
223
+ return "\n".join(lines)
224
+
225
+
226
  def get_trainer(
227
  model,
228
  train_dataset,
 
605
  train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
606
  ) if train_cfg.hf_hub.enabled else None
607
  hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
608
+
609
+ # Unified --mode controller. Falls back to the legacy inference (any of
610
+ # --resume_from / --resume_from_hf set ⇒ resuming) when --mode is unset.
611
+ if args.mode == "resume":
612
+ resuming = True
613
+ elif args.mode == "fresh":
614
+ resuming = False
615
+ else:
616
+ resuming = bool(args.resume_from) or args.resume_from_hf
617
+
618
  run_id = resolve_run_id(
619
  dataset_name = spec.dataset_name,
620
  output_root = output_root,
621
  state_file = state_file,
622
+ resuming = resuming,
623
  explicit = args.run_id,
624
  hf_repo_id = hf_repo_id,
625
  hf_token = hf_token,
 
655
  stage2_out = stage_dir(output_root, run_id,
656
  str(train_cfg.stage2.get("subdir", "stage2_instruct")))
657
 
658
+ # ── Auto-detect where to resume from (when --mode resume) ─────────
659
+ # Examines disk state inside {output_root}/{run_id}/ and chooses:
660
+ # • stage1 from scratch / stage1 mid-checkpoint
661
+ # • stage2 from scratch (stage1 done) / stage2 mid-checkpoint
662
+ # • done (both stages finished — skip everything)
663
+ # If the user passed --stage explicitly, that wins over auto-detect.
664
+ auto_target_stage = None
665
+ auto_resume_ckpt = None
666
+ if args.mode == "resume" and args.stage is None:
667
+ auto_target_stage, auto_resume_ckpt = detect_resume_point(
668
+ run_dir(output_root, run_id),
669
+ str(train_cfg.stage1.get("subdir", "stage1_projection")),
670
+ str(train_cfg.stage2.get("subdir", "stage2_instruct")),
671
+ )
672
+ logger.info(
673
+ f"[resume autodetect] target={auto_target_stage} "
674
+ f"ckpt={auto_resume_ckpt}"
675
+ )
676
+
677
+ # ── Pretty plan banner (total steps across both stages) ───────────
678
+ plan = compute_training_plan(train_cfg, spec.instruct_json)
679
+ logger.info("\n" + _fmt_plan_banner(plan, run_id,
680
+ auto_target_stage or "stage1",
681
+ auto_resume_ckpt))
682
+
683
+ if auto_target_stage == "done":
684
+ logger.info("Both stages already complete for this run. Exiting cleanly.")
685
+ return
686
+
687
  # ── Snapshot resolved config into the run dir ────────────────────
688
  # Every run gets its own self-describing folder so we never have to ask
689
  # "what config did IU-Xray_run_3 actually use?" — open run_meta.json.
 
735
  load_checkpoint(model, args.resume_from)
736
 
737
  # Run training stages
738
+ #
739
+ # Stage selection priority:
740
+ # 1. Explicit --stage from CLI wins.
741
+ # 2. --mode resume + auto-detect: skip stage1 when its final ckpt exists,
742
+ # resume stage1/stage2 from `auto_resume_ckpt` as detected above.
743
+ # 3. Otherwise: enabled flags from train_cfg drive it (legacy: run both).
744
+ if args.stage is not None:
745
+ run_s1 = (args.stage == 1) and train_cfg.stage1.enabled
746
+ run_s2 = (args.stage == 2) and train_cfg.stage2.enabled
747
+ elif auto_target_stage == "stage2":
748
+ # Stage 1 finished previously — skip it entirely.
749
+ run_s1 = False
750
+ run_s2 = train_cfg.stage2.enabled
751
+ else:
752
+ run_s1 = train_cfg.stage1.enabled
753
+ run_s2 = train_cfg.stage2.enabled
754
+
755
+ # Decide the resume checkpoint each stage should use.
756
+ # Manual --resume_from still wins when --stage is given explicitly.
757
+ s1_resume_path = None
758
+ s2_resume_path = None
759
+ if args.stage == 1:
760
+ s1_resume_path = args.resume_from
761
+ elif args.stage == 2:
762
+ s2_resume_path = args.resume_from
763
+ elif auto_target_stage == "stage1":
764
+ s1_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
765
+ elif auto_target_stage == "stage2":
766
+ s2_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
767
 
768
  if run_s1:
 
 
 
 
 
769
  model = run_stage1(
770
  model, train_cfg, model_cfg, spec, stage1_out, logger,
771
  tracker = tracker,
772
+ resume_from = s1_resume_path,
773
  )
774
 
775
  if run_s2:
 
777
  # Priority:
778
  # 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt
779
  # 2. Not running stage1 but stage1_final.pt exists on disk → load it
780
+ # 3. s2_resume_path set (we're mid-stage2) Trainer will reload from
781
+ # the checkpoint itself; no need to seed stage1 weights here.
782
+ # 4. Nothing → warn loudly; stage2 starts with random projection.
783
  stage1_ckpt = Path(stage1_out) / "stage1_final.pt"
784
  if run_s1:
785
  load_checkpoint(model, str(stage1_ckpt))
786
  logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}")
787
+ elif stage1_ckpt.exists() and not s2_resume_path:
788
  load_checkpoint(model, str(stage1_ckpt))
789
  logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}")
790
+ elif not s2_resume_path:
791
  logger.warning(
792
  "⚠ No stage1 weights found and not resuming. Projection layer "
793
  "will start RANDOMLY for stage2. Expect degraded convergence. "
 
796
 
797
  model = run_stage2(
798
  model, train_cfg, model_cfg, spec, stage2_out, logger,
799
+ resume_from = s2_resume_path,
800
  tracker = tracker,
801
  )
802