training: push adapters to HF Hub after SFT + each GRPO stage
Browse filesHF Spaces /data is ephemeral and silently got wiped mid-run during
testing, killing the SFT adapter and the stage1 GRPO checkpoint.
Push each adapter to a dedicated HF Hub model repo as soon as it's
saved:
shivam2k3/opensoc-defender-grpo-sft (after SFT)
shivam2k3/opensoc-defender-grpo-stage1_basic (after stage1)
shivam2k3/opensoc-defender-grpo-stage2_multi (after stage2)
...
shivam2k3/opensoc-defender-grpo (final, written by
scripts/run_full_pipeline.sh)
A later restart can re-run on a fresh container by downloading the
latest bookmark instead of redoing 12 min of SFT or 16 min per GRPO
stage.
Made-with: Cursor
- scripts/run_full_pipeline.sh +18 -0
- train/train_grpo.py +25 -2
scripts/run_full_pipeline.sh
CHANGED
|
@@ -46,6 +46,24 @@ python -m train.sft_warmstart \
|
|
| 46 |
--epochs 1 --batch-size 4 --grad-accum 4 --lr 2e-4 \
|
| 47 |
--out checkpoints/defender_sft_adapter
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
echo "[4/6] GRPO curriculum (~3 hr on L4) ..."
|
| 50 |
python -m train.train_grpo \
|
| 51 |
--sft-adapter checkpoints/defender_sft_adapter \
|
|
|
|
| 46 |
--epochs 1 --batch-size 4 --grad-accum 4 --lr 2e-4 \
|
| 47 |
--out checkpoints/defender_sft_adapter
|
| 48 |
|
| 49 |
+
# Cheap insurance: HF Spaces may wipe /data mid-run. Push the SFT
|
| 50 |
+
# adapter to a dedicated HF Hub model repo immediately so a restart
|
| 51 |
+
# can skip step 3 by re-downloading it.
|
| 52 |
+
if [ -n "${HF_TOKEN:-}" ] && [ -n "${HF_ADAPTER_REPO:-}" ]; then
|
| 53 |
+
echo " pushing SFT adapter to ${HF_ADAPTER_REPO}-sft for resilience ..."
|
| 54 |
+
python -c "
|
| 55 |
+
import os
|
| 56 |
+
from huggingface_hub import HfApi, upload_folder
|
| 57 |
+
token = os.environ['HF_TOKEN']
|
| 58 |
+
sft_repo = os.environ['HF_ADAPTER_REPO'] + '-sft'
|
| 59 |
+
api = HfApi(token=token)
|
| 60 |
+
api.create_repo(sft_repo, exist_ok=True, private=False)
|
| 61 |
+
upload_folder(repo_id=sft_repo, folder_path='checkpoints/defender_sft_adapter',
|
| 62 |
+
commit_message='SFT warm-start checkpoint (resume bookmark)', token=token)
|
| 63 |
+
print(' SFT bookmark ->', sft_repo)
|
| 64 |
+
"
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
echo "[4/6] GRPO curriculum (~3 hr on L4) ..."
|
| 68 |
python -m train.train_grpo \
|
| 69 |
--sft-adapter checkpoints/defender_sft_adapter \
|
train/train_grpo.py
CHANGED
|
@@ -182,8 +182,31 @@ def main() -> None:
|
|
| 182 |
callbacks=[json_logger],
|
| 183 |
)
|
| 184 |
trainer.train()
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
print(f"GRPO curriculum complete. Final adapter: {out_dir}/{stages[-1]}/adapter")
|
| 189 |
|
|
|
|
| 182 |
callbacks=[json_logger],
|
| 183 |
)
|
| 184 |
trainer.train()
|
| 185 |
+
adapter_path = os.path.join(out_dir, stage_id, "adapter")
|
| 186 |
+
model.save_pretrained(adapter_path)
|
| 187 |
+
print(f"Saved {stage_id} adapter to {adapter_path}")
|
| 188 |
+
|
| 189 |
+
# Resilience: HF Spaces /data is ephemeral, so push each stage
|
| 190 |
+
# adapter to a stage-specific HF Hub repo path immediately. A
|
| 191 |
+
# later /data wipe can then resume from the latest stage by
|
| 192 |
+
# downloading from the Hub.
|
| 193 |
+
token = os.environ.get("HF_TOKEN")
|
| 194 |
+
adapter_repo = os.environ.get("HF_ADAPTER_REPO")
|
| 195 |
+
if token and adapter_repo:
|
| 196 |
+
try:
|
| 197 |
+
from huggingface_hub import HfApi, upload_folder
|
| 198 |
+
staged_repo = f"{adapter_repo}-{stage_id}"
|
| 199 |
+
api = HfApi(token=token)
|
| 200 |
+
api.create_repo(staged_repo, exist_ok=True, private=False)
|
| 201 |
+
upload_folder(
|
| 202 |
+
repo_id=staged_repo,
|
| 203 |
+
folder_path=adapter_path,
|
| 204 |
+
commit_message=f"GRPO {stage_id} adapter checkpoint",
|
| 205 |
+
token=token,
|
| 206 |
+
)
|
| 207 |
+
print(f" bookmark -> https://huggingface.co/{staged_repo}")
|
| 208 |
+
except Exception as e:
|
| 209 |
+
print(f" (per-stage push failed: {e}); continuing")
|
| 210 |
|
| 211 |
print(f"GRPO curriculum complete. Final adapter: {out_dir}/{stages[-1]}/adapter")
|
| 212 |
|