opensoc-env / scripts /run_resume_stage3.sh
shivam2k3's picture
add scripts/run_resume_stage3.sh
76ba348
#!/usr/bin/env bash
# Resume the GRPO curriculum from stage 3.
# Pulls SFT and stage2_multi adapters from the Hub, then runs the
# remaining stages (3 and 4), eval, bake-demo, and final push.
#
# steps_per_stage is intentionally lowered to 150 (vs 200 elsewhere)
# because we want to finish before the next /data wipe.
set -euo pipefail
export HF_HUB_ENABLE_HF_TRANSFER=0
echo "[1/6] Installing GPU stack ..."
pip install -q --upgrade pip setuptools wheel
pip install -q "future>=1.0.0"
pip install -q "torch==2.10.0" "torchvision==0.25.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install -q "transformers>=4.46,<5"
pip install -q --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install -q --no-deps "unsloth_zoo"
pip install -q --no-deps "trl>=0.18.2,<=0.24.0" peft accelerate bitsandbytes
pip install -q "datasets>=3.4.1,<4.4.0" tyro tensorboard matplotlib sentencepiece protobuf huggingface_hub
pip install -q hf_transfer msgspec "torchao>=0.13.0" cut_cross_entropy || true
pip install -q -r requirements.txt
echo "[2/6] Building / verifying datasets ..."
python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl
python -m eval.make_holdout --out data/holdout.jsonl
echo "[3/6] Skipped (SFT + Stages 1,2 already on Hub)."
echo "[3.5/6] Pulling SFT and Stage 2 adapters from the Hub ..."
python - <<'PY'
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
from huggingface_hub import snapshot_download
sft = snapshot_download(
"shivam2k3/opensoc-defender-grpo-sft",
local_dir="checkpoints/defender_sft_adapter",
local_dir_use_symlinks=False,
)
print("SFT adapter ->", sft)
stage2 = snapshot_download(
"shivam2k3/opensoc-defender-grpo-stage2_multi",
local_dir="checkpoints/defender_grpo/stage2_multi/adapter",
local_dir_use_symlinks=False,
)
print("Stage 2 adapter ->", stage2)
PY
echo "[4/6] Resuming GRPO from stage 3 (2 remaining stages, 150 steps each) ..."
python -m train.train_grpo \
--model unsloth/Qwen2.5-3B-Instruct \
--sft-adapter checkpoints/defender_grpo/stage2_multi/adapter \
--stages stage3_mixed,stage4_adversarial \
--steps-per-stage 150 --num-generations 8 \
--batch-size 2 --grad-accum 4 --lr 5e-6 \
--report-to tensorboard \
--out checkpoints/defender_grpo
echo "[5/6] Eval + plots ..."
python -m eval.eval \
--baseline unsloth/Qwen2.5-3B-Instruct \
--trained-adapter checkpoints/defender_grpo/stage4_adversarial/adapter \
--holdout data/holdout.jsonl --out-dir eval/results
python -m eval.plot_results --in eval/results/summary.json --out-dir eval/results
python -m eval.plot_training --grpo-root checkpoints/defender_grpo --out-dir eval/results
echo "[6/6] Baking demo data for the Gradio /demo Space ..."
python -m eval.bake_demo \
--baseline unsloth/Qwen2.5-3B-Instruct \
--trained-adapter checkpoints/defender_grpo/stage4_adversarial/adapter \
--n 50 --out data/demo_examples.json
if [ -n "${HF_TOKEN:-}" ] && [ -n "${HF_PUSH_TARGET:-}" ]; then
echo "[7/7] Uploading artifacts back to ${HF_PUSH_TARGET} ..."
python - <<'PY'
import os
from huggingface_hub import HfApi, upload_folder
token = os.environ["HF_TOKEN"]
target = os.environ["HF_PUSH_TARGET"]
adapter_repo = os.environ.get(
"HF_ADAPTER_REPO", target.split("/")[0] + "/opensoc-defender-grpo"
)
api = HfApi(token=token)
adapter_dir = "checkpoints/defender_grpo/stage4_adversarial/adapter"
if os.path.isdir(adapter_dir):
api.create_repo(adapter_repo, exist_ok=True, private=False)
upload_folder(repo_id=adapter_repo, folder_path=adapter_dir,
commit_message="GRPO-trained Qwen2.5-3B-Instruct LoRA defender adapter",
token=token)
print(" adapter ->", adapter_repo)
for p in [
"data/demo_examples.json",
"eval/results/summary.json",
"eval/results/bar_macro_f1.png",
"eval/results/bar_dismiss_on_malicious.png",
"eval/results/confusion_baseline_zero_shot.png",
"eval/results/confusion_opensoc_grpo.png",
"eval/results/training_curves.png",
"eval/results/training_kl_loss.png",
]:
if os.path.exists(p):
api.upload_file(
path_or_fileobj=p, path_in_repo=p,
repo_id=target,
commit_message="trained: refresh " + os.path.basename(p),
token=token,
)
print(" ", p, "->", target)
print("PIPELINE_DONE")
PY
fi