| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| set -euo pipefail |
|
|
| export HF_HUB_ENABLE_HF_TRANSFER=0 |
|
|
| echo "[1/6] Installing GPU stack ..." |
| pip install -q --upgrade pip setuptools wheel |
| pip install -q "future>=1.0.0" |
| pip install -q "torch==2.10.0" "torchvision==0.25.0" \ |
| --index-url https://download.pytorch.org/whl/cu128 |
| pip install -q "transformers>=4.46,<5" |
| pip install -q --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git" |
| pip install -q --no-deps "unsloth_zoo" |
| pip install -q --no-deps "trl>=0.18.2,<=0.24.0" peft accelerate bitsandbytes |
| pip install -q "datasets>=3.4.1,<4.4.0" tyro tensorboard matplotlib sentencepiece protobuf huggingface_hub |
| pip install -q hf_transfer msgspec "torchao>=0.13.0" cut_cross_entropy || true |
| pip install -q -r requirements.txt |
|
|
| echo "[2/6] Building / verifying datasets ..." |
| python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl |
| python -m eval.make_holdout --out data/holdout.jsonl |
|
|
| echo "[3/6] Skipped (SFT + Stages 1,2 already on Hub)." |
|
|
| echo "[3.5/6] Pulling SFT and Stage 2 adapters from the Hub ..." |
| python - <<'PY' |
| import os |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" |
| from huggingface_hub import snapshot_download |
| sft = snapshot_download( |
| "shivam2k3/opensoc-defender-grpo-sft", |
| local_dir="checkpoints/defender_sft_adapter", |
| local_dir_use_symlinks=False, |
| ) |
| print("SFT adapter ->", sft) |
| stage2 = snapshot_download( |
| "shivam2k3/opensoc-defender-grpo-stage2_multi", |
| local_dir="checkpoints/defender_grpo/stage2_multi/adapter", |
| local_dir_use_symlinks=False, |
| ) |
| print("Stage 2 adapter ->", stage2) |
| PY |
|
|
| echo "[4/6] Resuming GRPO from stage 3 (2 remaining stages, 150 steps each) ..." |
| python -m train.train_grpo \ |
| --model unsloth/Qwen2.5-3B-Instruct \ |
| --sft-adapter checkpoints/defender_grpo/stage2_multi/adapter \ |
| --stages stage3_mixed,stage4_adversarial \ |
| --steps-per-stage 150 --num-generations 8 \ |
| --batch-size 2 --grad-accum 4 --lr 5e-6 \ |
| --report-to tensorboard \ |
| --out checkpoints/defender_grpo |
|
|
| echo "[5/6] Eval + plots ..." |
| python -m eval.eval \ |
| --baseline unsloth/Qwen2.5-3B-Instruct \ |
| --trained-adapter checkpoints/defender_grpo/stage4_adversarial/adapter \ |
| --holdout data/holdout.jsonl --out-dir eval/results |
| python -m eval.plot_results --in eval/results/summary.json --out-dir eval/results |
| python -m eval.plot_training --grpo-root checkpoints/defender_grpo --out-dir eval/results |
|
|
| echo "[6/6] Baking demo data for the Gradio /demo Space ..." |
| python -m eval.bake_demo \ |
| --baseline unsloth/Qwen2.5-3B-Instruct \ |
| --trained-adapter checkpoints/defender_grpo/stage4_adversarial/adapter \ |
| --n 50 --out data/demo_examples.json |
|
|
| if [ -n "${HF_TOKEN:-}" ] && [ -n "${HF_PUSH_TARGET:-}" ]; then |
| echo "[7/7] Uploading artifacts back to ${HF_PUSH_TARGET} ..." |
| python - <<'PY' |
| import os |
| from huggingface_hub import HfApi, upload_folder |
| token = os.environ["HF_TOKEN"] |
| target = os.environ["HF_PUSH_TARGET"] |
| adapter_repo = os.environ.get( |
| "HF_ADAPTER_REPO", target.split("/")[0] + "/opensoc-defender-grpo" |
| ) |
| api = HfApi(token=token) |
| adapter_dir = "checkpoints/defender_grpo/stage4_adversarial/adapter" |
| if os.path.isdir(adapter_dir): |
| api.create_repo(adapter_repo, exist_ok=True, private=False) |
| upload_folder(repo_id=adapter_repo, folder_path=adapter_dir, |
| commit_message="GRPO-trained Qwen2.5-3B-Instruct LoRA defender adapter", |
| token=token) |
| print(" adapter ->", adapter_repo) |
| for p in [ |
| "data/demo_examples.json", |
| "eval/results/summary.json", |
| "eval/results/bar_macro_f1.png", |
| "eval/results/bar_dismiss_on_malicious.png", |
| "eval/results/confusion_baseline_zero_shot.png", |
| "eval/results/confusion_opensoc_grpo.png", |
| "eval/results/training_curves.png", |
| "eval/results/training_kl_loss.png", |
| ]: |
| if os.path.exists(p): |
| api.upload_file( |
| path_or_fileobj=p, path_in_repo=p, |
| repo_id=target, |
| commit_message="trained: refresh " + os.path.basename(p), |
| token=token, |
| ) |
| print(" ", p, "->", target) |
| print("PIPELINE_DONE") |
| PY |
| fi |
|
|