convitom commited on
Commit ·
9dadb47
1
Parent(s): c369576
- scripts/cxrvlm_colab_train.ipynb +107 -365
- training/train.py +214 -15
scripts/cxrvlm_colab_train.ipynb
CHANGED
|
@@ -966,174 +966,91 @@
|
|
| 966 |
{
|
| 967 |
"cell_type": "markdown",
|
| 968 |
"metadata": {
|
| 969 |
-
"id": "cell-
|
| 970 |
},
|
| 971 |
"source": [
|
| 972 |
-
"## 5b. Resume
|
| 973 |
"\n",
|
| 974 |
-
"
|
|
|
|
| 975 |
"\n",
|
| 976 |
-
"
|
|
|
|
|
|
|
|
|
|
| 977 |
"\n",
|
| 978 |
-
"
|
| 979 |
-
"
|
| 980 |
-
"| Stage 1 đang train dở, cùng VM | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID=None` |\n",
|
| 981 |
-
"| Stage 1 dở, **VM mới** (Colab disconnect) | `RESUME_STAGE=1`, `EXPLICIT_RUN_ID='IU-Xray_run_1'` |\n",
|
| 982 |
-
"| Stage 1 xong, chạy tiếp stage 2 | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` (cùng VM) hoặc set (VM mới) |\n",
|
| 983 |
-
"| Stage 2 đang dở, cùng VM | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID=None` |\n",
|
| 984 |
-
"| Stage 2 dở, **VM mới** | `RESUME_STAGE=2`, `EXPLICIT_RUN_ID='IU-Xray_run_1'` |\n",
|
| 985 |
"\n",
|
| 986 |
-
"
|
| 987 |
-
"\n",
|
| 988 |
-
"
|
| 989 |
-
"1. `cell-select` → `cell-env` → `cell-paths` (pull code + data từ HF)\n",
|
| 990 |
-
"2. `cell-pip` → **Runtime → Restart session** → `cell-pip-verify`\n",
|
| 991 |
-
"3. `cell-hf-token` → `cell-cfg` → `cell-sanity`\n",
|
| 992 |
-
"4. **Skip `cell-stage1` / `cell-stage2` mặc định** — chạy `cell-resume` ngay bên dưới, rồi chạy cell train tương ứng.\n",
|
| 993 |
-
"\n",
|
| 994 |
-
"Cell resume bên dưới tự:\n",
|
| 995 |
-
"- Pull các folder `checkpoint-<step>` mới nhất từ HF Hub về (từ `stage1/` hoặc `stage2/` trên hub).\n",
|
| 996 |
-
"- Rename `stage1/` → `stage1_projection/`, `stage2/` → `stage2_instruct/` cho khớp layout local mà `train.py` kỳ vọng.\n",
|
| 997 |
-
"- Ghi lại `run_id.txt` trong `CKPT_ROOT`.\n",
|
| 998 |
-
"- Tìm checkpoint có step cao nhất và set `RESUME_FROM` cho cell train.\n",
|
| 999 |
-
"\n",
|
| 1000 |
-
"### Train tiếp sẽ lưu ở đâu trên HF?\n",
|
| 1001 |
-
"\n",
|
| 1002 |
-
"**Vẫn cùng folder `<RUN_ID>/` trên HF.** `HFRunTracker` tái sử dụng `run_id` đã resolve nên:\n",
|
| 1003 |
-
"- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/stage2/checkpoint-XXXX/` (intermediate, callback mỗi 200 step)\n",
|
| 1004 |
-
"- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/stage2/stage2_final.pt` (final)\n",
|
| 1005 |
-
"- `hieu3636/cxr-vlm-runs/IU-Xray_run_1/meta.json` (merge với meta cũ, tăng `resume_count`)\n",
|
| 1006 |
-
"\n",
|
| 1007 |
-
"Không tạo `run_2` mới.\n",
|
| 1008 |
-
"\n",
|
| 1009 |
-
"### Auto-detect run_id\n",
|
| 1010 |
-
"\n",
|
| 1011 |
-
"`train.py` resolve `run_id` theo thứ tự:\n",
|
| 1012 |
-
"1. `--run_id` CLI (cell resume truyền cái này khi bạn set `EXPLICIT_RUN_ID`).\n",
|
| 1013 |
-
"2. `run_id.txt` trong `CKPT_ROOT` (cell resume ghi lại).\n",
|
| 1014 |
-
"3. Nếu cả 2 trống + `--resume_from` → tự tìm run mới nhất trên HF.\n"
|
| 1015 |
],
|
| 1016 |
"id": "cell-resume-md"
|
| 1017 |
},
|
| 1018 |
{
|
| 1019 |
"cell_type": "code",
|
| 1020 |
"metadata": {
|
| 1021 |
-
"id": "cell-
|
| 1022 |
"colab": {
|
| 1023 |
"base_uri": "https://localhost:8080/"
|
| 1024 |
},
|
| 1025 |
"outputId": "6f1b15fb-e751-4209-ea52-7a98a8c40212"
|
| 1026 |
},
|
| 1027 |
"execution_count": null,
|
| 1028 |
-
"outputs": [
|
| 1029 |
-
{
|
| 1030 |
-
"output_type": "stream",
|
| 1031 |
-
"name": "stdout",
|
| 1032 |
-
"text": [
|
| 1033 |
-
"RESUME_STAGE=None — fresh run. Skip this cell; go straight to cell-stage1.\n"
|
| 1034 |
-
]
|
| 1035 |
-
}
|
| 1036 |
-
],
|
| 1037 |
"source": [
|
| 1038 |
-
"# Resume controller —
|
| 1039 |
-
"
|
| 1040 |
-
"EXPLICIT_RUN_ID =
|
| 1041 |
-
"\n",
|
| 1042 |
-
"
|
| 1043 |
-
"
|
| 1044 |
-
"\n",
|
| 1045 |
-
"
|
| 1046 |
-
"
|
| 1047 |
-
"\n",
|
| 1048 |
-
"
|
| 1049 |
-
"
|
| 1050 |
-
" RESUME_RUN_ID = EXPLICIT_RUN_ID\n",
|
| 1051 |
-
" CKPT_ROOT.mkdir(parents=True, exist_ok=True)\n",
|
| 1052 |
-
" (CKPT_ROOT / \"run_id.txt\").write_text(RESUME_RUN_ID)\n",
|
| 1053 |
-
" print(f\"Using EXPLICIT_RUN_ID = {RESUME_RUN_ID} (wrote run_id.txt)\")\n",
|
| 1054 |
-
" else:\n",
|
| 1055 |
-
" state_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1056 |
-
" assert state_file.exists(), (\n",
|
| 1057 |
-
" \"No local run_id.txt — looks like a fresh VM. \"\n",
|
| 1058 |
-
" \"Set EXPLICIT_RUN_ID to the run folder on HF (e.g. \\\"IU-Xray_run_1\\\").\"\n",
|
| 1059 |
-
" )\n",
|
| 1060 |
-
" RESUME_RUN_ID = state_file.read_text().strip()\n",
|
| 1061 |
-
" print(f\"Using run_id from state file: {RESUME_RUN_ID}\")\n",
|
| 1062 |
-
"\n",
|
| 1063 |
-
" # 2) Local subdir names (code expects long names; HF uses short \"stage1\"/\"stage2\")\n",
|
| 1064 |
-
" local_subdir = \"stage1_projection\" if RESUME_STAGE == 1 else \"stage2_instruct\"\n",
|
| 1065 |
-
" remote_subdir = \"stage1\" if RESUME_STAGE == 1 else \"stage2\"\n",
|
| 1066 |
-
" local_stage_dir = CKPT_ROOT / RESUME_RUN_ID / local_subdir\n",
|
| 1067 |
-
" local_stage_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 1068 |
-
"\n",
|
| 1069 |
-
" # 3) If no local checkpoints, pull from HF\n",
|
| 1070 |
-
" existing = sorted(local_stage_dir.glob(\"checkpoint-*\"),\n",
|
| 1071 |
-
" key=lambda p: int(p.name.split(\"-\")[1]))\n",
|
| 1072 |
-
" if not existing:\n",
|
| 1073 |
-
" print(f\"No local checkpoints under {local_stage_dir} — pulling from HF Hub …\")\n",
|
| 1074 |
-
" from huggingface_hub import snapshot_download\n",
|
| 1075 |
-
" hub_prefix = f\"{RESUME_RUN_ID}/{remote_subdir}/\"\n",
|
| 1076 |
-
" pulled = snapshot_download(\n",
|
| 1077 |
-
" repo_id = train_cfg.hf_hub.repo_id,\n",
|
| 1078 |
-
" repo_type = \"model\",\n",
|
| 1079 |
-
" token = os.environ[\"HF_TOKEN\"],\n",
|
| 1080 |
-
" allow_patterns = [f\"{hub_prefix}checkpoint-*/**\"],\n",
|
| 1081 |
-
" local_dir = str(WORK / \"hf_pull\"),\n",
|
| 1082 |
-
" )\n",
|
| 1083 |
-
" hub_stage_dir = Path(pulled) / RESUME_RUN_ID / remote_subdir\n",
|
| 1084 |
-
" assert hub_stage_dir.exists() and any(hub_stage_dir.glob(\"checkpoint-*\")), (\n",
|
| 1085 |
-
" f\"No checkpoints found under {hub_prefix} on HF repo \"\n",
|
| 1086 |
-
" f\"{train_cfg.hf_hub.repo_id}. Did you set the right EXPLICIT_RUN_ID?\"\n",
|
| 1087 |
-
" )\n",
|
| 1088 |
-
" for ck in hub_stage_dir.glob(\"checkpoint-*\"):\n",
|
| 1089 |
-
" dst = local_stage_dir / ck.name\n",
|
| 1090 |
-
" if dst.exists():\n",
|
| 1091 |
-
" shutil.rmtree(dst)\n",
|
| 1092 |
-
" shutil.move(str(ck), str(dst))\n",
|
| 1093 |
-
" print(f\"Pulled {len(list(local_stage_dir.glob('checkpoint-*')))} checkpoint(s) → {local_stage_dir}\")\n",
|
| 1094 |
-
" existing = sorted(local_stage_dir.glob(\"checkpoint-*\"),\n",
|
| 1095 |
-
" key=lambda p: int(p.name.split(\"-\")[1]))\n",
|
| 1096 |
-
"\n",
|
| 1097 |
-
" assert existing, f\"Still no checkpoints under {local_stage_dir}\"\n",
|
| 1098 |
-
"\n",
|
| 1099 |
-
" # 4) Latest checkpoint = highest global_step\n",
|
| 1100 |
-
" RESUME_FROM = existing[-1]\n",
|
| 1101 |
-
" print()\n",
|
| 1102 |
-
" print(f\"✓ Ready to resume STAGE {RESUME_STAGE} of run {RESUME_RUN_ID}\")\n",
|
| 1103 |
-
" print(f\" checkpoints on disk : {[c.name for c in existing]}\")\n",
|
| 1104 |
-
" print(f\" will resume from : {RESUME_FROM}\")\n",
|
| 1105 |
-
" print()\n",
|
| 1106 |
-
" print(f\"→ Now run cell-stage{RESUME_STAGE} below.\")\n",
|
| 1107 |
"else:\n",
|
| 1108 |
-
" print(\"
|
| 1109 |
],
|
| 1110 |
"id": "cell-resume"
|
| 1111 |
},
|
| 1112 |
{
|
| 1113 |
"cell_type": "markdown",
|
| 1114 |
"metadata": {
|
| 1115 |
-
"id": "cell-
|
| 1116 |
},
|
| 1117 |
"source": [
|
| 1118 |
-
"## 6.
|
| 1119 |
"\n",
|
| 1120 |
-
"
|
| 1121 |
"\n",
|
| 1122 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1123 |
],
|
| 1124 |
"id": "cell-stage1-md"
|
| 1125 |
},
|
| 1126 |
{
|
| 1127 |
"cell_type": "code",
|
| 1128 |
"metadata": {
|
| 1129 |
-
"id": "cell-
|
| 1130 |
"colab": {
|
| 1131 |
"base_uri": "https://localhost:8080/"
|
| 1132 |
},
|
| 1133 |
"outputId": "c7d6c209-6790-473c-c1b7-a44441141785"
|
| 1134 |
},
|
| 1135 |
"source": [
|
| 1136 |
-
"#
|
| 1137 |
"import time as _time, json as _json\n",
|
| 1138 |
"from datetime import datetime as _dt, timezone as _tz\n",
|
| 1139 |
"from pathlib import Path as _Path\n",
|
|
@@ -1146,16 +1063,13 @@
|
|
| 1146 |
" try:\n",
|
| 1147 |
" from huggingface_hub import hf_hub_download\n",
|
| 1148 |
" hf_hub_download(\n",
|
| 1149 |
-
" repo_id
|
| 1150 |
-
"
|
| 1151 |
-
"
|
| 1152 |
-
" token = token,\n",
|
| 1153 |
-
" local_dir = str(ckpt_root),\n",
|
| 1154 |
" )\n",
|
| 1155 |
" print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
|
| 1156 |
-
" except Exception
|
| 1157 |
-
" #
|
| 1158 |
-
" pass\n",
|
| 1159 |
"\n",
|
| 1160 |
"def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1161 |
" # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
|
|
@@ -1165,247 +1079,87 @@
|
|
| 1165 |
" try:\n",
|
| 1166 |
" from huggingface_hub import HfApi\n",
|
| 1167 |
" HfApi(token=token).upload_file(\n",
|
| 1168 |
-
" path_or_fileobj
|
| 1169 |
-
" path_in_repo
|
| 1170 |
-
" repo_id
|
| 1171 |
-
"
|
| 1172 |
-
" commit_message = f\"timing.json @ {run_id}\",\n",
|
| 1173 |
" )\n",
|
| 1174 |
" print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
|
| 1175 |
" except Exception as e:\n",
|
| 1176 |
" print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
|
| 1177 |
"\n",
|
| 1178 |
"\n",
|
| 1179 |
-
"
|
| 1180 |
-
"_is_resume = False\n",
|
| 1181 |
-
"if \"RESUME_FROM\" in dir() and RESUME_FROM and RESUME_STAGE == 1:\n",
|
| 1182 |
-
" _resume_args = f'--resume_from \"{RESUME_FROM}\" --run_id \"{RESUME_RUN_ID}\"'\n",
|
| 1183 |
-
" _is_resume = True\n",
|
| 1184 |
-
" print(\"▶ STAGE 1 resuming from\", RESUME_FROM)\n",
|
| 1185 |
-
"else:\n",
|
| 1186 |
-
" print(\"▶ STAGE 1 fresh run\")\n",
|
| 1187 |
-
"\n",
|
| 1188 |
-
"# ─── Pre-pull timing.json from HF if resuming (best-effort) ────────────────\n",
|
| 1189 |
"_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
|
| 1190 |
"_hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
| 1191 |
-
"if _is_resume and \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
|
| 1192 |
-
" _pull_timing_from_hf(RESUME_RUN_ID, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1193 |
"\n",
|
| 1194 |
-
"# ───
|
| 1195 |
-
"
|
| 1196 |
-
"
|
|
|
|
|
|
|
| 1197 |
"\n",
|
| 1198 |
-
"
|
| 1199 |
-
"
|
| 1200 |
-
" --
|
| 1201 |
-
" --train_config configs/train_config.yaml \\\n",
|
| 1202 |
-
" --stage 1 {_resume_args}\n",
|
| 1203 |
-
"\n",
|
| 1204 |
-
"# ─── End-of-stage timer + persist cumulative time ──────────────────────────\n",
|
| 1205 |
-
"_elapsed_stage1 = _time.time() - _t0_stage1\n",
|
| 1206 |
-
"_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1207 |
-
"if _run_id_file.exists():\n",
|
| 1208 |
-
" _run_id_now = _run_id_file.read_text().strip()\n",
|
| 1209 |
-
" _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
|
| 1210 |
-
" _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
|
| 1211 |
-
" _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
|
| 1212 |
-
" \"stage1_elapsed_sec\": 0.0,\n",
|
| 1213 |
-
" \"stage2_elapsed_sec\": 0.0,\n",
|
| 1214 |
-
" \"resume_count_stage1\": 0,\n",
|
| 1215 |
-
" \"resume_count_stage2\": 0,\n",
|
| 1216 |
-
" \"first_started_at\": None,\n",
|
| 1217 |
-
" \"last_finished_at\": None,\n",
|
| 1218 |
-
" \"session_history\": [],\n",
|
| 1219 |
-
" }\n",
|
| 1220 |
-
" if _t.get(\"first_started_at\") is None:\n",
|
| 1221 |
-
" _t[\"first_started_at\"] = _iso_start_stage1\n",
|
| 1222 |
-
" _t[\"stage1_elapsed_sec\"] = float(_t.get(\"stage1_elapsed_sec\", 0.0)) + _elapsed_stage1\n",
|
| 1223 |
-
" _t[\"resume_count_stage1\"] = int(_t.get(\"resume_count_stage1\", 0)) + (1 if _is_resume else 0)\n",
|
| 1224 |
-
" _t[\"last_finished_at\"] = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
|
| 1225 |
-
" _t.setdefault(\"session_history\", []).append({\n",
|
| 1226 |
-
" \"stage\": 1,\n",
|
| 1227 |
-
" \"resumed\": _is_resume,\n",
|
| 1228 |
-
" \"started\": _iso_start_stage1,\n",
|
| 1229 |
-
" \"finished\": _t[\"last_finished_at\"],\n",
|
| 1230 |
-
" \"elapsed_sec\": _elapsed_stage1,\n",
|
| 1231 |
-
" })\n",
|
| 1232 |
-
" _timing_path.write_text(_json.dumps(_t, indent=2))\n",
|
| 1233 |
-
"\n",
|
| 1234 |
-
" # ─── Push to HF Hub so the timer survives a fresh VM ─────────────────\n",
|
| 1235 |
-
" _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1236 |
"\n",
|
| 1237 |
-
"
|
| 1238 |
-
"
|
| 1239 |
-
"
|
| 1240 |
-
" print(f\"[TIMING] Stage 1 this session : {_fmt(_elapsed_stage1)}\")\n",
|
| 1241 |
-
" print(f\"[TIMING] Stage 1 cumulative : {_fmt(_t['stage1_elapsed_sec'])} (resumes so far: {_t['resume_count_stage1']})\")\n",
|
| 1242 |
-
" print(f\"[TIMING] persisted to : {_timing_path}\")\n",
|
| 1243 |
-
"else:\n",
|
| 1244 |
-
" print(\"[TIMING] run_id.txt missing — could not persist timing (training likely failed before resolve_run_id ran).\")\n"
|
| 1245 |
-
],
|
| 1246 |
-
"execution_count": null,
|
| 1247 |
-
"outputs": [],
|
| 1248 |
-
"id": "cell-stage1"
|
| 1249 |
-
},
|
| 1250 |
-
{
|
| 1251 |
-
"cell_type": "markdown",
|
| 1252 |
-
"metadata": {
|
| 1253 |
-
"id": "cell-stage2-md"
|
| 1254 |
-
},
|
| 1255 |
-
"source": [
|
| 1256 |
-
"## 7. Stage 2 — projection + LoRA instruction tuning\n",
|
| 1257 |
-
"\n",
|
| 1258 |
-
"Kaggle caps a GPU session at 9h. If Stage 2 doesn't finish, Persistence keeps the Trainer checkpoints in `/kaggle/working/ckpt/{RUN_ID}/stage2_instruct/checkpoint-XXXX/` — resume next session with:\n",
|
| 1259 |
-
"```\n",
|
| 1260 |
-
"!python -m training.train --stage 2 \\\n",
|
| 1261 |
-
" --model_config configs/model_config.yaml --train_config configs/train_config.yaml \\\n",
|
| 1262 |
-
" --resume_from /kaggle/working/ckpt/{RUN_ID}/stage2_instruct/checkpoint-XXXX\n",
|
| 1263 |
-
"```\n",
|
| 1264 |
-
"This **does not** create a new `run_N+1` on HF — it reuses the existing run."
|
| 1265 |
-
],
|
| 1266 |
-
"id": "cell-stage2-md"
|
| 1267 |
-
},
|
| 1268 |
-
{
|
| 1269 |
-
"cell_type": "code",
|
| 1270 |
-
"metadata": {
|
| 1271 |
-
"id": "cell-stage2",
|
| 1272 |
-
"colab": {
|
| 1273 |
-
"base_uri": "https://localhost:8080/"
|
| 1274 |
-
},
|
| 1275 |
-
"outputId": "d9cd4a96-ec88-4907-fc0e-38b6eb2f66a7"
|
| 1276 |
-
},
|
| 1277 |
-
"source": [
|
| 1278 |
-
"# Picks up RESUME_FROM / RESUME_RUN_ID from cell-resume (None if fresh run).\n",
|
| 1279 |
-
"import time as _time, json as _json\n",
|
| 1280 |
-
"from datetime import datetime as _dt, timezone as _tz\n",
|
| 1281 |
-
"from pathlib import Path as _Path\n",
|
| 1282 |
-
"\n",
|
| 1283 |
-
"def _pull_timing_from_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1284 |
-
" # Pull timing.json from HF Hub for this run if not present locally.\n",
|
| 1285 |
-
" local = ckpt_root / run_id / \"timing.json\"\n",
|
| 1286 |
-
" if local.exists() or not repo_id or not token:\n",
|
| 1287 |
-
" return\n",
|
| 1288 |
-
" try:\n",
|
| 1289 |
-
" from huggingface_hub import hf_hub_download\n",
|
| 1290 |
-
" hf_hub_download(\n",
|
| 1291 |
-
" repo_id = repo_id,\n",
|
| 1292 |
-
" repo_type = \"model\",\n",
|
| 1293 |
-
" filename = f\"{run_id}/timing.json\",\n",
|
| 1294 |
-
" token = token,\n",
|
| 1295 |
-
" local_dir = str(ckpt_root),\n",
|
| 1296 |
-
" )\n",
|
| 1297 |
-
" print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
|
| 1298 |
-
" except Exception as e:\n",
|
| 1299 |
-
" # First run for this run_id → no remote file yet. That's fine.\n",
|
| 1300 |
-
" pass\n",
|
| 1301 |
-
"\n",
|
| 1302 |
-
"def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1303 |
-
" # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
|
| 1304 |
-
" local = ckpt_root / run_id / \"timing.json\"\n",
|
| 1305 |
-
" if not local.exists() or not repo_id or not token:\n",
|
| 1306 |
-
" return\n",
|
| 1307 |
-
" try:\n",
|
| 1308 |
-
" from huggingface_hub import HfApi\n",
|
| 1309 |
-
" HfApi(token=token).upload_file(\n",
|
| 1310 |
-
" path_or_fileobj = str(local),\n",
|
| 1311 |
-
" path_in_repo = f\"{run_id}/timing.json\",\n",
|
| 1312 |
-
" repo_id = repo_id,\n",
|
| 1313 |
-
" repo_type = \"model\",\n",
|
| 1314 |
-
" commit_message = f\"timing.json @ {run_id}\",\n",
|
| 1315 |
-
" )\n",
|
| 1316 |
-
" print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
|
| 1317 |
-
" except Exception as e:\n",
|
| 1318 |
-
" print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
|
| 1319 |
-
"\n",
|
| 1320 |
-
"\n",
|
| 1321 |
-
"_resume_args = \"\"\n",
|
| 1322 |
-
"_is_resume = False\n",
|
| 1323 |
-
"if \"RESUME_FROM\" in dir() and RESUME_FROM and RESUME_STAGE == 2:\n",
|
| 1324 |
-
" _resume_args = f'--resume_from \"{RESUME_FROM}\" --run_id \"{RESUME_RUN_ID}\"'\n",
|
| 1325 |
-
" _is_resume = True\n",
|
| 1326 |
-
" print(\"▶ STAGE 2 resuming from\", RESUME_FROM)\n",
|
| 1327 |
-
"elif \"RESUME_RUN_ID\" in dir() and RESUME_RUN_ID:\n",
|
| 1328 |
-
" _resume_args = f'--run_id \"{RESUME_RUN_ID}\"'\n",
|
| 1329 |
-
" print(\"▶ STAGE 2 fresh start, pinned to run_id\", RESUME_RUN_ID)\n",
|
| 1330 |
-
"else:\n",
|
| 1331 |
-
" # ─── FIX: pin stage 2 to the run_id stage 1 just wrote ────────────────\n",
|
| 1332 |
-
" # Without this, train.py treats stage 2 as a brand-new launch and\n",
|
| 1333 |
-
" # allocates a NEW run_N folder, splitting stage1/stage2 across two runs.\n",
|
| 1334 |
-
" _state_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1335 |
-
" if _state_file.exists():\n",
|
| 1336 |
-
" _pinned = _state_file.read_text().strip()\n",
|
| 1337 |
-
" _resume_args = f'--run_id \"{_pinned}\"'\n",
|
| 1338 |
-
" print(f\"▶ STAGE 2 fresh, auto-pinned to run_id from state file: {_pinned}\")\n",
|
| 1339 |
-
" else:\n",
|
| 1340 |
-
" print(\"▶ STAGE 2 fresh (no state file — train.py will allocate a new run_id)\")\n",
|
| 1341 |
-
"\n",
|
| 1342 |
-
"# ─── Pre-pull timing.json from HF (in case of fresh VM) ───────────────────\n",
|
| 1343 |
-
"_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
|
| 1344 |
-
"_hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
| 1345 |
-
"# Best guess at run_id BEFORE training (may be missing if stage 1 wasn't run here)\n",
|
| 1346 |
-
"_pre_state = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1347 |
-
"if _pre_state.exists():\n",
|
| 1348 |
-
" _pull_timing_from_hf(_pre_state.read_text().strip(), CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1349 |
"\n",
|
| 1350 |
-
"# ───
|
| 1351 |
-
"
|
| 1352 |
-
"
|
| 1353 |
"\n",
|
| 1354 |
"!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
|
| 1355 |
"python -u -m training.train \\\n",
|
| 1356 |
" --model_config configs/model_config.yaml \\\n",
|
| 1357 |
" --train_config configs/train_config.yaml \\\n",
|
| 1358 |
-
" --
|
|
|
|
|
|
|
| 1359 |
"\n",
|
| 1360 |
-
"# ───
|
| 1361 |
-
"_elapsed_stage2 = _time.time() - _t0_stage2\n",
|
| 1362 |
"_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1363 |
"if _run_id_file.exists():\n",
|
| 1364 |
" _run_id_now = _run_id_file.read_text().strip()\n",
|
| 1365 |
" _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
|
| 1366 |
" _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
|
| 1367 |
" _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
|
| 1368 |
-
" \"
|
| 1369 |
-
" \"
|
| 1370 |
-
" \"
|
| 1371 |
-
" \"
|
| 1372 |
-
" \"
|
| 1373 |
-
" \"last_finished_at\": None,\n",
|
| 1374 |
-
" \"session_history\": [],\n",
|
| 1375 |
" }\n",
|
| 1376 |
" if _t.get(\"first_started_at\") is None:\n",
|
| 1377 |
-
" _t[\"first_started_at\"] =
|
| 1378 |
-
" _t[\"
|
| 1379 |
-
" _t[\"
|
| 1380 |
-
" _t[\"last_finished_at\"]
|
| 1381 |
" _t.setdefault(\"session_history\", []).append({\n",
|
| 1382 |
-
" \"
|
| 1383 |
-
" \"
|
| 1384 |
-
" \"
|
| 1385 |
-
" \"
|
| 1386 |
-
" \"elapsed_sec\": _elapsed_stage2,\n",
|
| 1387 |
" })\n",
|
| 1388 |
" _timing_path.write_text(_json.dumps(_t, indent=2))\n",
|
| 1389 |
-
"\n",
|
| 1390 |
-
" # ─── Push to HF Hub ──────────────────────────────────────────────────\n",
|
| 1391 |
" _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1392 |
"\n",
|
| 1393 |
" def _fmt(sec):\n",
|
| 1394 |
" h, r = divmod(int(sec), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
|
| 1395 |
-
" _total = _t[\"stage1_elapsed_sec\"] + _t[\"stage2_elapsed_sec\"]\n",
|
| 1396 |
" print()\n",
|
| 1397 |
-
" print(f\"[TIMING]
|
| 1398 |
-
" print(f\"[TIMING]
|
| 1399 |
-
" print(f\"[TIMING]
|
| 1400 |
-
" print(f\"[TIMING]
|
| 1401 |
-
" print(f\"[TIMING]
|
| 1402 |
-
" print(f\"[TIMING] persisted to : {_timing_path}\")\n",
|
| 1403 |
"else:\n",
|
| 1404 |
" print(\"[TIMING] run_id.txt missing — could not persist timing.\")\n"
|
| 1405 |
],
|
| 1406 |
"execution_count": null,
|
| 1407 |
"outputs": [],
|
| 1408 |
-
"id": "cell-
|
| 1409 |
},
|
| 1410 |
{
|
| 1411 |
"cell_type": "markdown",
|
|
@@ -1557,16 +1311,13 @@
|
|
| 1557 |
" try:\n",
|
| 1558 |
" from huggingface_hub import hf_hub_download\n",
|
| 1559 |
" hf_hub_download(\n",
|
| 1560 |
-
" repo_id
|
| 1561 |
-
"
|
| 1562 |
-
"
|
| 1563 |
-
" token = token,\n",
|
| 1564 |
-
" local_dir = str(ckpt_root),\n",
|
| 1565 |
" )\n",
|
| 1566 |
" print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
|
| 1567 |
-
" except Exception
|
| 1568 |
-
" #
|
| 1569 |
-
" pass\n",
|
| 1570 |
"\n",
|
| 1571 |
"def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1572 |
" # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
|
|
@@ -1576,11 +1327,10 @@
|
|
| 1576 |
" try:\n",
|
| 1577 |
" from huggingface_hub import HfApi\n",
|
| 1578 |
" HfApi(token=token).upload_file(\n",
|
| 1579 |
-
" path_or_fileobj
|
| 1580 |
-
" path_in_repo
|
| 1581 |
-
" repo_id
|
| 1582 |
-
"
|
| 1583 |
-
" commit_message = f\"timing.json @ {run_id}\",\n",
|
| 1584 |
" )\n",
|
| 1585 |
" print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
|
| 1586 |
" except Exception as e:\n",
|
|
@@ -1588,10 +1338,10 @@
|
|
| 1588 |
"\n",
|
| 1589 |
"\n",
|
| 1590 |
"_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1591 |
-
"assert _run_id_file.exists(), \"No run_id.txt — train at least
|
| 1592 |
"_run_id = _run_id_file.read_text().strip()\n",
|
| 1593 |
"\n",
|
| 1594 |
-
"# Pull
|
| 1595 |
"_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
|
| 1596 |
"_hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
| 1597 |
"_pull_timing_from_hf(_run_id, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
|
@@ -1599,32 +1349,24 @@
|
|
| 1599 |
"_timing_path = CKPT_ROOT / _run_id / \"timing.json\"\n",
|
| 1600 |
"assert _timing_path.exists(), (\n",
|
| 1601 |
" f\"No timing.json under {_timing_path.parent} (also not on HF). \"\n",
|
| 1602 |
-
" f\"
|
| 1603 |
")\n",
|
| 1604 |
"\n",
|
| 1605 |
"_t = _json.loads(_timing_path.read_text())\n",
|
| 1606 |
"\n",
|
| 1607 |
"def _fmt(sec):\n",
|
| 1608 |
-
" h, r = divmod(int(sec or 0), 3600)\n",
|
| 1609 |
-
" m, s = divmod(r, 60)\n",
|
| 1610 |
-
" return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
|
| 1611 |
-
"\n",
|
| 1612 |
-
"_s1 = float(_t.get(\"stage1_elapsed_sec\", 0.0))\n",
|
| 1613 |
-
"_s2 = float(_t.get(\"stage2_elapsed_sec\", 0.0))\n",
|
| 1614 |
-
"_total = _s1 + _s2\n",
|
| 1615 |
"\n",
|
| 1616 |
"print(f\"Run : {_run_id}\")\n",
|
| 1617 |
"print(f\"First started at : {_t.get('first_started_at')}\")\n",
|
| 1618 |
"print(f\"Last finished at : {_t.get('last_finished_at')}\")\n",
|
| 1619 |
-
"print()\n",
|
| 1620 |
-
"print(f\"
|
| 1621 |
-
"print(f\"Stage 2 cumulative : {_fmt(_s2)} (resumes: {_t.get('resume_count_stage2', 0)})\")\n",
|
| 1622 |
-
"print(f\"TOTAL : {_fmt(_total)}\")\n",
|
| 1623 |
"print()\n",
|
| 1624 |
"print(\"Session history :\")\n",
|
| 1625 |
"for _i, _s in enumerate(_t.get(\"session_history\", []), 1):\n",
|
| 1626 |
-
"
|
| 1627 |
-
"
|
| 1628 |
],
|
| 1629 |
"outputs": [],
|
| 1630 |
"execution_count": null
|
|
|
|
| 966 |
{
|
| 967 |
"cell_type": "markdown",
|
| 968 |
"metadata": {
|
| 969 |
+
"id": "cell-mode-md"
|
| 970 |
},
|
| 971 |
"source": [
|
| 972 |
+
"## 5b. Resume controller\n",
|
| 973 |
"\n",
|
| 974 |
+
"Single switch. No more \"which stage\" — `train.py` auto-detects which stage\n",
|
| 975 |
+
"to continue from by inspecting checkpoints on disk.\n",
|
| 976 |
"\n",
|
| 977 |
+
"| MODE | What happens |\n",
|
| 978 |
+
"|---------------------|--------------|\n",
|
| 979 |
+
"| `'fresh'` | Allocate a brand-new `{DATASET}_run_N+1` folder. Train both stages from scratch. |\n",
|
| 980 |
+
"| `'resume'` | Reuse latest matching `{DATASET}_run_N` (or `EXPLICIT_RUN_ID`). Auto-detect: stage 1 mid-checkpoint, stage 1 done → stage 2 fresh, stage 2 mid-checkpoint, or both done. |\n",
|
| 981 |
"\n",
|
| 982 |
+
"`EXPLICIT_RUN_ID` is optional (set to `None` to auto-pick the latest run on\n",
|
| 983 |
+
"disk or HF Hub that matches the current dataset prefix).\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 984 |
"\n",
|
| 985 |
+
"When `MODE='resume'` on a fresh VM the train cell will pull the previous\n",
|
| 986 |
+
"run's checkpoints from HF before training. The `--mode resume` flag in\n",
|
| 987 |
+
"`train.py` does the auto-detect — no further action needed in the notebook."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
],
|
| 989 |
"id": "cell-resume-md"
|
| 990 |
},
|
| 991 |
{
|
| 992 |
"cell_type": "code",
|
| 993 |
"metadata": {
|
| 994 |
+
"id": "cell-mode",
|
| 995 |
"colab": {
|
| 996 |
"base_uri": "https://localhost:8080/"
|
| 997 |
},
|
| 998 |
"outputId": "6f1b15fb-e751-4209-ea52-7a98a8c40212"
|
| 999 |
},
|
| 1000 |
"execution_count": null,
|
| 1001 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
"source": [
|
| 1003 |
+
"# Resume controller — set MODE once, run the train cell below.\n",
|
| 1004 |
+
"MODE = 'fresh' # 'fresh' | 'resume'\n",
|
| 1005 |
+
"EXPLICIT_RUN_ID = None # None | 'IU-Xray_run_5' (only matters when MODE='resume')\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
"assert MODE in ('fresh', 'resume'), \"MODE must be 'fresh' or 'resume'\"\n",
|
| 1008 |
+
"if MODE == 'resume' and EXPLICIT_RUN_ID:\n",
|
| 1009 |
+
" CKPT_ROOT.mkdir(parents=True, exist_ok=True)\n",
|
| 1010 |
+
" (CKPT_ROOT / 'run_id.txt').write_text(EXPLICIT_RUN_ID)\n",
|
| 1011 |
+
" print(f\"MODE=resume, pinning run_id = {EXPLICIT_RUN_ID}\")\n",
|
| 1012 |
+
"elif MODE == 'resume':\n",
|
| 1013 |
+
" print(\"MODE=resume, run_id will be auto-resolved to the latest \"\n",
|
| 1014 |
+
" f\"'{DATASET_NAME}_run_*' on disk (or HF Hub).\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
"else:\n",
|
| 1016 |
+
" print(\"MODE=fresh — train.py will allocate a new run folder.\")\n"
|
| 1017 |
],
|
| 1018 |
"id": "cell-resume"
|
| 1019 |
},
|
| 1020 |
{
|
| 1021 |
"cell_type": "markdown",
|
| 1022 |
"metadata": {
|
| 1023 |
+
"id": "cell-train-md"
|
| 1024 |
},
|
| 1025 |
"source": [
|
| 1026 |
+
"## 6. Train (both stages, one shot)\n",
|
| 1027 |
"\n",
|
| 1028 |
+
"Single cell runs `train.py --mode {MODE}` which:\n",
|
| 1029 |
"\n",
|
| 1030 |
+
"1. Resolves `run_id` (new vs. latest matching).\n",
|
| 1031 |
+
"2. Prints the total step plan (stage 1 + stage 2 = global step count).\n",
|
| 1032 |
+
"3. Auto-detects which stage to resume from by scanning the run folder.\n",
|
| 1033 |
+
"4. Runs stage 1 (skip if already done), then stage 2.\n",
|
| 1034 |
+
"5. Pushes intermediate `checkpoint-NNN/` + `best/` of each stage to HF Hub\n",
|
| 1035 |
+
" as before. `timing.json` is updated and uploaded after each stage.\n",
|
| 1036 |
+
"\n",
|
| 1037 |
+
"Kaggle / Colab caps a GPU session at ~9h. If the session dies mid-stage,\n",
|
| 1038 |
+
"just re-run this cell with `MODE='resume'` — `train.py` picks up at the\n",
|
| 1039 |
+
"last `checkpoint-NNN/` of whichever stage was in progress."
|
| 1040 |
],
|
| 1041 |
"id": "cell-stage1-md"
|
| 1042 |
},
|
| 1043 |
{
|
| 1044 |
"cell_type": "code",
|
| 1045 |
"metadata": {
|
| 1046 |
+
"id": "cell-train",
|
| 1047 |
"colab": {
|
| 1048 |
"base_uri": "https://localhost:8080/"
|
| 1049 |
},
|
| 1050 |
"outputId": "c7d6c209-6790-473c-c1b7-a44441141785"
|
| 1051 |
},
|
| 1052 |
"source": [
|
| 1053 |
+
"# Unified train cell — drives both stages in one shot, with auto-resume.\n",
|
| 1054 |
"import time as _time, json as _json\n",
|
| 1055 |
"from datetime import datetime as _dt, timezone as _tz\n",
|
| 1056 |
"from pathlib import Path as _Path\n",
|
|
|
|
| 1063 |
" try:\n",
|
| 1064 |
" from huggingface_hub import hf_hub_download\n",
|
| 1065 |
" hf_hub_download(\n",
|
| 1066 |
+
" repo_id=repo_id, repo_type=\"model\",\n",
|
| 1067 |
+
" filename=f\"{run_id}/timing.json\",\n",
|
| 1068 |
+
" token=token, local_dir=str(ckpt_root),\n",
|
|
|
|
|
|
|
| 1069 |
" )\n",
|
| 1070 |
" print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
|
| 1071 |
+
" except Exception:\n",
|
| 1072 |
+
" pass # first time for this run_id → no remote file yet, fine\n",
|
|
|
|
| 1073 |
"\n",
|
| 1074 |
"def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1075 |
" # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
|
|
|
|
| 1079 |
" try:\n",
|
| 1080 |
" from huggingface_hub import HfApi\n",
|
| 1081 |
" HfApi(token=token).upload_file(\n",
|
| 1082 |
+
" path_or_fileobj=str(local),\n",
|
| 1083 |
+
" path_in_repo=f\"{run_id}/timing.json\",\n",
|
| 1084 |
+
" repo_id=repo_id, repo_type=\"model\",\n",
|
| 1085 |
+
" commit_message=f\"timing.json @ {run_id}\",\n",
|
|
|
|
| 1086 |
" )\n",
|
| 1087 |
" print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
|
| 1088 |
" except Exception as e:\n",
|
| 1089 |
" print(f\"[TIMING] upload failed (non-fatal): {e}\")\n",
|
| 1090 |
"\n",
|
| 1091 |
"\n",
|
| 1092 |
+
"assert MODE in ('fresh', 'resume')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1093 |
"_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
|
| 1094 |
"_hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
|
|
|
|
|
|
| 1095 |
"\n",
|
| 1096 |
+
"# ─── Pre-pull timing.json if resuming ──────────────────────────────────────\n",
|
| 1097 |
+
"if MODE == 'resume':\n",
|
| 1098 |
+
" _pre_state = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1099 |
+
" if _pre_state.exists():\n",
|
| 1100 |
+
" _pull_timing_from_hf(_pre_state.read_text().strip(), CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1101 |
"\n",
|
| 1102 |
+
"_run_id_arg = \"\"\n",
|
| 1103 |
+
"if MODE == 'resume' and EXPLICIT_RUN_ID:\n",
|
| 1104 |
+
" _run_id_arg = f'--run_id \"{EXPLICIT_RUN_ID}\"'\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
"\n",
|
| 1106 |
+
"print(f\"▶ Launching train.py --mode {MODE} {_run_id_arg}\")\n",
|
| 1107 |
+
"print(f\" (Both stages dispatched in one call. Stage will be auto-detected for resume.)\")\n",
|
| 1108 |
+
"print()\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1109 |
"\n",
|
| 1110 |
+
"# ─── Wall-clock timing (entire training session) ──────────────────────────\n",
|
| 1111 |
+
"_t0 = _time.time()\n",
|
| 1112 |
+
"_iso_start = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
|
| 1113 |
"\n",
|
| 1114 |
"!HF_HUB_DISABLE_PROGRESS_BARS=1 TRANSFORMERS_VERBOSITY=warning TOKENIZERS_PARALLELISM=false BITSANDBYTES_NOWELCOME=1 PYTHONUNBUFFERED=1 \\\n",
|
| 1115 |
"python -u -m training.train \\\n",
|
| 1116 |
" --model_config configs/model_config.yaml \\\n",
|
| 1117 |
" --train_config configs/train_config.yaml \\\n",
|
| 1118 |
+
" --mode {MODE} {_run_id_arg}\n",
|
| 1119 |
+
"\n",
|
| 1120 |
+
"_elapsed = _time.time() - _t0\n",
|
| 1121 |
"\n",
|
| 1122 |
+
"# ─── Persist timing.json (cumulative across resumes) ──────────────────────\n",
|
|
|
|
| 1123 |
"_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1124 |
"if _run_id_file.exists():\n",
|
| 1125 |
" _run_id_now = _run_id_file.read_text().strip()\n",
|
| 1126 |
" _timing_path = CKPT_ROOT / _run_id_now / \"timing.json\"\n",
|
| 1127 |
" _timing_path.parent.mkdir(parents=True, exist_ok=True)\n",
|
| 1128 |
" _t = _json.loads(_timing_path.read_text()) if _timing_path.exists() else {\n",
|
| 1129 |
+
" \"total_elapsed_sec\": 0.0,\n",
|
| 1130 |
+
" \"session_count\": 0,\n",
|
| 1131 |
+
" \"first_started_at\": None,\n",
|
| 1132 |
+
" \"last_finished_at\": None,\n",
|
| 1133 |
+
" \"session_history\": [],\n",
|
|
|
|
|
|
|
| 1134 |
" }\n",
|
| 1135 |
" if _t.get(\"first_started_at\") is None:\n",
|
| 1136 |
+
" _t[\"first_started_at\"] = _iso_start\n",
|
| 1137 |
+
" _t[\"total_elapsed_sec\"] = float(_t.get(\"total_elapsed_sec\", 0.0)) + _elapsed\n",
|
| 1138 |
+
" _t[\"session_count\"] = int(_t.get(\"session_count\", 0)) + 1\n",
|
| 1139 |
+
" _t[\"last_finished_at\"] = _dt.now(_tz.utc).isoformat(timespec=\"seconds\")\n",
|
| 1140 |
" _t.setdefault(\"session_history\", []).append({\n",
|
| 1141 |
+
" \"mode\": MODE,\n",
|
| 1142 |
+
" \"started\": _iso_start,\n",
|
| 1143 |
+
" \"finished\": _t[\"last_finished_at\"],\n",
|
| 1144 |
+
" \"elapsed_sec\": _elapsed,\n",
|
|
|
|
| 1145 |
" })\n",
|
| 1146 |
" _timing_path.write_text(_json.dumps(_t, indent=2))\n",
|
|
|
|
|
|
|
| 1147 |
" _push_timing_to_hf(_run_id_now, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
| 1148 |
"\n",
|
| 1149 |
" def _fmt(sec):\n",
|
| 1150 |
" h, r = divmod(int(sec), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
|
|
|
|
| 1151 |
" print()\n",
|
| 1152 |
+
" print(f\"[TIMING] This session : {_fmt(_elapsed)}\")\n",
|
| 1153 |
+
" print(f\"[TIMING] Cumulative : {_fmt(_t['total_elapsed_sec'])} ({_t['session_count']} session(s))\")\n",
|
| 1154 |
+
" print(f\"[TIMING] first started : {_t.get('first_started_at')}\")\n",
|
| 1155 |
+
" print(f\"[TIMING] last finished : {_t.get('last_finished_at')}\")\n",
|
| 1156 |
+
" print(f\"[TIMING] persisted to : {_timing_path}\")\n",
|
|
|
|
| 1157 |
"else:\n",
|
| 1158 |
" print(\"[TIMING] run_id.txt missing — could not persist timing.\")\n"
|
| 1159 |
],
|
| 1160 |
"execution_count": null,
|
| 1161 |
"outputs": [],
|
| 1162 |
+
"id": "cell-stage1"
|
| 1163 |
},
|
| 1164 |
{
|
| 1165 |
"cell_type": "markdown",
|
|
|
|
| 1311 |
" try:\n",
|
| 1312 |
" from huggingface_hub import hf_hub_download\n",
|
| 1313 |
" hf_hub_download(\n",
|
| 1314 |
+
" repo_id=repo_id, repo_type=\"model\",\n",
|
| 1315 |
+
" filename=f\"{run_id}/timing.json\",\n",
|
| 1316 |
+
" token=token, local_dir=str(ckpt_root),\n",
|
|
|
|
|
|
|
| 1317 |
" )\n",
|
| 1318 |
" print(f\"[TIMING] pulled previous timing.json from HF → {local}\")\n",
|
| 1319 |
+
" except Exception:\n",
|
| 1320 |
+
" pass # first time for this run_id → no remote file yet, fine\n",
|
|
|
|
| 1321 |
"\n",
|
| 1322 |
"def _push_timing_to_hf(run_id, ckpt_root, repo_id, token):\n",
|
| 1323 |
" # Upload local timing.json to HF Hub under {run_id}/timing.json.\n",
|
|
|
|
| 1327 |
" try:\n",
|
| 1328 |
" from huggingface_hub import HfApi\n",
|
| 1329 |
" HfApi(token=token).upload_file(\n",
|
| 1330 |
+
" path_or_fileobj=str(local),\n",
|
| 1331 |
+
" path_in_repo=f\"{run_id}/timing.json\",\n",
|
| 1332 |
+
" repo_id=repo_id, repo_type=\"model\",\n",
|
| 1333 |
+
" commit_message=f\"timing.json @ {run_id}\",\n",
|
|
|
|
| 1334 |
" )\n",
|
| 1335 |
" print(f\"[TIMING] uploaded timing.json to HF → {repo_id}/{run_id}/timing.json\")\n",
|
| 1336 |
" except Exception as e:\n",
|
|
|
|
| 1338 |
"\n",
|
| 1339 |
"\n",
|
| 1340 |
"_run_id_file = CKPT_ROOT / \"run_id.txt\"\n",
|
| 1341 |
+
"assert _run_id_file.exists(), \"No run_id.txt — run the train cell at least once.\"\n",
|
| 1342 |
"_run_id = _run_id_file.read_text().strip()\n",
|
| 1343 |
"\n",
|
| 1344 |
+
"# Pull latest timing.json from HF in case this is a fresh VM\n",
|
| 1345 |
"_hf_repo = getattr(train_cfg.hf_hub, \"repo_id\", None) if train_cfg.hf_hub.enabled else None\n",
|
| 1346 |
"_hf_token = os.environ.get(\"HF_TOKEN\")\n",
|
| 1347 |
"_pull_timing_from_hf(_run_id, CKPT_ROOT, _hf_repo, _hf_token)\n",
|
|
|
|
| 1349 |
"_timing_path = CKPT_ROOT / _run_id / \"timing.json\"\n",
|
| 1350 |
"assert _timing_path.exists(), (\n",
|
| 1351 |
" f\"No timing.json under {_timing_path.parent} (also not on HF). \"\n",
|
| 1352 |
+
" f\"Did the train cell run?\"\n",
|
| 1353 |
")\n",
|
| 1354 |
"\n",
|
| 1355 |
"_t = _json.loads(_timing_path.read_text())\n",
|
| 1356 |
"\n",
|
| 1357 |
"def _fmt(sec):\n",
|
| 1358 |
+
" h, r = divmod(int(sec or 0), 3600); m, s = divmod(r, 60); return f\"{h:d}h {m:02d}m {s:02d}s\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1359 |
"\n",
|
| 1360 |
"print(f\"Run : {_run_id}\")\n",
|
| 1361 |
"print(f\"First started at : {_t.get('first_started_at')}\")\n",
|
| 1362 |
"print(f\"Last finished at : {_t.get('last_finished_at')}\")\n",
|
| 1363 |
+
"print(f\"Session count : {_t.get('session_count', 0)}\")\n",
|
| 1364 |
+
"print(f\"TOTAL elapsed : {_fmt(_t.get('total_elapsed_sec', 0.0))}\")\n",
|
|
|
|
|
|
|
| 1365 |
"print()\n",
|
| 1366 |
"print(\"Session history :\")\n",
|
| 1367 |
"for _i, _s in enumerate(_t.get(\"session_history\", []), 1):\n",
|
| 1368 |
+
" print(f\" {_i:2d}. mode={_s.get('mode','?'):6s} {_fmt(_s['elapsed_sec'])} \"\n",
|
| 1369 |
+
" f\"{_s['started']} → {_s['finished']}\")\n"
|
| 1370 |
],
|
| 1371 |
"outputs": [],
|
| 1372 |
"execution_count": null
|
training/train.py
CHANGED
|
@@ -69,11 +69,19 @@ def parse_args():
|
|
| 69 |
)
|
| 70 |
parser.add_argument(
|
| 71 |
"--stage", type=int, default=None,
|
| 72 |
-
help="Run only stage 1 or stage 2 (default: run both)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
parser.add_argument(
|
| 75 |
"--resume_from", type=str, default=None,
|
| 76 |
-
help="Path to checkpoint to resume from"
|
| 77 |
)
|
| 78 |
parser.add_argument(
|
| 79 |
"--run_id", type=str, default=None,
|
|
@@ -87,6 +95,134 @@ def parse_args():
|
|
| 87 |
return parser.parse_args()
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def get_trainer(
|
| 91 |
model,
|
| 92 |
train_dataset,
|
|
@@ -469,11 +605,21 @@ def main():
|
|
| 469 |
train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
|
| 470 |
) if train_cfg.hf_hub.enabled else None
|
| 471 |
hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
run_id = resolve_run_id(
|
| 473 |
dataset_name = spec.dataset_name,
|
| 474 |
output_root = output_root,
|
| 475 |
state_file = state_file,
|
| 476 |
-
resuming =
|
| 477 |
explicit = args.run_id,
|
| 478 |
hf_repo_id = hf_repo_id,
|
| 479 |
hf_token = hf_token,
|
|
@@ -509,6 +655,35 @@ def main():
|
|
| 509 |
stage2_out = stage_dir(output_root, run_id,
|
| 510 |
str(train_cfg.stage2.get("subdir", "stage2_instruct")))
|
| 511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
# ── Snapshot resolved config into the run dir ────────────────────
|
| 513 |
# Every run gets its own self-describing folder so we never have to ask
|
| 514 |
# "what config did IU-Xray_run_3 actually use?" — open run_meta.json.
|
|
@@ -560,19 +735,41 @@ def main():
|
|
| 560 |
load_checkpoint(model, args.resume_from)
|
| 561 |
|
| 562 |
# Run training stages
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
if run_s1:
|
| 567 |
-
# Only pass resume_from to stage1 if stage1 is the explicit target
|
| 568 |
-
# (`--stage 1`). If user runs both stages with --resume_from, we assume
|
| 569 |
-
# the checkpoint is for stage2 and let stage1 run fresh (or finish fast
|
| 570 |
-
# if projection was already trained).
|
| 571 |
-
s1_resume = args.resume_from if args.stage == 1 else None
|
| 572 |
model = run_stage1(
|
| 573 |
model, train_cfg, model_cfg, spec, stage1_out, logger,
|
| 574 |
tracker = tracker,
|
| 575 |
-
resume_from =
|
| 576 |
)
|
| 577 |
|
| 578 |
if run_s2:
|
|
@@ -580,15 +777,17 @@ def main():
|
|
| 580 |
# Priority:
|
| 581 |
# 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt
|
| 582 |
# 2. Not running stage1 but stage1_final.pt exists on disk → load it
|
| 583 |
-
# 3.
|
|
|
|
|
|
|
| 584 |
stage1_ckpt = Path(stage1_out) / "stage1_final.pt"
|
| 585 |
if run_s1:
|
| 586 |
load_checkpoint(model, str(stage1_ckpt))
|
| 587 |
logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}")
|
| 588 |
-
elif stage1_ckpt.exists() and not
|
| 589 |
load_checkpoint(model, str(stage1_ckpt))
|
| 590 |
logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}")
|
| 591 |
-
elif not
|
| 592 |
logger.warning(
|
| 593 |
"⚠ No stage1 weights found and not resuming. Projection layer "
|
| 594 |
"will start RANDOMLY for stage2. Expect degraded convergence. "
|
|
@@ -597,7 +796,7 @@ def main():
|
|
| 597 |
|
| 598 |
model = run_stage2(
|
| 599 |
model, train_cfg, model_cfg, spec, stage2_out, logger,
|
| 600 |
-
resume_from =
|
| 601 |
tracker = tracker,
|
| 602 |
)
|
| 603 |
|
|
|
|
| 69 |
)
|
| 70 |
parser.add_argument(
|
| 71 |
"--stage", type=int, default=None,
|
| 72 |
+
help="Run only stage 1 or stage 2 (default: run both). With --mode resume, "
|
| 73 |
+
"the stage is auto-detected and this flag should be left unset."
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--mode", type=str, default=None, choices=["fresh", "resume"],
|
| 77 |
+
help="Unified resume controller. 'fresh' → new run_N folder. "
|
| 78 |
+
"'resume' → reuse latest matching run_id (or --run_id), auto-detect "
|
| 79 |
+
"which stage to continue from based on checkpoints on disk. "
|
| 80 |
+
"If unset, behaviour is inferred from --resume_from / --run_id (legacy)."
|
| 81 |
)
|
| 82 |
parser.add_argument(
|
| 83 |
"--resume_from", type=str, default=None,
|
| 84 |
+
help="Path to checkpoint to resume from (legacy; prefer --mode resume)"
|
| 85 |
)
|
| 86 |
parser.add_argument(
|
| 87 |
"--run_id", type=str, default=None,
|
|
|
|
| 95 |
return parser.parse_args()
|
| 96 |
|
| 97 |
|
| 98 |
+
# ─── Resume-point auto-detection ────────────────────────────────────────────
|
| 99 |
+
|
| 100 |
+
def _list_checkpoints(stage_dir):
|
| 101 |
+
"""Return [Path, …] of `checkpoint-NNN` folders sorted ascending by step."""
|
| 102 |
+
if not stage_dir.is_dir():
|
| 103 |
+
return []
|
| 104 |
+
out = []
|
| 105 |
+
for p in stage_dir.iterdir():
|
| 106 |
+
if not p.is_dir() or not p.name.startswith("checkpoint-"):
|
| 107 |
+
continue
|
| 108 |
+
suffix = p.name.split("-", 1)[1]
|
| 109 |
+
if suffix.isdigit():
|
| 110 |
+
out.append((int(suffix), p))
|
| 111 |
+
return [p for _, p in sorted(out)]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def detect_resume_point(run_dir_path, stage1_subdir, stage2_subdir):
|
| 115 |
+
"""
|
| 116 |
+
Inspect the run dir on disk and decide where to pick up training.
|
| 117 |
+
|
| 118 |
+
Returns a tuple `(target_stage, ckpt_path)` where:
|
| 119 |
+
target_stage : "stage1" | "stage2" | "done"
|
| 120 |
+
ckpt_path : Path to the checkpoint folder to pass to HF Trainer,
|
| 121 |
+
or None if the stage should start from scratch.
|
| 122 |
+
|
| 123 |
+
Priority:
|
| 124 |
+
1. stage 2 final saved → ("done", None) everything finished
|
| 125 |
+
2. stage 2 has ckpts → ("stage2", latest) resume mid-stage2
|
| 126 |
+
3. stage 1 final saved → ("stage2", None) stage 1 done; start stage 2
|
| 127 |
+
4. stage 1 has ckpts → ("stage1", latest) resume mid-stage1
|
| 128 |
+
5. otherwise → ("stage1", None) brand-new run
|
| 129 |
+
"""
|
| 130 |
+
from pathlib import Path as _P
|
| 131 |
+
run_dir_path = _P(run_dir_path)
|
| 132 |
+
s1d = run_dir_path / stage1_subdir
|
| 133 |
+
s2d = run_dir_path / stage2_subdir
|
| 134 |
+
|
| 135 |
+
if (s2d / "stage2_final_projection.pt").exists():
|
| 136 |
+
return ("done", None)
|
| 137 |
+
|
| 138 |
+
s2_ckpts = _list_checkpoints(s2d)
|
| 139 |
+
if s2_ckpts:
|
| 140 |
+
return ("stage2", s2_ckpts[-1])
|
| 141 |
+
|
| 142 |
+
if (s1d / "stage1_final_projection.pt").exists():
|
| 143 |
+
return ("stage2", None)
|
| 144 |
+
|
| 145 |
+
s1_ckpts = _list_checkpoints(s1d)
|
| 146 |
+
if s1_ckpts:
|
| 147 |
+
return ("stage1", s1_ckpts[-1])
|
| 148 |
+
|
| 149 |
+
return ("stage1", None)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def compute_training_plan(train_cfg, instruct_json_path):
|
| 153 |
+
"""
|
| 154 |
+
Compute a coarse plan of total optimizer steps across stage 1 + stage 2,
|
| 155 |
+
derived from the train_config + the train-split sample count in the
|
| 156 |
+
instruct JSON. Used to print a human-readable summary at startup.
|
| 157 |
+
|
| 158 |
+
Returns a dict (all ints) — gracefully handles missing fields.
|
| 159 |
+
"""
|
| 160 |
+
import json as _json
|
| 161 |
+
tr = train_cfg.training
|
| 162 |
+
try:
|
| 163 |
+
with open(instruct_json_path, "r", encoding="utf-8") as f:
|
| 164 |
+
all_samples = _json.load(f)
|
| 165 |
+
train_count = sum(1 for s in all_samples if s.get("split") == "train")
|
| 166 |
+
except Exception:
|
| 167 |
+
train_count = 0
|
| 168 |
+
|
| 169 |
+
bs = int(getattr(tr, "per_device_train_batch_size", 1))
|
| 170 |
+
ga = int(getattr(tr, "gradient_accumulation_steps", 1))
|
| 171 |
+
eff = max(1, bs * ga)
|
| 172 |
+
steps_per_epoch = max(1, (train_count + eff - 1) // eff)
|
| 173 |
+
|
| 174 |
+
s1_enabled = bool(getattr(train_cfg.stage1, "enabled", True))
|
| 175 |
+
s2_enabled = bool(getattr(train_cfg.stage2, "enabled", True))
|
| 176 |
+
s1_epochs = int(getattr(train_cfg.stage1, "num_epochs", 0)) if s1_enabled else 0
|
| 177 |
+
s2_epochs = int(getattr(train_cfg.stage2, "num_epochs", 0)) if s2_enabled else 0
|
| 178 |
+
|
| 179 |
+
s1_steps = steps_per_epoch * s1_epochs
|
| 180 |
+
s2_steps = steps_per_epoch * s2_epochs
|
| 181 |
+
return {
|
| 182 |
+
"train_samples": train_count,
|
| 183 |
+
"effective_batch": eff,
|
| 184 |
+
"steps_per_epoch": steps_per_epoch,
|
| 185 |
+
"stage1_steps": s1_steps,
|
| 186 |
+
"stage2_steps": s2_steps,
|
| 187 |
+
"total_steps": s1_steps + s2_steps,
|
| 188 |
+
"stage1_epochs": s1_epochs,
|
| 189 |
+
"stage2_epochs": s2_epochs,
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _fmt_plan_banner(plan, run_id, target_stage, resume_ckpt):
|
| 194 |
+
s1, s2, tot = plan["stage1_steps"], plan["stage2_steps"], plan["total_steps"]
|
| 195 |
+
head = f"TRAINING PLAN — {run_id}"
|
| 196 |
+
sep = "=" * max(len(head) + 4, 60)
|
| 197 |
+
cur = ""
|
| 198 |
+
if target_stage == "stage1":
|
| 199 |
+
offset = 0
|
| 200 |
+
if resume_ckpt and str(resume_ckpt).split("-")[-1].isdigit():
|
| 201 |
+
offset = int(str(resume_ckpt).split("-")[-1])
|
| 202 |
+
cur = f"Resuming at step {offset} / {tot} (inside stage 1)"
|
| 203 |
+
elif target_stage == "stage2":
|
| 204 |
+
offset = s1
|
| 205 |
+
if resume_ckpt and str(resume_ckpt).split("-")[-1].isdigit():
|
| 206 |
+
offset = s1 + int(str(resume_ckpt).split("-")[-1])
|
| 207 |
+
cur = f"Resuming at step {offset} / {tot} (inside stage 2)"
|
| 208 |
+
elif target_stage == "done":
|
| 209 |
+
cur = f"All {tot} steps already complete — nothing to do"
|
| 210 |
+
|
| 211 |
+
lines = [
|
| 212 |
+
sep, f" {head}", sep,
|
| 213 |
+
f" Train samples : {plan['train_samples']:,}",
|
| 214 |
+
f" Effective batch : {plan['effective_batch']}",
|
| 215 |
+
f" Steps / epoch : {plan['steps_per_epoch']}",
|
| 216 |
+
f" Stage 1 : {plan['stage1_epochs']} epochs → {s1} steps (global steps 1–{s1})",
|
| 217 |
+
f" Stage 2 : {plan['stage2_epochs']} epochs → {s2} steps (global steps {s1+1}–{tot})",
|
| 218 |
+
f" TOTAL : {tot} optimizer steps",
|
| 219 |
+
]
|
| 220 |
+
if cur:
|
| 221 |
+
lines += [" " + "─" * (len(sep) - 4), f" {cur}"]
|
| 222 |
+
lines.append(sep)
|
| 223 |
+
return "\n".join(lines)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
def get_trainer(
|
| 227 |
model,
|
| 228 |
train_dataset,
|
|
|
|
| 605 |
train_cfg.hf_hub.token_env, os.environ.get("HF_TOKEN")
|
| 606 |
) if train_cfg.hf_hub.enabled else None
|
| 607 |
hf_repo_id = train_cfg.hf_hub.repo_id if train_cfg.hf_hub.enabled else None
|
| 608 |
+
|
| 609 |
+
# Unified --mode controller. Falls back to the legacy inference (any of
|
| 610 |
+
# --resume_from / --resume_from_hf set ⇒ resuming) when --mode is unset.
|
| 611 |
+
if args.mode == "resume":
|
| 612 |
+
resuming = True
|
| 613 |
+
elif args.mode == "fresh":
|
| 614 |
+
resuming = False
|
| 615 |
+
else:
|
| 616 |
+
resuming = bool(args.resume_from) or args.resume_from_hf
|
| 617 |
+
|
| 618 |
run_id = resolve_run_id(
|
| 619 |
dataset_name = spec.dataset_name,
|
| 620 |
output_root = output_root,
|
| 621 |
state_file = state_file,
|
| 622 |
+
resuming = resuming,
|
| 623 |
explicit = args.run_id,
|
| 624 |
hf_repo_id = hf_repo_id,
|
| 625 |
hf_token = hf_token,
|
|
|
|
| 655 |
stage2_out = stage_dir(output_root, run_id,
|
| 656 |
str(train_cfg.stage2.get("subdir", "stage2_instruct")))
|
| 657 |
|
| 658 |
+
# ── Auto-detect where to resume from (when --mode resume) ─────────
|
| 659 |
+
# Examines disk state inside {output_root}/{run_id}/ and chooses:
|
| 660 |
+
# • stage1 from scratch / stage1 mid-checkpoint
|
| 661 |
+
# • stage2 from scratch (stage1 done) / stage2 mid-checkpoint
|
| 662 |
+
# • done (both stages finished — skip everything)
|
| 663 |
+
# If the user passed --stage explicitly, that wins over auto-detect.
|
| 664 |
+
auto_target_stage = None
|
| 665 |
+
auto_resume_ckpt = None
|
| 666 |
+
if args.mode == "resume" and args.stage is None:
|
| 667 |
+
auto_target_stage, auto_resume_ckpt = detect_resume_point(
|
| 668 |
+
run_dir(output_root, run_id),
|
| 669 |
+
str(train_cfg.stage1.get("subdir", "stage1_projection")),
|
| 670 |
+
str(train_cfg.stage2.get("subdir", "stage2_instruct")),
|
| 671 |
+
)
|
| 672 |
+
logger.info(
|
| 673 |
+
f"[resume autodetect] target={auto_target_stage} "
|
| 674 |
+
f"ckpt={auto_resume_ckpt}"
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# ── Pretty plan banner (total steps across both stages) ───────────
|
| 678 |
+
plan = compute_training_plan(train_cfg, spec.instruct_json)
|
| 679 |
+
logger.info("\n" + _fmt_plan_banner(plan, run_id,
|
| 680 |
+
auto_target_stage or "stage1",
|
| 681 |
+
auto_resume_ckpt))
|
| 682 |
+
|
| 683 |
+
if auto_target_stage == "done":
|
| 684 |
+
logger.info("Both stages already complete for this run. Exiting cleanly.")
|
| 685 |
+
return
|
| 686 |
+
|
| 687 |
# ── Snapshot resolved config into the run dir ────────────────────
|
| 688 |
# Every run gets its own self-describing folder so we never have to ask
|
| 689 |
# "what config did IU-Xray_run_3 actually use?" — open run_meta.json.
|
|
|
|
| 735 |
load_checkpoint(model, args.resume_from)
|
| 736 |
|
| 737 |
# Run training stages
|
| 738 |
+
#
|
| 739 |
+
# Stage selection priority:
|
| 740 |
+
# 1. Explicit --stage from CLI wins.
|
| 741 |
+
# 2. --mode resume + auto-detect: skip stage1 when its final ckpt exists,
|
| 742 |
+
# resume stage1/stage2 from `auto_resume_ckpt` as detected above.
|
| 743 |
+
# 3. Otherwise: enabled flags from train_cfg drive it (legacy: run both).
|
| 744 |
+
if args.stage is not None:
|
| 745 |
+
run_s1 = (args.stage == 1) and train_cfg.stage1.enabled
|
| 746 |
+
run_s2 = (args.stage == 2) and train_cfg.stage2.enabled
|
| 747 |
+
elif auto_target_stage == "stage2":
|
| 748 |
+
# Stage 1 finished previously — skip it entirely.
|
| 749 |
+
run_s1 = False
|
| 750 |
+
run_s2 = train_cfg.stage2.enabled
|
| 751 |
+
else:
|
| 752 |
+
run_s1 = train_cfg.stage1.enabled
|
| 753 |
+
run_s2 = train_cfg.stage2.enabled
|
| 754 |
+
|
| 755 |
+
# Decide the resume checkpoint each stage should use.
|
| 756 |
+
# Manual --resume_from still wins when --stage is given explicitly.
|
| 757 |
+
s1_resume_path = None
|
| 758 |
+
s2_resume_path = None
|
| 759 |
+
if args.stage == 1:
|
| 760 |
+
s1_resume_path = args.resume_from
|
| 761 |
+
elif args.stage == 2:
|
| 762 |
+
s2_resume_path = args.resume_from
|
| 763 |
+
elif auto_target_stage == "stage1":
|
| 764 |
+
s1_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
|
| 765 |
+
elif auto_target_stage == "stage2":
|
| 766 |
+
s2_resume_path = str(auto_resume_ckpt) if auto_resume_ckpt else None
|
| 767 |
|
| 768 |
if run_s1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
model = run_stage1(
|
| 770 |
model, train_cfg, model_cfg, spec, stage1_out, logger,
|
| 771 |
tracker = tracker,
|
| 772 |
+
resume_from = s1_resume_path,
|
| 773 |
)
|
| 774 |
|
| 775 |
if run_s2:
|
|
|
|
| 777 |
# Priority:
|
| 778 |
# 1. Just finished stage1 in this run → use stage1_out/stage1_final.pt
|
| 779 |
# 2. Not running stage1 but stage1_final.pt exists on disk → load it
|
| 780 |
+
# 3. s2_resume_path set (we're mid-stage2) → Trainer will reload from
|
| 781 |
+
# the checkpoint itself; no need to seed stage1 weights here.
|
| 782 |
+
# 4. Nothing → warn loudly; stage2 starts with random projection.
|
| 783 |
stage1_ckpt = Path(stage1_out) / "stage1_final.pt"
|
| 784 |
if run_s1:
|
| 785 |
load_checkpoint(model, str(stage1_ckpt))
|
| 786 |
logger.info(f"Loaded stage1 weights from this run: {stage1_ckpt}")
|
| 787 |
+
elif stage1_ckpt.exists() and not s2_resume_path:
|
| 788 |
load_checkpoint(model, str(stage1_ckpt))
|
| 789 |
logger.info(f"Auto-loaded existing stage1 weights: {stage1_ckpt}")
|
| 790 |
+
elif not s2_resume_path:
|
| 791 |
logger.warning(
|
| 792 |
"⚠ No stage1 weights found and not resuming. Projection layer "
|
| 793 |
"will start RANDOMLY for stage2. Expect degraded convergence. "
|
|
|
|
| 796 |
|
| 797 |
model = run_stage2(
|
| 798 |
model, train_cfg, model_cfg, spec, stage2_out, logger,
|
| 799 |
+
resume_from = s2_resume_path,
|
| 800 |
tracker = tracker,
|
| 801 |
)
|
| 802 |
|