Commit ·
dbc69f3
0
Parent(s):
Export neuralese codebase (cache and .env excluded).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +108 -0
- .gitignore +15 -0
- AGENTS.md +56 -0
- README.md +164 -0
- configs/accelerate_ddp_2gpu.yaml +17 -0
- configs/accelerate_ddp_4gpu.yaml +17 -0
- configs/accelerate_ddp_7gpu.yaml +17 -0
- configs/accelerate_ddp_8gpu.yaml +17 -0
- configs/grpo_llama32_3b_bf16.yaml +70 -0
- docs/repository-map.md +50 -0
- out.jsonl +0 -0
- requirements.txt +9 -0
- scripts/conda_env.sh +40 -0
- scripts/eval_gsm8k_think_entropy_4gpu.sh +42 -0
- scripts/eval_gsm8k_truncated_thinking.sh +51 -0
- scripts/eval_gsm8k_zeroshot_4gpu.sh +36 -0
- scripts/eval_gsm8k_zeroshot_train_4gpu.sh +54 -0
- scripts/eval_length_penalty_ablation_offline.sh +39 -0
- scripts/eval_math_level1_4gpu.sh +55 -0
- scripts/eval_math_level1_thinking_zeroshot_4gpu.sh +48 -0
- scripts/eval_math_level1_zeroshot_4gpu.sh +49 -0
- scripts/eval_math_level2_thinking_zeroshot_4gpu.sh +48 -0
- scripts/eval_math_level2_zeroshot_4gpu.sh +49 -0
- scripts/eval_permanent_root_acc_cot.sh +53 -0
- scripts/eval_sweep_models_offline.sh +34 -0
- scripts/eval_twostage_permanent_checkpoints.sh +47 -0
- scripts/hf_download_repo.py +72 -0
- scripts/hf_upload_repo.py +128 -0
- scripts/resume_grpo_8gpu.sh +27 -0
- scripts/run_grpo.sh +19 -0
- scripts/run_grpo_2gpu.sh +22 -0
- scripts/run_grpo_4gpu.sh +22 -0
- scripts/run_grpo_8gpu.sh +25 -0
- scripts/run_grpo_thinking_kl_masked_resume_4gpu.sh +139 -0
- scripts/run_lambda_0p1_existing_gate_token_util.sh +107 -0
- scripts/run_reward_variants_and_eval.sh +146 -0
- scripts/run_sft_gsm8k_boxed_7gpu.sh +44 -0
- scripts/run_twostage_correctness1.sh +79 -0
- scripts/run_twostage_correctness5.sh +79 -0
- scripts/sweep_length_penalty_lambda.sh +88 -0
- src/eval_gsm8k_truncated_thinking.py +321 -0
- src/eval_gsm8k_zeroshot.py +76 -0
- src/eval_math_level1_thinking_zeroshot.py +312 -0
- src/eval_permanent_checkpoints.py +433 -0
- src/eval_sweep_models.py +386 -0
- src/hackable/__init__.py +6 -0
- src/hackable/backends.py +62 -0
- src/hackable/config.py +183 -0
- src/hackable/data_plugins.py +291 -0
- src/hackable/interfaces.py +48 -0
.gitattributes
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.avro filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mds filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
# Audio files - uncompressed
|
| 40 |
+
*.pcm filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.sam filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.raw filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
# Audio files - compressed
|
| 44 |
+
*.aac filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
# Image files - uncompressed
|
| 50 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
*.tiff filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
# Image files - compressed
|
| 55 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
# Video files - compressed
|
| 59 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/blobs/9151ba0c695e63ba3dea7436a55bb6aa9f2d0a3d7b7eead62086959354bd6c67 filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/blobs/6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/snapshots/19846d3f624f3eb96f3bdd275620c6bc7e21e1f8/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
cache/artifacts/runs/llama32-3b-grpo/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
cache/artifacts/runs/llama32-3b-grpo/checkpoint-275/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
cache/cache/artifacts/sweeps/reward_variants_lambda_0p1/run_lambda_0p1_existing_gate_token_util/checkpoint-1350/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/checkpoint-3736/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/permanent_checkpoints/checkpoint-stage1-boundary-epoch-1p0-step-1868/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoint-2925/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/permanent_checkpoints/checkpoint-stage1-boundary-epoch-1p0-step-1868/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
cache/wandb/wandb/offline-run-20260218_164737-w4awch9k/run-w4awch9k.wandb filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
cache/wandb/wandb/offline-run-20260218_172032-b9o54gvf/run-b9o54gvf.wandb filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
cache/wandb/wandb/offline-run-20260218_170233-ydg9el7o/run-ydg9el7o.wandb filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
cache/wandb/wandb/offline-run-20260218_165455-zj6y61aw/run-zj6y61aw.wandb filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
cache/wandb/wandb/offline-run-20260326_035115-p68cj9a2/run-p68cj9a2.wandb filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
cache/wandb/wandb/offline-run-20260331_190529-w0p8d22o/run-w0p8d22o.wandb filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
cache/wandb/wandb/offline-run-20260326_235622-lo899m3z/run-lo899m3z.wandb filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
cache/wandb/wandb/run-20260218_150224-l50vxprm/run-l50vxprm.wandb filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
cache/wandb/wandb/offline-run-20260327_235926-kx2wc1tp/run-kx2wc1tp.wandb filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
cache/wandb/wandb/offline-run-20260329_125618-21yt0wn1/run-21yt0wn1.wandb filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
cache/wandb/wandb/run-20260218_145059-onpmdiea/run-onpmdiea.wandb filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
cache/wandb/wandb/offline-run-20260218_172557-kst72vbj/run-kst72vbj.wandb filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
cache/wandb/wandb/offline-run-20260325_203352-8ncdpjq1/run-8ncdpjq1.wandb filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
cache/wandb/wandb/offline-run-20260329_125618-nuzntjga/run-nuzntjga.wandb filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
cache/wandb/wandb/offline-run-20260328_025352-z84jpmb8/run-z84jpmb8.wandb filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
cache/wandb/wandb/offline-run-20260328_142710-or2vd7o6/run-or2vd7o6.wandb filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
cache/wandb/wandb/offline-run-20260329_124408-koirzdhw/run-koirzdhw.wandb filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
cache/wandb/wandb/offline-run-20260325_052529-dz6rih26/run-dz6rih26.wandb filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
cache/wandb/wandb/offline-run-20260325_015712-7ialnk1w/run-7ialnk1w.wandb filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
cache/wandb/wandb/offline-run-20260327_235600-s2pr0n6l/run-s2pr0n6l.wandb filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
cache/wandb/wandb/offline-run-20260218_154633-igplmjku/run-igplmjku.wandb filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
cache/wandb/wandb/offline-run-20260218_204810-jg5x3bn6/run-jg5x3bn6.wandb filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
cache/wandb/wandb/offline-run-20260218_205814-3cvhbicy/run-3cvhbicy.wandb filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
cache/wandb/wandb/offline-run-20260326_000646-w5fgyzj3/run-w5fgyzj3.wandb filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
cache/wandb/wandb/offline-run-20260326_201451-g14348cv/run-g14348cv.wandb filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
neuralese_latent/wandb/offline-run-20260219_214352-58nmo2ma/run-58nmo2ma.wandb filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
cache/models/datasets--openai--gsm8k/blobs/ee7b8da9e381df27b9e3f7758a159ab2bdaa4dbaa910546cbbc47e0cb44e4f59 filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
cache/models/datasets--openai--gsm8k/blobs/ea82612ea9582142387730c793eb67d3b12849002bc0b7fa6f8efafa7351419d filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/7b770216613ac5c34d7c54bdff1fa616bc4e338a9d0b20af6303e48c295ee23c filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/13cbd6d16e927a0c5bad54102514e6e18b4a47b3a6eb911e39d678d328d19f55 filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
cache/models/models--unsloth--Llama-3.2-3B-Instruct/snapshots/006f5dcd1393c3add266de40994ba96225e9689d/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
cache/wandb/wandb/offline-run-20260219_155344-i4fzutup/run-i4fzutup.wandb filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
cache/wandb/wandb/offline-run-20260328_061445-eyjtftlz/run-eyjtftlz.wandb filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
cache/wandb/wandb/offline-run-20260327_234927-gbd6yui5/run-gbd6yui5.wandb filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
cache/wandb/wandb/offline-run-20260326_154346-zf18ek2n/run-zf18ek2n.wandb filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
cache/wandb/wandb/offline-run-20260325_011840-yvcn1q14/run-yvcn1q14.wandb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
|
| 3 |
+
# Large local caches and artifacts (not for Hub)
|
| 4 |
+
cache/
|
| 5 |
+
**/__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
.pytest_cache/
|
| 8 |
+
.mypy_cache/
|
| 9 |
+
.ruff_cache/
|
| 10 |
+
*.egg-info/
|
| 11 |
+
.eggs/
|
| 12 |
+
.wandb/
|
| 13 |
+
*.swp
|
| 14 |
+
.DS_Store
|
| 15 |
+
|
AGENTS.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AGENTS.md
|
| 2 |
+
|
| 3 |
+
## Purpose
|
| 4 |
+
|
| 5 |
+
This repo is a GRPO-style math-reasoning training workspace. Agents should preserve a clean split between tracked source code and untracked runtime artifacts.
|
| 6 |
+
|
| 7 |
+
## Canonical Layout
|
| 8 |
+
|
| 9 |
+
Tracked code and docs:
|
| 10 |
+
|
| 11 |
+
- `src/`: Python entrypoints and the `hackable` library.
|
| 12 |
+
- `configs/`: Accelerate and experiment YAMLs.
|
| 13 |
+
- `scripts/`: launchers and workflow helpers.
|
| 14 |
+
- `docs/`: durable repo documentation.
|
| 15 |
+
- `README.md`: user-facing overview.
|
| 16 |
+
|
| 17 |
+
Runtime storage under `cache/`:
|
| 18 |
+
|
| 19 |
+
- `cache/datasets/`: Hugging Face dataset cache and prepared dataset assets.
|
| 20 |
+
- `cache/models/`: local model weights, checkpoint loads, tokenizer cache.
|
| 21 |
+
- `cache/hf/`: extra Hugging Face runtime cache.
|
| 22 |
+
- `cache/artifacts/runs/`: single training or SFT runs.
|
| 23 |
+
- `cache/artifacts/sweeps/`: parameter sweeps and multi-run experiments.
|
| 24 |
+
- `cache/artifacts/eval/`: evaluation outputs, JSONL, CSV, plots.
|
| 25 |
+
- `cache/logs/wandb/`: W&B offline or online run logs.
|
| 26 |
+
|
| 27 |
+
Within a run directory, use:
|
| 28 |
+
|
| 29 |
+
- `checkpoint-*`: rotating trainer checkpoints.
|
| 30 |
+
- `checkpoints/permanent/`: non-rotating full-model snapshots saved by the callback.
|
| 31 |
+
- run-local analysis files only when they belong to that run.
|
| 32 |
+
|
| 33 |
+
## Path Rules
|
| 34 |
+
|
| 35 |
+
- Treat `storage.cache_dir` as the runtime root. Relative runtime paths such as `artifacts/runs/foo` resolve under that root.
|
| 36 |
+
- Do not introduce new top-level folders for datasets, checkpoints, models, or logs.
|
| 37 |
+
- Prefer `hackable.paths.resolve_storage_path()` and `hackable.paths.storage_layout()` over manual string concatenation.
|
| 38 |
+
- Keep tracked docs and code out of `cache/`. Keep large generated outputs out of `src/`, `scripts/`, and the repo root.
|
| 39 |
+
|
| 40 |
+
## Code Map
|
| 41 |
+
|
| 42 |
+
- `src/train_grpo.py`: main GRPO training entrypoint.
|
| 43 |
+
- `src/sft_gsm8k_boxed.py`: SFT baseline for boxed-only GSM8K.
|
| 44 |
+
- `src/eval_*.py`: evaluation entrypoints.
|
| 45 |
+
- `src/hackable/config.py`: YAML dataclasses.
|
| 46 |
+
- `src/hackable/paths.py`: canonical storage layout helpers.
|
| 47 |
+
- `src/hackable/data_plugins.py`: dataset providers.
|
| 48 |
+
- `src/hackable/reward_plugins.py`: reward functions.
|
| 49 |
+
- `src/hackable/objectives.py`: reward combination logic.
|
| 50 |
+
- `src/hackable/permanent_checkpoint_callback.py`: permanent checkpoint writer.
|
| 51 |
+
|
| 52 |
+
## Working Agreement
|
| 53 |
+
|
| 54 |
+
- Before changing structure, inspect existing path assumptions in both `src/` and `scripts/`.
|
| 55 |
+
- When adding a new workflow, document it in `README.md` or `docs/repository-map.md`.
|
| 56 |
+
- Preserve user changes already present in the worktree unless explicitly asked to rewrite them.
|
README.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Neuralese / hackable GRPO baseline
|
| 2 |
+
|
| 3 |
+
This repository is a **hackable GRPO-style training baseline** for math reasoning with a fixed chain-of-thought format: reasoning inside `<redacted_thinking>...</redacted_thinking>` and a final `\boxed{...}` answer. The design keeps **objectives**, **rewards**, and **data providers** in small pluggable modules so you can experiment without rewriting a monolithic trainer.
|
| 4 |
+
|
| 5 |
+
## Repository layout
|
| 6 |
+
|
| 7 |
+
| Path | Role |
|
| 8 |
+
|------|------|
|
| 9 |
+
| `src/hackable/` | Core library: config types, registries, objectives, rewards, data providers, model loading. |
|
| 10 |
+
| `src/train_grpo.py` | Training entrypoint expected by `scripts/*.sh` (may be absent in a partial checkout; scripts default to `--config configs/grpo_llama32_3b_bf16.yaml`). |
|
| 11 |
+
| `docs/repository-map.md` | Maintainer-focused map of entrypoints, modules, and storage conventions. |
|
| 12 |
+
| `AGENTS.md` | Agent-facing workflow and folder contract for future repo changes. |
|
| 13 |
+
| `configs/` | YAML experiment configs (referenced by scripts; may need to be added locally). |
|
| 14 |
+
| `scripts/` | Bash launchers for single/multi-GPU training, sweeps, and evaluation. |
|
| 15 |
+
| `cache/` | Runtime storage root: `datasets/`, `models/`, `hf/`, `artifacts/{runs,sweeps,eval}/`, and `logs/wandb/`. |
|
| 16 |
+
| `requirements.txt` | Python dependencies (Torch, TRL, Transformers, Accelerate, etc.). |
|
| 17 |
+
| `note.txt` | Unrelated environment/pip noise from a past install; not project documentation. |
|
| 18 |
+
|
| 19 |
+
### Storage conventions
|
| 20 |
+
|
| 21 |
+
- `storage.cache_dir` is the canonical runtime root. Relative runtime paths like `artifacts/runs/grpo-llama32-3b` resolve under that directory.
|
| 22 |
+
- Use `cache/datasets` for dataset cache, `cache/models` for model/tokenizer cache, and `cache/logs/wandb` for W&B logs.
|
| 23 |
+
- Training runs live in `cache/artifacts/runs`, sweep outputs in `cache/artifacts/sweeps`, eval outputs in `cache/artifacts/eval`.
|
| 24 |
+
- Permanent checkpoints now live under each run at `checkpoints/permanent/`.
|
| 25 |
+
|
| 26 |
+
### `src/hackable/` modules
|
| 27 |
+
|
| 28 |
+
- **`config.py`** — Loads YAML into typed dataclasses (`ExperimentConfig`, `ModelConfig`, `TrainerConfig`, `GenerationConfig`, `ObjectiveConfig`, `RewardsConfig`, …). Normalizes optimizer/scheduler aliases and numeric fields.
|
| 29 |
+
- **`registry.py`** — `@register_data_provider`, `@register_reward`, `@register_objective` plus `build_*` factories. Reward kwargs from YAML are partially applied via `build_reward` so each reward can receive static options (e.g. tokenizer name for length penalty).
|
| 30 |
+
- **`interfaces.py`** — `TrainingSample` and protocol shapes for providers/rewards/objectives.
|
| 31 |
+
- **`objectives.py`** — `TokenGRPOObjective` (main token-level GRPO recipe) and `LatentNeuraleseObjective` (stub: format reward only, `extra_reward` no-op for future latent scoring).
|
| 32 |
+
- **`reward_plugins.py`** — Registered rewards: strict format, GSM8K/MATH-style correctness, **length penalty**, optional token-utilisation shaping. See [Length penalty rewards](#length-penalty-rewards) below.
|
| 33 |
+
- **`data_plugins.py`** — GSM8K, Hendrycks MATH by level, interleaved curricula (`gsm8k_math_curriculum`, etc.). Shared prompt prefix matches the strict completion format expected by rewards.
|
| 34 |
+
- **`backends.py`** — Loads causal LMs with Liger Llama patches when applicable, FlashAttention2 when importable, else SDPA. Exposes `generation_kwargs` from config.
|
| 35 |
+
- **`utils.py`** — `import_from_path` for `objective.class_path` style `"module:Class"` loading.
|
| 36 |
+
|
| 37 |
+
Importing `hackable` registers default plugins (`hackable/__init__.py` imports `data_plugins`, `objectives`, `reward_plugins`).
|
| 38 |
+
|
| 39 |
+
### `src/` evaluation utilities
|
| 40 |
+
|
| 41 |
+
- **`eval_sweep_models.py`** — Distributed evaluation of every `run_*` directory under a sweep root: loads checkpoints, runs GSM8K test generations, scores correctness, records CoT word length stats, writes CSV/JSON summaries (used after lambda or reward-variant sweeps).
|
| 42 |
+
- **`eval_permanent_checkpoints.py`** — Walks `checkpoints/permanent` folders, evaluates each checkpoint, can emit simple SVG learning curves.
|
| 43 |
+
- **`eval_math_level1_thinking_zeroshot.py`** — Zero-shot / thinking-format eval on MATH-style data with JSONL output (for downstream rewards or analysis).
|
| 44 |
+
|
| 45 |
+
### `scripts/` (high level)
|
| 46 |
+
|
| 47 |
+
| Script | Purpose |
|
| 48 |
+
|--------|---------|
|
| 49 |
+
| `run_grpo.sh`, `run_grpo_2gpu.sh`, `run_grpo_4gpu.sh`, `run_grpo_8gpu.sh` | Launch training with Accelerate. |
|
| 50 |
+
| `resume_grpo_8gpu.sh` | Resume from latest or explicit checkpoint. |
|
| 51 |
+
| `sweep_length_penalty_lambda.sh` | Trains multiple runs with different `length_penalty_lambda` (weighted length-penalty mode). |
|
| 52 |
+
| `run_reward_variants_and_eval.sh` | Trains three interaction/gating variants, then runs `eval_sweep_models.py`. |
|
| 53 |
+
| `run_twostage_correctness1.sh`, `run_twostage_correctness5.sh` | Two-stage schedules; YAML is expected to set `correctness_weight` and optionally stage-2 length-penalty fields (see below). |
|
| 54 |
+
| `run_lambda_0p1_existing_gate_token_util.sh` | Example run with low λ and `token_utilisation_reward` enabled. |
|
| 55 |
+
| `eval_sweep_models_offline.sh` | Offline eval driver for a length-penalty λ sweep directory. |
|
| 56 |
+
| `eval_length_penalty_ablation_offline.sh` | Launches `src/eval_length_penalty_ablation.py` (script must exist alongside training code). |
|
| 57 |
+
| `eval_twostage_permanent_checkpoints.sh` | Eval for two-stage permanent checkpoint trees. |
|
| 58 |
+
| `eval_gsm8k_*.sh`, `eval_math_level*_*.sh` | Dataset-specific eval launchers. |
|
| 59 |
+
| `hf_upload_repo.py`, `hf_download_repo.py` | Push/pull Hugging Face dataset repo snapshots. |
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## Length penalty rewards
|
| 64 |
+
|
| 65 |
+
Training prefers **shorter thinking traces** *within each GRPO group* (same prompt, multiple sampled completions). The signal is implemented as a **reward** (`length_penalty_reward`) and combined with correctness and format either **additively** or via a **weighted multiplicative** term controlled by `TokenGRPOObjective`.
|
| 66 |
+
|
| 67 |
+
### What gets measured
|
| 68 |
+
|
| 69 |
+
1. **Strict format only** — `_think_length_tokens` in `reward_plugins.py` parses the completion with the same regex as `format_tag_reward`: a single block `<redacted_thinking>...</redacted_thinking>` followed by `\boxed{...}`. If the completion does not match, thinking length is treated as **0** (so length reward does not reward malformed outputs on length grounds).
|
| 70 |
+
2. **Token count** — If `tokenizer_name` is passed (via `rewards.kwargs.length_penalty_reward` in YAML and `build_reward`), the **thinking substring** is encoded with that tokenizer (`add_special_tokens=False`) and the length is the number of token IDs. If no tokenizer is configured, length falls back to **whitespace-split words** inside the thinking block.
|
| 71 |
+
3. **Per-group normalization** — For each group \(G\) of completions (see grouping below), let \(L_i\) be the thinking length of completion \(i\), and \(\bar{L} = \frac{1}{|G|}\sum_j L_j\). If \(\bar{L} \le 0\), every score in the group is **0**. Otherwise:
|
| 72 |
+
|
| 73 |
+
\[
|
| 74 |
+
R^{\mathrm{length}}_i = \max\left(0,\ 1 - \frac{L_i}{\bar{L}}\right).
|
| 75 |
+
\]
|
| 76 |
+
|
| 77 |
+
So **shorter-than-average** thinking in the group scores closer to **1**, **longer-than-average** scores closer to **0**, and everyone at the average gets **0**. This is a **relative** length preference, not an absolute token budget.
|
| 78 |
+
|
| 79 |
+
### How groups are formed
|
| 80 |
+
|
| 81 |
+
`length_penalty_reward` assigns the same group normalization to completions that belong to the same GRPO comparison group:
|
| 82 |
+
|
| 83 |
+
- If `group_size` is set (again, typically under `rewards.kwargs.length_penalty_reward`), the flat batch is chunked in order: `[0:group_size)`, `[group_size:2*group_size)`, …
|
| 84 |
+
- Else groups are inferred by **contiguous runs of identical `prompt` text** in the parallel lists passed to the reward.
|
| 85 |
+
|
| 86 |
+
### How `TokenGRPOObjective` combines length with other rewards
|
| 87 |
+
|
| 88 |
+
Configured under `objective.name: token_grpo` with `objective.kwargs` (see `src/hackable/objectives.py`).
|
| 89 |
+
|
| 90 |
+
**Registered reward names** (when enabled):
|
| 91 |
+
|
| 92 |
+
- `format_tag_reward` — 1.0 if strict thinking + non-empty boxed answer, else 0.0.
|
| 93 |
+
- `gsm8k_correctness_reward` — Parses `\boxed{...}` vs reference (GSM8K `####` answers or MATH boxed solutions), numeric normalization and tolerant float compare.
|
| 94 |
+
- `length_penalty_reward` — if `enable_length_penalty: true`.
|
| 95 |
+
- `token_utilisation_reward` — optional; shapes training vs a frozen zero-shot correctness JSONL (see docstring in `reward_plugins.py`).
|
| 96 |
+
|
| 97 |
+
**`reward_mode`**
|
| 98 |
+
|
| 99 |
+
1. **`additive`** — Total score is the **sum** of all enabled reward outputs for that sample. If `strict_format_gate: true`, any sample with `format_tag_reward ≤ 0.5` is replaced by `non_strict_penalty` (default **-1.0**) instead of the sum.
|
| 100 |
+
2. **`weighted_length_penalty`** — Correctness and length interact multiplicatively; format is added (and optionally multiplied into the interaction):
|
| 101 |
+
|
| 102 |
+
- Let \(r_c, r_f, r_\ell\) be correctness, format, and length scores in \([0,1]\) (length may be 0 if disabled or malformed).
|
| 103 |
+
- Base interaction: \(r_c \times r_\ell\).
|
| 104 |
+
- If `length_penalty_interaction` is `correctness_length_format`, the interaction is \(r_c \times r_\ell \times r_f\).
|
| 105 |
+
- Total (before optional token-util term):
|
| 106 |
+
|
| 107 |
+
\[
|
| 108 |
+
\texttt{correctness\_weight} \cdot r_c + \texttt{length\_penalty\_lambda} \cdot \text{interaction} + r_f + r_{\mathrm{util}}.
|
| 109 |
+
\]
|
| 110 |
+
|
| 111 |
+
If `strict_format_gate` is true and \(r_f \le 0.5\), the total is **`non_strict_penalty`** and the formula above is skipped.
|
| 112 |
+
|
| 113 |
+
**Important knobs**
|
| 114 |
+
|
| 115 |
+
| KWarg | Meaning |
|
| 116 |
+
|-------|---------|
|
| 117 |
+
| `enable_length_penalty` | When false, length reward is not registered; in weighted mode \(r_\ell\) is treated as 0. |
|
| 118 |
+
| `length_penalty_lambda` | Scales the \(r_c \cdot r_\ell\) (or \(r_c \cdot r_\ell \cdot r_f\)) term in weighted mode. Sweeps often try e.g. 0.25, 0.5, 0.75, 1.0. |
|
| 119 |
+
| `correctness_weight` | Scales the standalone correctness term in weighted mode. |
|
| 120 |
+
| `length_penalty_interaction` | `correctness_length` vs `correctness_length_format` (whether format enters the product). |
|
| 121 |
+
| `strict_format_gate` / `non_strict_penalty` | Hard gate on format before crediting other terms. |
|
| 122 |
+
| `stage2_length_penalty_lambda`, `stage2_start_epoch` | Stored on `TokenGRPOObjective` for **two-stage** schedules; **changing λ at epoch boundaries must be implemented in the trainer** (`combine_rewards` itself only reads the current `length_penalty_lambda` on the instance). |
|
| 123 |
+
|
| 124 |
+
**YAML wiring example** (illustrative; paths depend on your repo):
|
| 125 |
+
|
| 126 |
+
```yaml
|
| 127 |
+
objective:
|
| 128 |
+
name: token_grpo
|
| 129 |
+
kwargs:
|
| 130 |
+
reward_mode: weighted_length_penalty
|
| 131 |
+
enable_length_penalty: true
|
| 132 |
+
correctness_weight: 1.0
|
| 133 |
+
length_penalty_lambda: 0.5
|
| 134 |
+
length_penalty_interaction: correctness_length
|
| 135 |
+
strict_format_gate: true
|
| 136 |
+
non_strict_penalty: -1.0
|
| 137 |
+
|
| 138 |
+
rewards:
|
| 139 |
+
kwargs:
|
| 140 |
+
length_penalty_reward:
|
| 141 |
+
tokenizer_name: meta-llama/Llama-3.2-3B-Instruct # example
|
| 142 |
+
group_size: 4 # often match generation.num_generations
|
| 143 |
+
cache_dir: cache/models
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
---
|
| 147 |
+
|
| 148 |
+
## Hackable extension path
|
| 149 |
+
|
| 150 |
+
- Implement a new objective class with `reward_names()` and `combine_rewards`-equivalent behavior (your trainer must call into it the same way as the baseline), or add new `@register_reward` functions.
|
| 151 |
+
- Point config at a custom class: `objective.class_path: "my_package.my_module:MyObjective"`.
|
| 152 |
+
- `LatentNeuraleseObjective` is a stub: keep token rewards while you add latent signals via `extra_reward` or a future trainer hook.
|
| 153 |
+
|
| 154 |
+
## Quickstart
|
| 155 |
+
|
| 156 |
+
1. `pip install -r requirements.txt` (you may need a CUDA-matched Torch and optional FlashAttention build).
|
| 157 |
+
2. Set `HF_TOKEN` / `WANDB_API_KEY` if needed (names configurable under `auth` in YAML).
|
| 158 |
+
3. Single GPU: `bash scripts/run_grpo.sh` (requires `src/train_grpo.py` and `configs/grpo_llama32_3b_bf16.yaml` present).
|
| 159 |
+
4. Many scripts default to offline Hub/datasets/W&B; override `WANDB_MODE`, `HF_HUB_OFFLINE`, etc. if you need network access.
|
| 160 |
+
5. Runtime outputs default to `cache/artifacts/...` through `storage.cache_dir`; agent docs live in `AGENTS.md` and `docs/repository-map.md`.
|
| 161 |
+
|
| 162 |
+
## Attention backend
|
| 163 |
+
|
| 164 |
+
`backends.py` prefers **FlashAttention2** when `flash_attn` imports cleanly; otherwise **SDPA**. Llama models optionally use **Liger** kernels when installed.
|
configs/accelerate_ddp_2gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: "no"
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 2
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
configs/accelerate_ddp_4gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 4
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
configs/accelerate_ddp_7gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 7
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
configs/accelerate_ddp_8gpu.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
gpu_ids: all
|
| 7 |
+
machine_rank: 0
|
| 8 |
+
main_training_function: main
|
| 9 |
+
mixed_precision: bf16
|
| 10 |
+
num_machines: 1
|
| 11 |
+
num_processes: 8
|
| 12 |
+
rdzv_backend: static
|
| 13 |
+
same_network: true
|
| 14 |
+
tpu_env: []
|
| 15 |
+
tpu_use_cluster: false
|
| 16 |
+
tpu_use_sudo: false
|
| 17 |
+
use_cpu: false
|
configs/grpo_llama32_3b_bf16.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Baseline GRPO experiment (paths are relative to storage.cache_dir unless absolute).
|
| 2 |
+
model:
|
| 3 |
+
name: meta-llama/Llama-3.2-3B-Instruct
|
| 4 |
+
trust_remote_code: false
|
| 5 |
+
|
| 6 |
+
trainer:
|
| 7 |
+
output_dir: artifacts/runs/grpo-llama32-3b
|
| 8 |
+
run_name: grpo-llama32-3b
|
| 9 |
+
num_train_epochs: 1.0
|
| 10 |
+
max_steps: -1
|
| 11 |
+
per_device_train_batch_size: 1
|
| 12 |
+
gradient_accumulation_steps: 8
|
| 13 |
+
learning_rate: 1.0e-6
|
| 14 |
+
logging_steps: 1
|
| 15 |
+
save_steps: 200
|
| 16 |
+
save_total_limit: 3
|
| 17 |
+
permanent_checkpoint_steps: 300
|
| 18 |
+
permanent_checkpoint_dir: checkpoints/permanent
|
| 19 |
+
bf16: true
|
| 20 |
+
shuffle_dataset: false
|
| 21 |
+
report_to: wandb
|
| 22 |
+
warmup_steps: 20
|
| 23 |
+
gradient_checkpointing: true
|
| 24 |
+
max_grad_norm: 1.0
|
| 25 |
+
lr_scheduler_type: cosine
|
| 26 |
+
lr_scheduler_kwargs: {}
|
| 27 |
+
|
| 28 |
+
data:
|
| 29 |
+
provider: gsm8k_math_curriculum
|
| 30 |
+
split: train
|
| 31 |
+
max_samples: null
|
| 32 |
+
|
| 33 |
+
generation:
|
| 34 |
+
max_prompt_length: 512
|
| 35 |
+
max_completion_length: 256
|
| 36 |
+
num_generations: 4
|
| 37 |
+
temperature: 0.9
|
| 38 |
+
top_p: 0.95
|
| 39 |
+
|
| 40 |
+
objective:
|
| 41 |
+
name: token_grpo
|
| 42 |
+
kwargs:
|
| 43 |
+
reward_mode: weighted_length_penalty
|
| 44 |
+
enable_length_penalty: true
|
| 45 |
+
length_penalty_lambda: 0.1
|
| 46 |
+
correctness_weight: 1.0
|
| 47 |
+
strict_format_gate: true
|
| 48 |
+
|
| 49 |
+
rewards:
|
| 50 |
+
kwargs:
|
| 51 |
+
length_penalty_reward:
|
| 52 |
+
group_size: 4
|
| 53 |
+
|
| 54 |
+
# Multiply KL on tokens inside <redacted_thinking>...</redacted_thinking> (inner body) by this weight.
|
| 55 |
+
# 1.0 = default TRL behavior; 0.0 = no KL penalty on those tokens (policy loss still applies everywhere).
|
| 56 |
+
thinking_kl:
|
| 57 |
+
inner_kl_weight: 1.0
|
| 58 |
+
|
| 59 |
+
# TRL GRPO-specific flags (see GRPOConfig). When thinking_kl.inner_kl_weight != 1 and beta != 0,
|
| 60 |
+
# train_grpo.py forces use_liger_kernel: false.
|
| 61 |
+
grpo:
|
| 62 |
+
beta: 0.001
|
| 63 |
+
loss_type: grpo
|
| 64 |
+
importance_sampling_level: token
|
| 65 |
+
use_liger_kernel: false
|
| 66 |
+
|
| 67 |
+
storage:
|
| 68 |
+
cache_dir: cache
|
| 69 |
+
|
| 70 |
+
auth: {}
|
docs/repository-map.md
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Repository Map
|
| 2 |
+
|
| 3 |
+
## What Lives Where
|
| 4 |
+
|
| 5 |
+
- `src/train_grpo.py`: launches GRPO training from a YAML experiment config.
|
| 6 |
+
- `src/sft_gsm8k_boxed.py`: supervised fine-tuning entrypoint for boxed-answer GSM8K.
|
| 7 |
+
- `src/eval_sweep_models.py`: evaluates every `run_*` directory in a sweep root on GSM8K.
|
| 8 |
+
- `src/eval_permanent_checkpoints.py`: evaluates a run's permanent checkpoints and writes summaries and plots.
|
| 9 |
+
- `src/eval_gsm8k_zeroshot.py`: single-checkpoint GSM8K evaluation.
|
| 10 |
+
- `src/eval_gsm8k_truncated_thinking.py`: GSM8K eval with forced truncation of the thinking span.
|
| 11 |
+
- `src/eval_math_level1_thinking_zeroshot.py`: MATH level-1 zero-shot thinking-format evaluation.
|
| 12 |
+
|
| 13 |
+
## `src/hackable/`
|
| 14 |
+
|
| 15 |
+
- `config.py`: experiment dataclasses and YAML loading.
|
| 16 |
+
- `paths.py`: canonical runtime storage layout and path resolution.
|
| 17 |
+
- `registry.py`: plugin registration and construction for objectives, rewards, and data providers.
|
| 18 |
+
- `interfaces.py`: shared sample and protocol types.
|
| 19 |
+
- `data_plugins.py`: GSM8K, MATH, and curriculum data providers.
|
| 20 |
+
- `reward_plugins.py`: format, correctness, length, and token-utilization rewards.
|
| 21 |
+
- `objectives.py`: objective classes and reward-combination strategies.
|
| 22 |
+
- `backends.py`: model-loading helpers and backend selection.
|
| 23 |
+
- `thinking_kl_grpo_trainer.py`: custom GRPO trainer with masked thinking-KL handling.
|
| 24 |
+
- `thinking_kl_mask.py`: token masking logic for inner-thinking KL control.
|
| 25 |
+
- `permanent_checkpoint_callback.py`: periodic permanent checkpoint export.
|
| 26 |
+
- `utils.py`: repository-relative import and path utilities.
|
| 27 |
+
|
| 28 |
+
## Config And Launch Surface
|
| 29 |
+
|
| 30 |
+
- `configs/grpo_llama32_3b_bf16.yaml`: baseline experiment config.
|
| 31 |
+
- `configs/accelerate_ddp_*.yaml`: multi-GPU Accelerate launch configs.
|
| 32 |
+
- `scripts/run_grpo*.sh`: baseline training launchers.
|
| 33 |
+
- `scripts/run_twostage_*.sh`: two-stage experiments.
|
| 34 |
+
- `scripts/sweep_length_penalty_lambda.sh`: lambda sweep generator.
|
| 35 |
+
- `scripts/run_reward_variants_and_eval.sh`: reward variant sweep plus evaluation.
|
| 36 |
+
- `scripts/eval_*.sh`: distributed evaluation wrappers.
|
| 37 |
+
|
| 38 |
+
## Runtime Storage
|
| 39 |
+
|
| 40 |
+
All generated assets should live under `cache/`:
|
| 41 |
+
|
| 42 |
+
- `cache/datasets/`: dataset downloads and prepared data.
|
| 43 |
+
- `cache/models/`: Hugging Face model cache and local model roots.
|
| 44 |
+
- `cache/hf/`: extra Hugging Face cache state.
|
| 45 |
+
- `cache/artifacts/runs/`: one-off runs.
|
| 46 |
+
- `cache/artifacts/sweeps/`: grouped experiments.
|
| 47 |
+
- `cache/artifacts/eval/`: evaluation outputs.
|
| 48 |
+
- `cache/logs/wandb/`: W&B logs.
|
| 49 |
+
|
| 50 |
+
Relative paths like `artifacts/runs/foo` are intentionally interpreted relative to `storage.cache_dir`, so they land at `cache/artifacts/runs/foo` by default.
|
out.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
flash-attn
|
| 3 |
+
liger-kernel
|
| 4 |
+
transformers
|
| 5 |
+
trl>=1.0.0
|
| 6 |
+
datasets
|
| 7 |
+
accelerate
|
| 8 |
+
pyyaml
|
| 9 |
+
wandb
|
scripts/conda_env.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Source from other scripts: default training env is conda "sidharth".
|
| 2 |
+
# Override: CONDA_ENV=myenv bash scripts/run_grpo_8gpu.sh
|
| 3 |
+
#
|
| 4 |
+
# Conda and nvidia packages (e.g. cuda-nvcc) ship deactivate hooks that use unset variables
|
| 5 |
+
# (CUDAARCHS_BACKUP, etc.). Parent scripts often use `set -u`, which breaks those hooks.
|
| 6 |
+
# We disable nounset for this entire file, then restore the parent's setting at the end.
|
| 7 |
+
|
| 8 |
+
_nounset_was_on=0
|
| 9 |
+
case $- in *u*) _nounset_was_on=1 ;; esac
|
| 10 |
+
set +u
|
| 11 |
+
|
| 12 |
+
: "${CONDA_ENV:=sidharth}"
|
| 13 |
+
|
| 14 |
+
_conda_try_init() {
|
| 15 |
+
if command -v conda >/dev/null 2>&1; then
|
| 16 |
+
# shellcheck disable=SC2312
|
| 17 |
+
eval "$(conda shell.bash hook)"
|
| 18 |
+
return 0
|
| 19 |
+
fi
|
| 20 |
+
local d
|
| 21 |
+
for d in "${CONDA_PREFIX:-}" "$HOME/miniconda3" "$HOME/mambaforge" "$HOME/anaconda3" "/opt/conda"; do
|
| 22 |
+
if [[ -n "$d" && -f "$d/etc/profile.d/conda.sh" ]]; then
|
| 23 |
+
# shellcheck disable=SC1090
|
| 24 |
+
source "$d/etc/profile.d/conda.sh"
|
| 25 |
+
return 0
|
| 26 |
+
fi
|
| 27 |
+
done
|
| 28 |
+
return 1
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
if ! _conda_try_init; then
|
| 32 |
+
echo "conda_env.sh: could not initialize conda (not on PATH and no known install)." >&2
|
| 33 |
+
exit 1
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
conda activate "${CONDA_ENV}"
|
| 37 |
+
|
| 38 |
+
if [[ "$_nounset_was_on" -eq 1 ]]; then
|
| 39 |
+
set -u
|
| 40 |
+
fi
|
scripts/eval_gsm8k_think_entropy_4gpu.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
# Required:
|
| 18 |
+
# MODEL_DIR=/path/to/checkpoint
|
| 19 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 20 |
+
|
| 21 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_think_entropy/entropy.json}"
|
| 22 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 23 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 24 |
+
SAVE_PER_SAMPLE="${SAVE_PER_SAMPLE:-0}"
|
| 25 |
+
|
| 26 |
+
export PYTHONPATH="${ROOT}/src"
|
| 27 |
+
export BASE_CONFIG
|
| 28 |
+
export MODEL_DIR
|
| 29 |
+
export OUTPUT_PATH
|
| 30 |
+
export BATCH_SIZE
|
| 31 |
+
export MAX_SAMPLES
|
| 32 |
+
export SAVE_PER_SAMPLE
|
| 33 |
+
|
| 34 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 35 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 36 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 37 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 38 |
+
|
| 39 |
+
accelerate launch \
|
| 40 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 41 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 42 |
+
src/eval_gsm8k_think_entropy.py
|
scripts/eval_gsm8k_truncated_thinking.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# GSM8K eval with forced early </redacted_thinking> after N inner thinking tokens (default 5).
|
| 3 |
+
#
|
| 4 |
+
# Example:
|
| 5 |
+
# MODEL_PATH=artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent/checkpoint-stage1-boundary-epoch-1p0-step-1868 \
|
| 6 |
+
# bash scripts/eval_gsm8k_truncated_thinking.sh
|
| 7 |
+
set -eo pipefail
|
| 8 |
+
|
| 9 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 10 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 11 |
+
cd "${REPO_ROOT}"
|
| 12 |
+
# shellcheck disable=SC1091
|
| 13 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 14 |
+
set -u
|
| 15 |
+
|
| 16 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 17 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 18 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 19 |
+
MODEL_PATH="${MODEL_PATH:?Set MODEL_PATH to your checkpoint directory}"
|
| 20 |
+
THINK_INNER_TOKEN_LIMIT="${THINK_INNER_TOKEN_LIMIT:-5}"
|
| 21 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 22 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-256}"
|
| 23 |
+
OUTPUT_PATH="${OUTPUT_PATH:-}"
|
| 24 |
+
|
| 25 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 26 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 27 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
|
| 28 |
+
|
| 29 |
+
CMD_ENV=(
|
| 30 |
+
PYTHONPATH="${REPO_ROOT}/src"
|
| 31 |
+
BASE_CONFIG="${BASE_CONFIG}"
|
| 32 |
+
MODEL_PATH="${MODEL_PATH}"
|
| 33 |
+
THINK_INNER_TOKEN_LIMIT="${THINK_INNER_TOKEN_LIMIT}"
|
| 34 |
+
MAX_SAMPLES="${MAX_SAMPLES}"
|
| 35 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS}"
|
| 36 |
+
)
|
| 37 |
+
if [[ -n "${OUTPUT_PATH}" ]]; then
|
| 38 |
+
CMD_ENV+=(OUTPUT_PATH="${OUTPUT_PATH}")
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
env "${CMD_ENV[@]}" accelerate launch \
|
| 42 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 43 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 44 |
+
src/eval_gsm8k_truncated_thinking.py \
|
| 45 |
+
--config "${BASE_CONFIG}" \
|
| 46 |
+
--model_path "${MODEL_PATH}" \
|
| 47 |
+
--think_inner_token_limit "${THINK_INNER_TOKEN_LIMIT}" \
|
| 48 |
+
--max_samples "${MAX_SAMPLES}" \
|
| 49 |
+
--max_new_tokens "${MAX_NEW_TOKENS}"
|
| 50 |
+
|
| 51 |
+
echo "Done."
|
scripts/eval_gsm8k_zeroshot_4gpu.sh
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Zero-shot GSM-8K eval: system prompt asks for answer only in \\boxed{}; chat template via tokenizer.
|
| 3 |
+
# Usage:
|
| 4 |
+
# MODEL_PATH=/path/to/checkpoint OUTPUT_PATH=out.jsonl ./scripts/eval_gsm8k_zeroshot_4gpu.sh
|
| 5 |
+
# Or with base YAML (uses model.name if MODEL_PATH unset):
|
| 6 |
+
# BASE_CONFIG=configs/grpo_llama32_3b_bf16.yaml OUTPUT_PATH=out.jsonl ./scripts/eval_gsm8k_zeroshot_4gpu.sh
|
| 7 |
+
set -eo pipefail
|
| 8 |
+
|
| 9 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 10 |
+
# shellcheck disable=SC1091
|
| 11 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 12 |
+
set -u
|
| 13 |
+
|
| 14 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 15 |
+
cd "${ROOT}"
|
| 16 |
+
|
| 17 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 18 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 19 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 20 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_zeroshot_answers.jsonl}"
|
| 21 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 22 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 23 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 24 |
+
|
| 25 |
+
export PYTHONPATH="${ROOT}/src"
|
| 26 |
+
export BASE_CONFIG
|
| 27 |
+
export OUTPUT_PATH
|
| 28 |
+
export BATCH_SIZE
|
| 29 |
+
export MAX_NEW_TOKENS
|
| 30 |
+
export MAX_SAMPLES
|
| 31 |
+
# Optional: MODEL_PATH, SYSTEM_PROMPT, GSM8K_SPLIT, MAX_PROMPT_LENGTH, LOCAL_FILES_ONLY=1
|
| 32 |
+
|
| 33 |
+
accelerate launch \
|
| 34 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 35 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 36 |
+
src/eval_gsm8k_zeroshot.py
|
scripts/eval_gsm8k_zeroshot_train_4gpu.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_PATH="${MODEL_PATH:-}"
|
| 18 |
+
if [[ -z "${MODEL_PATH}" ]]; then
|
| 19 |
+
echo "Missing MODEL_PATH."
|
| 20 |
+
echo "Example:"
|
| 21 |
+
echo " MODEL_PATH=/path/to/checkpoint OUTPUT_PATH=artifacts/eval/gsm8k_train_zeroshot/results.jsonl ./scripts/eval_gsm8k_zeroshot_train_4gpu.sh"
|
| 22 |
+
exit 1
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
# Train split, entire dataset by default.
|
| 26 |
+
GSM8K_SPLIT="${GSM8K_SPLIT:-train}"
|
| 27 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 28 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_train_zeroshot/results.jsonl}"
|
| 29 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 30 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 31 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-1024}"
|
| 32 |
+
|
| 33 |
+
export PYTHONPATH="${ROOT}/src"
|
| 34 |
+
export BASE_CONFIG
|
| 35 |
+
export MODEL_PATH
|
| 36 |
+
export GSM8K_SPLIT
|
| 37 |
+
export MAX_SAMPLES
|
| 38 |
+
export OUTPUT_PATH
|
| 39 |
+
export BATCH_SIZE
|
| 40 |
+
export MAX_NEW_TOKENS
|
| 41 |
+
export MAX_PROMPT_LENGTH
|
| 42 |
+
|
| 43 |
+
# Optional override for prompt behavior.
|
| 44 |
+
# export SYSTEM_PROMPT="..."
|
| 45 |
+
|
| 46 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 47 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 48 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 49 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 50 |
+
|
| 51 |
+
accelerate launch \
|
| 52 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 53 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 54 |
+
src/eval_gsm8k_zeroshot.py
|
scripts/eval_length_penalty_ablation_offline.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-2}"
|
| 16 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/ablation_length_penalty}"
|
| 17 |
+
|
| 18 |
+
# Update these if your model dirs differ.
|
| 19 |
+
MODEL_LEN_ON="${MODEL_LEN_ON:-${OUT_ROOT}/run_len_on}"
|
| 20 |
+
MODEL_LEN_OFF="${MODEL_LEN_OFF:-${OUT_ROOT}/run_len_off}"
|
| 21 |
+
|
| 22 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
|
| 23 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
|
| 24 |
+
|
| 25 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 26 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 27 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 28 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 29 |
+
|
| 30 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 31 |
+
OUT_ROOT="${OUT_ROOT}" \
|
| 32 |
+
MODEL_LEN_ON="${MODEL_LEN_ON}" \
|
| 33 |
+
MODEL_LEN_OFF="${MODEL_LEN_OFF}" \
|
| 34 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
|
| 35 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
|
| 36 |
+
accelerate launch \
|
| 37 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 38 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 39 |
+
src/eval_length_penalty_ablation.py
|
scripts/eval_math_level1_4gpu.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 18 |
+
if [[ -z "${MODEL_DIR}" ]]; then
|
| 19 |
+
echo "Missing MODEL_DIR. Example:"
|
| 20 |
+
echo " MODEL_DIR=/path/to/checkpoint OUTPUT_PATH=out.json ./scripts/eval_math_level1_4gpu.sh"
|
| 21 |
+
exit 1
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1/accuracy.json}"
|
| 25 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 26 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 27 |
+
MATH_SPLIT="${MATH_SPLIT:-test}"
|
| 28 |
+
SAVE_ROLLOUTS="${SAVE_ROLLOUTS:-0}"
|
| 29 |
+
ROLLOUTS_MAX_SAMPLES="${ROLLOUTS_MAX_SAMPLES:-50}"
|
| 30 |
+
ROLLOUTS_OUT_PATH="${ROLLOUTS_OUT_PATH:-artifacts/eval/math_level1/rollouts.jsonl}"
|
| 31 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-}"
|
| 32 |
+
MAX_COMPLETION_LENGTH="${MAX_COMPLETION_LENGTH:-}"
|
| 33 |
+
|
| 34 |
+
export PYTHONPATH="${ROOT}/src"
|
| 35 |
+
export BASE_CONFIG
|
| 36 |
+
export MODEL_DIR
|
| 37 |
+
export OUTPUT_PATH
|
| 38 |
+
export MAX_SAMPLES
|
| 39 |
+
export BATCH_SIZE
|
| 40 |
+
export MATH_SPLIT
|
| 41 |
+
export SAVE_ROLLOUTS
|
| 42 |
+
export ROLLOUTS_MAX_SAMPLES
|
| 43 |
+
export ROLLOUTS_OUT_PATH
|
| 44 |
+
if [[ -n "${MAX_PROMPT_LENGTH}" ]]; then export MAX_PROMPT_LENGTH; fi
|
| 45 |
+
if [[ -n "${MAX_COMPLETION_LENGTH}" ]]; then export MAX_COMPLETION_LENGTH; fi
|
| 46 |
+
|
| 47 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 48 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 49 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 50 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 51 |
+
|
| 52 |
+
accelerate launch \
|
| 53 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 54 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 55 |
+
src/eval_math_level1.py
|
scripts/eval_math_level1_thinking_zeroshot_4gpu.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 18 |
+
if [[ -z "${MODEL_DIR}" ]]; then
|
| 19 |
+
echo "Missing MODEL_DIR."
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl}"
|
| 24 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 25 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 26 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 27 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
|
| 28 |
+
MATH_SPLIT="${MATH_SPLIT:-test}"
|
| 29 |
+
|
| 30 |
+
export PYTHONPATH="${ROOT}/src"
|
| 31 |
+
export BASE_CONFIG
|
| 32 |
+
export MODEL_DIR
|
| 33 |
+
export OUTPUT_PATH
|
| 34 |
+
export MAX_SAMPLES
|
| 35 |
+
export BATCH_SIZE
|
| 36 |
+
export MAX_NEW_TOKENS
|
| 37 |
+
export MAX_PROMPT_LENGTH
|
| 38 |
+
export MATH_SPLIT
|
| 39 |
+
|
| 40 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 41 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 42 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 43 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 44 |
+
|
| 45 |
+
accelerate launch \
|
| 46 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 47 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 48 |
+
src/eval_math_level1_thinking_zeroshot.py
|
scripts/eval_math_level1_zeroshot_4gpu.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 18 |
+
if [[ -z "${MODEL_DIR}" ]]; then
|
| 19 |
+
echo "Missing MODEL_DIR."
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1_zeroshot/answers.jsonl}"
|
| 24 |
+
SYSTEM_PROMPT="${SYSTEM_PROMPT:-}"
|
| 25 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 26 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 27 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 28 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
|
| 29 |
+
|
| 30 |
+
export PYTHONPATH="${ROOT}/src"
|
| 31 |
+
export BASE_CONFIG
|
| 32 |
+
export MODEL_DIR
|
| 33 |
+
export OUTPUT_PATH
|
| 34 |
+
export MAX_SAMPLES
|
| 35 |
+
export BATCH_SIZE
|
| 36 |
+
export MAX_NEW_TOKENS
|
| 37 |
+
export MAX_PROMPT_LENGTH
|
| 38 |
+
|
| 39 |
+
if [[ -n "${SYSTEM_PROMPT}" ]]; then export SYSTEM_PROMPT; fi
|
| 40 |
+
|
| 41 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 42 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 43 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 44 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 45 |
+
|
| 46 |
+
accelerate launch \
|
| 47 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 48 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 49 |
+
src/eval_math_level1_zeroshot.py
|
scripts/eval_math_level2_thinking_zeroshot_4gpu.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 18 |
+
if [[ -z "${MODEL_DIR}" ]]; then
|
| 19 |
+
echo "Missing MODEL_DIR."
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level2_thinking_zeroshot/answers.jsonl}"
|
| 24 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 25 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 26 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 27 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
|
| 28 |
+
MATH_SPLIT="${MATH_SPLIT:-test}"
|
| 29 |
+
|
| 30 |
+
export PYTHONPATH="${ROOT}/src"
|
| 31 |
+
export BASE_CONFIG
|
| 32 |
+
export MODEL_DIR
|
| 33 |
+
export OUTPUT_PATH
|
| 34 |
+
export MAX_SAMPLES
|
| 35 |
+
export BATCH_SIZE
|
| 36 |
+
export MAX_NEW_TOKENS
|
| 37 |
+
export MAX_PROMPT_LENGTH
|
| 38 |
+
export MATH_SPLIT
|
| 39 |
+
|
| 40 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 41 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 42 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 43 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 44 |
+
|
| 45 |
+
accelerate launch \
|
| 46 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 47 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 48 |
+
src/eval_math_level2_thinking_zeroshot.py
|
scripts/eval_math_level2_zeroshot_4gpu.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 11 |
+
cd "${ROOT}"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
|
| 17 |
+
MODEL_DIR="${MODEL_DIR:-}"
|
| 18 |
+
if [[ -z "${MODEL_DIR}" ]]; then
|
| 19 |
+
echo "Missing MODEL_DIR."
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level2_zeroshot/answers.jsonl}"
|
| 24 |
+
SYSTEM_PROMPT="${SYSTEM_PROMPT:-}"
|
| 25 |
+
MAX_SAMPLES="${MAX_SAMPLES:--1}"
|
| 26 |
+
BATCH_SIZE="${BATCH_SIZE:-4}"
|
| 27 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
|
| 28 |
+
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
|
| 29 |
+
|
| 30 |
+
export PYTHONPATH="${ROOT}/src"
|
| 31 |
+
export BASE_CONFIG
|
| 32 |
+
export MODEL_DIR
|
| 33 |
+
export OUTPUT_PATH
|
| 34 |
+
export MAX_SAMPLES
|
| 35 |
+
export BATCH_SIZE
|
| 36 |
+
export MAX_NEW_TOKENS
|
| 37 |
+
export MAX_PROMPT_LENGTH
|
| 38 |
+
|
| 39 |
+
if [[ -n "${SYSTEM_PROMPT}" ]]; then export SYSTEM_PROMPT; fi
|
| 40 |
+
|
| 41 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 42 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 43 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 44 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 45 |
+
|
| 46 |
+
accelerate launch \
|
| 47 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 48 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 49 |
+
src/eval_math_level2_zeroshot.py
|
scripts/eval_permanent_root_acc_cot.sh
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Evaluate every checkpoint-* under a single checkpoints/permanent directory, then write
|
| 3 |
+
# JSON/CSV summary, per-checkpoint JSONL, and accuracy_vs_avg_cot_words.svg (scatter + step labels).
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# PERMANENT_ROOT=artifacts/runs/.../checkpoints/permanent ./scripts/eval_permanent_root_acc_cot.sh
|
| 7 |
+
#
|
| 8 |
+
# Optional: RUN_LABEL (default permanent), OUT_ROOT, BASE_CONFIG, EVAL_MAX_SAMPLES, NUM_PROCESSES, offline HF vars.
|
| 9 |
+
|
| 10 |
+
set -eo pipefail
|
| 11 |
+
|
| 12 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 13 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 14 |
+
cd "${REPO_ROOT}"
|
| 15 |
+
# shellcheck disable=SC1091
|
| 16 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 17 |
+
set -u
|
| 18 |
+
|
| 19 |
+
: "${PERMANENT_ROOT:?Set PERMANENT_ROOT to your checkpoints/permanent directory}"
|
| 20 |
+
|
| 21 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 22 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 23 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 24 |
+
RUN_LABEL="${RUN_LABEL:-permanent}"
|
| 25 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
|
| 26 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
|
| 27 |
+
ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES:-8}"
|
| 28 |
+
|
| 29 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 30 |
+
# Allow dataset download into cache unless you already mirrored GSM8K (set HF_DATASETS_OFFLINE=1).
|
| 31 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 32 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
|
| 33 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 34 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 35 |
+
|
| 36 |
+
CMD_ENV=(
|
| 37 |
+
PYTHONPATH="${REPO_ROOT}/src"
|
| 38 |
+
BASE_CONFIG="${BASE_CONFIG}"
|
| 39 |
+
PERMANENT_ROOT="${PERMANENT_ROOT}"
|
| 40 |
+
RUN_LABEL="${RUN_LABEL}"
|
| 41 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}"
|
| 42 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}"
|
| 43 |
+
ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES}"
|
| 44 |
+
)
|
| 45 |
+
if [[ -n "${OUT_ROOT:-}" ]]; then
|
| 46 |
+
CMD_ENV+=(OUT_ROOT="${OUT_ROOT}")
|
| 47 |
+
fi
|
| 48 |
+
env "${CMD_ENV[@]}" accelerate launch \
|
| 49 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 50 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 51 |
+
src/eval_permanent_checkpoints.py
|
| 52 |
+
|
| 53 |
+
echo "Done. Summary and accuracy_vs_avg_cot_words.svg under OUT_ROOT (default: <PERMANENT_ROOT>/eval_permanent)."
|
scripts/eval_sweep_models_offline.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 14 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 15 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 16 |
+
SWEEP_ROOT="${SWEEP_ROOT:-artifacts/sweeps/length_penalty_lambda}"
|
| 17 |
+
OUT_ROOT="${OUT_ROOT:-${SWEEP_ROOT}/eval_results}"
|
| 18 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
|
| 19 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
|
| 20 |
+
|
| 21 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 22 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 23 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 24 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 25 |
+
|
| 26 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 27 |
+
SWEEP_ROOT="${SWEEP_ROOT}" \
|
| 28 |
+
OUT_ROOT="${OUT_ROOT}" \
|
| 29 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
|
| 30 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
|
| 31 |
+
accelerate launch \
|
| 32 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 33 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 34 |
+
src/eval_sweep_models.py
|
scripts/eval_twostage_permanent_checkpoints.sh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
cd "${REPO_ROOT}"
|
| 7 |
+
# shellcheck disable=SC1091
|
| 8 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 9 |
+
set -u
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Evaluate every checkpoint under two run-local checkpoints/permanent dirs (cw=1 and cw=5 runs).
|
| 13 |
+
# Uses 4 GPUs. Writes rollouts, full JSONL outputs, CSV/JSON summary, step line charts,
|
| 14 |
+
# and accuracy_vs_avg_cot_words.svg (scatter: accuracy vs mean CoT length).
|
| 15 |
+
|
| 16 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 17 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 18 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 19 |
+
|
| 20 |
+
PERMANENT_CW1="${PERMANENT_CW1:-artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent}"
|
| 21 |
+
PERMANENT_CW5="${PERMANENT_CW5:-artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/checkpoints/permanent}"
|
| 22 |
+
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight/eval_permanent_checkpoints}"
|
| 24 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
|
| 25 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
|
| 26 |
+
ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES:-8}"
|
| 27 |
+
|
| 28 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 29 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 30 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
|
| 31 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 32 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 33 |
+
|
| 34 |
+
PYTHONPATH="${REPO_ROOT}/src" \
|
| 35 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 36 |
+
PERMANENT_CW1="${PERMANENT_CW1}" \
|
| 37 |
+
PERMANENT_CW5="${PERMANENT_CW5}" \
|
| 38 |
+
OUT_ROOT="${OUT_ROOT}" \
|
| 39 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
|
| 40 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
|
| 41 |
+
ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES}" \
|
| 42 |
+
accelerate launch \
|
| 43 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 44 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 45 |
+
src/eval_permanent_checkpoints.py
|
| 46 |
+
|
| 47 |
+
echo "Done. Results under: ${OUT_ROOT}"
|
scripts/hf_download_repo.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_args() -> argparse.Namespace:
|
| 12 |
+
parser = argparse.ArgumentParser(
|
| 13 |
+
description="Download a Hugging Face dataset repo snapshot to local disk."
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"--repo-name",
|
| 17 |
+
type=str,
|
| 18 |
+
default="neuralese-move",
|
| 19 |
+
help="HF dataset repo name under the username.",
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--username",
|
| 23 |
+
type=str,
|
| 24 |
+
default="psidharth567",
|
| 25 |
+
help="HF username/org (default: psidharth567).",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--local-dir",
|
| 29 |
+
type=str,
|
| 30 |
+
default="./neuralese",
|
| 31 |
+
help="Where to download the repo snapshot locally.",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--revision",
|
| 35 |
+
type=str,
|
| 36 |
+
default=None,
|
| 37 |
+
help="Optional branch/tag/commit to download.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--hf-token",
|
| 41 |
+
type=str,
|
| 42 |
+
default=os.environ.get("HF_TOKEN", ""),
|
| 43 |
+
help="HF token (or set HF_TOKEN env var).",
|
| 44 |
+
)
|
| 45 |
+
return parser.parse_args()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main() -> None:
|
| 49 |
+
args = parse_args()
|
| 50 |
+
token = args.hf_token.strip()
|
| 51 |
+
if not token:
|
| 52 |
+
raise RuntimeError("Missing HF token. Set HF_TOKEN or pass --hf-token.")
|
| 53 |
+
|
| 54 |
+
repo_id = f"{args.username}/{args.repo_name}"
|
| 55 |
+
local_dir = Path(args.local_dir).resolve()
|
| 56 |
+
local_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 58 |
+
|
| 59 |
+
snapshot_path = snapshot_download(
|
| 60 |
+
repo_id=repo_id,
|
| 61 |
+
repo_type="dataset",
|
| 62 |
+
local_dir=str(local_dir),
|
| 63 |
+
local_dir_use_symlinks=False,
|
| 64 |
+
revision=args.revision,
|
| 65 |
+
token=token,
|
| 66 |
+
)
|
| 67 |
+
print(f"Downloaded snapshot to: {snapshot_path}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
main()
|
| 72 |
+
|
scripts/hf_upload_repo.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import inspect
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import HfApi, create_repo
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def parse_args() -> argparse.Namespace:
|
| 13 |
+
parser = argparse.ArgumentParser(
|
| 14 |
+
description="Upload a local folder to a Hugging Face dataset repo."
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--folder-path",
|
| 18 |
+
type=str,
|
| 19 |
+
default=".",
|
| 20 |
+
help="Local folder to upload (default: current directory).",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--repo-name",
|
| 24 |
+
type=str,
|
| 25 |
+
default="neuralese-move",
|
| 26 |
+
help="HF dataset repo name under the username.",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--username",
|
| 30 |
+
type=str,
|
| 31 |
+
default="psidharth567",
|
| 32 |
+
help="HF username/org (default: psidharth567).",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--private",
|
| 36 |
+
action="store_true",
|
| 37 |
+
help="Create repo as private (default is public).",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--path-in-repo",
|
| 41 |
+
type=str,
|
| 42 |
+
default=".",
|
| 43 |
+
help="Destination path inside HF repo (default: repo root).",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--commit-message",
|
| 47 |
+
type=str,
|
| 48 |
+
default="temporary cluster migration upload",
|
| 49 |
+
help="Commit message for upload.",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--ignore",
|
| 53 |
+
type=str,
|
| 54 |
+
nargs="*",
|
| 55 |
+
default=[
|
| 56 |
+
"**/.git/**",
|
| 57 |
+
"**/__pycache__/**",
|
| 58 |
+
"**/*.pyc",
|
| 59 |
+
"**/*.pyo",
|
| 60 |
+
"**/.DS_Store",
|
| 61 |
+
],
|
| 62 |
+
help="Glob patterns to ignore during upload.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--hf-token",
|
| 66 |
+
type=str,
|
| 67 |
+
default=os.environ.get("HF_TOKEN", ""),
|
| 68 |
+
help="HF token (or set HF_TOKEN env var).",
|
| 69 |
+
)
|
| 70 |
+
return parser.parse_args()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main() -> None:
|
| 74 |
+
args = parse_args()
|
| 75 |
+
token = args.hf_token.strip()
|
| 76 |
+
if not token:
|
| 77 |
+
raise RuntimeError("Missing HF token. Set HF_TOKEN or pass --hf-token.")
|
| 78 |
+
|
| 79 |
+
folder = Path(args.folder_path).resolve()
|
| 80 |
+
if not folder.exists():
|
| 81 |
+
raise FileNotFoundError(f"Folder does not exist: {folder}")
|
| 82 |
+
if not folder.is_dir():
|
| 83 |
+
raise NotADirectoryError(f"Not a directory: {folder}")
|
| 84 |
+
|
| 85 |
+
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 86 |
+
|
| 87 |
+
repo_id = f"{args.username}/{args.repo_name}"
|
| 88 |
+
api = HfApi(token=token)
|
| 89 |
+
|
| 90 |
+
# Create repo if needed; no-op if already exists.
|
| 91 |
+
create_repo(
|
| 92 |
+
repo_id=repo_id,
|
| 93 |
+
repo_type="dataset",
|
| 94 |
+
private=bool(args.private),
|
| 95 |
+
exist_ok=True,
|
| 96 |
+
token=token,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
large_sig = inspect.signature(api.upload_large_folder)
|
| 100 |
+
large_params = set(large_sig.parameters.keys())
|
| 101 |
+
upload_kwargs = {
|
| 102 |
+
"repo_id": repo_id,
|
| 103 |
+
"repo_type": "dataset",
|
| 104 |
+
"folder_path": str(folder),
|
| 105 |
+
"ignore_patterns": args.ignore,
|
| 106 |
+
}
|
| 107 |
+
if "token" in large_params:
|
| 108 |
+
upload_kwargs["token"] = token
|
| 109 |
+
# Some huggingface_hub versions support path_in_repo for large uploads,
|
| 110 |
+
# others do not. Pass it only when supported.
|
| 111 |
+
if "path_in_repo" in large_params:
|
| 112 |
+
upload_kwargs["path_in_repo"] = args.path_in_repo
|
| 113 |
+
elif args.path_in_repo not in (".", "", "/"):
|
| 114 |
+
print(
|
| 115 |
+
"Warning: this huggingface_hub version does not support "
|
| 116 |
+
"`path_in_repo` for upload_large_folder; uploading to repo root."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
api.upload_large_folder(**upload_kwargs)
|
| 120 |
+
|
| 121 |
+
info = api.repo_info(repo_id=repo_id, repo_type="dataset")
|
| 122 |
+
print(f"Upload complete: https://huggingface.co/datasets/{repo_id}")
|
| 123 |
+
print(f"Latest commit: {info.sha}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
| 128 |
+
|
scripts/resume_grpo_8gpu.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 13 |
+
|
| 14 |
+
RESUME_FROM="${1:-latest}"
|
| 15 |
+
GPU_COUNT="${2:-${GPU_COUNT:-8}}"
|
| 16 |
+
|
| 17 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 18 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 19 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 20 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 21 |
+
|
| 22 |
+
accelerate launch \
|
| 23 |
+
--config_file configs/accelerate_ddp_8gpu.yaml \
|
| 24 |
+
--num_processes "${GPU_COUNT}" \
|
| 25 |
+
src/train_grpo.py \
|
| 26 |
+
--config configs/grpo_llama32_3b_bf16.yaml \
|
| 27 |
+
--resume_from_checkpoint "${RESUME_FROM}"
|
scripts/run_grpo.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 13 |
+
|
| 14 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 15 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 16 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 17 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 18 |
+
|
| 19 |
+
python src/train_grpo.py --config configs/grpo_llama32_3b_bf16.yaml
|
scripts/run_grpo_2gpu.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 13 |
+
|
| 14 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 15 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 16 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 17 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 18 |
+
|
| 19 |
+
accelerate launch \
|
| 20 |
+
--config_file configs/accelerate_ddp_2gpu.yaml \
|
| 21 |
+
src/train_grpo.py \
|
| 22 |
+
--config configs/grpo_llama32_3b_bf16.yaml
|
scripts/run_grpo_4gpu.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 13 |
+
|
| 14 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 15 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 16 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 17 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 18 |
+
|
| 19 |
+
accelerate launch \
|
| 20 |
+
--config_file configs/accelerate_ddp_4gpu.yaml \
|
| 21 |
+
src/train_grpo.py \
|
| 22 |
+
--config configs/grpo_llama32_3b_bf16.yaml
|
scripts/run_grpo_8gpu.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 13 |
+
|
| 14 |
+
GPU_COUNT="${1:-${GPU_COUNT:-8}}"
|
| 15 |
+
|
| 16 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 17 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 18 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 19 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 20 |
+
|
| 21 |
+
accelerate launch \
|
| 22 |
+
--config_file configs/accelerate_ddp_8gpu.yaml \
|
| 23 |
+
--num_processes "${GPU_COUNT}" \
|
| 24 |
+
src/train_grpo.py \
|
| 25 |
+
--config configs/grpo_llama32_3b_bf16.yaml
|
scripts/run_grpo_thinking_kl_masked_resume_4gpu.sh
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Defer `set -u` until after conda: nvidia/conda hooks use unset vars during activate.
|
| 3 |
+
set -eo pipefail
|
| 4 |
+
|
| 5 |
+
# Fine-tune from a stage-1 checkpoint with KL *disabled* on tokens inside
|
| 6 |
+
# <redacted_thinking>...</redacted_thinking> (inner body only). KL still applies
|
| 7 |
+
# to tags, \\boxed{...}, and any text outside the inner thinking span.
|
| 8 |
+
#
|
| 9 |
+
# Conda env: sidharth (override with CONDA_ENV=...).
|
| 10 |
+
#
|
| 11 |
+
# Defaults: online Hugging Face / datasets (fetch if missing). Override with
|
| 12 |
+
# HF_HUB_OFFLINE=1 HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
| 13 |
+
#
|
| 14 |
+
# Effective microbatch for GRPO: 4 GPUs × per_device_batch × grad_accum.
|
| 15 |
+
# Matches configs/grpo_llama32_3b_bf16.yaml baseline: per_device=1, grad_accum=8 -> 32.
|
| 16 |
+
#
|
| 17 |
+
# Overrides (defaults tuned for this launcher):
|
| 18 |
+
# LENGTH_PENALTY_LAMBDA=0.5
|
| 19 |
+
# PERMANENT_CHECKPOINT_STEPS=300 (set to 0 to disable extra copies under permanent_checkpoint_dir)
|
| 20 |
+
# PERMANENT_CHECKPOINT_DIR=checkpoints/permanent
|
| 21 |
+
|
| 22 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 23 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 24 |
+
# shellcheck disable=SC1091
|
| 25 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 26 |
+
set -u
|
| 27 |
+
|
| 28 |
+
cd "${REPO_ROOT}"
|
| 29 |
+
|
| 30 |
+
# DeepSpeed is imported by some TRL/Accelerate stacks even when using plain DDP. Its import path
|
| 31 |
+
# calls installed_cuda_version() and requires CUDA_HOME if CUDA ops are probed. PyTorch GPU builds
|
| 32 |
+
# often work without nvcc, but DeepSpeed does not — set CUDA_HOME from nvcc when unset, or run:
|
| 33 |
+
# pip uninstall deepspeed
|
| 34 |
+
if [[ -z "${CUDA_HOME:-}" ]] && command -v nvcc >/dev/null 2>&1; then
|
| 35 |
+
export CUDA_HOME="$(dirname "$(dirname "$(command -v nvcc)")")"
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 39 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 40 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 41 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 42 |
+
|
| 43 |
+
CHECKPOINT="${CHECKPOINT:-${REPO_ROOT}/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent/checkpoint-stage1-boundary-epoch-1p0-step-1868}"
|
| 44 |
+
|
| 45 |
+
# 4 × 1 × 8 = 32 (same as base YAML; override with GRAD_ACCUM=...)
|
| 46 |
+
GRAD_ACCUM="${GRAD_ACCUM:-8}"
|
| 47 |
+
LENGTH_PENALTY_LAMBDA="${LENGTH_PENALTY_LAMBDA:-0.5}"
|
| 48 |
+
PERMANENT_CHECKPOINT_STEPS="${PERMANENT_CHECKPOINT_STEPS:-300}"
|
| 49 |
+
PERMANENT_CHECKPOINT_DIR="${PERMANENT_CHECKPOINT_DIR:-checkpoints/permanent}"
|
| 50 |
+
|
| 51 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/runs/grpo_thinking_kl_masked_from_stage1_4gpu}"
|
| 52 |
+
RUN_NAME="${RUN_NAME:-grpo-think-kl0-resume-stage1-4gpu}"
|
| 53 |
+
export THINKING_INNER_KL_WEIGHT="${THINKING_INNER_KL_WEIGHT:-0.0}"
|
| 54 |
+
|
| 55 |
+
export WANDB_MODE="${WANDB_MODE:-online}"
|
| 56 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-0}"
|
| 57 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
|
| 58 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-0}"
|
| 59 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 60 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 61 |
+
|
| 62 |
+
mkdir -p "${OUT_ROOT}"
|
| 63 |
+
TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
|
| 64 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 65 |
+
CFG_PATH="${TMP_DIR}/grpo_thinking_kl_4gpu.yaml"
|
| 66 |
+
|
| 67 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 68 |
+
OUT_ROOT="${OUT_ROOT}" \
|
| 69 |
+
RUN_NAME="${RUN_NAME}" \
|
| 70 |
+
GRAD_ACCUM="${GRAD_ACCUM}" \
|
| 71 |
+
THINKING_INNER_KL_WEIGHT="${THINKING_INNER_KL_WEIGHT}" \
|
| 72 |
+
LENGTH_PENALTY_LAMBDA="${LENGTH_PENALTY_LAMBDA}" \
|
| 73 |
+
PERMANENT_CHECKPOINT_STEPS="${PERMANENT_CHECKPOINT_STEPS}" \
|
| 74 |
+
PERMANENT_CHECKPOINT_DIR="${PERMANENT_CHECKPOINT_DIR}" \
|
| 75 |
+
CFG_PATH="${CFG_PATH}" \
|
| 76 |
+
"${PYTHON_BIN}" - <<'PY'
|
| 77 |
+
import copy
|
| 78 |
+
import os
|
| 79 |
+
from pathlib import Path
|
| 80 |
+
|
| 81 |
+
import yaml
|
| 82 |
+
|
| 83 |
+
base = Path(os.environ["BASE_CONFIG"])
|
| 84 |
+
out_root = Path(os.environ["OUT_ROOT"])
|
| 85 |
+
run_name = os.environ["RUN_NAME"]
|
| 86 |
+
grad_accum = int(os.environ["GRAD_ACCUM"])
|
| 87 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 88 |
+
|
| 89 |
+
with base.open("r", encoding="utf-8") as f:
|
| 90 |
+
cfg = yaml.safe_load(f)
|
| 91 |
+
|
| 92 |
+
cfg = copy.deepcopy(cfg)
|
| 93 |
+
cfg.setdefault("model", {})
|
| 94 |
+
# Local Unsloth-touched checkpoints often need this when loading with Transformers.
|
| 95 |
+
cfg["model"]["trust_remote_code"] = True
|
| 96 |
+
cfg.setdefault("trainer", {})
|
| 97 |
+
cfg["trainer"]["gradient_accumulation_steps"] = grad_accum
|
| 98 |
+
cfg["trainer"]["run_name"] = run_name
|
| 99 |
+
cfg["trainer"]["output_dir"] = str(out_root)
|
| 100 |
+
cfg["trainer"]["permanent_checkpoint_steps"] = int(os.environ["PERMANENT_CHECKPOINT_STEPS"])
|
| 101 |
+
cfg["trainer"]["permanent_checkpoint_dir"] = os.environ["PERMANENT_CHECKPOINT_DIR"]
|
| 102 |
+
|
| 103 |
+
cfg.setdefault("objective", {}).setdefault("kwargs", {})
|
| 104 |
+
cfg["objective"]["kwargs"]["length_penalty_lambda"] = float(os.environ["LENGTH_PENALTY_LAMBDA"])
|
| 105 |
+
|
| 106 |
+
cfg.setdefault("thinking_kl", {})
|
| 107 |
+
# 0.0 = no KL on inner <redacted_thinking>...</redacted_thinking> body; override with THINKING_INNER_KL_WEIGHT=1
|
| 108 |
+
cfg["thinking_kl"]["inner_kl_weight"] = float(os.environ.get("THINKING_INNER_KL_WEIGHT", "0.0"))
|
| 109 |
+
|
| 110 |
+
cfg.setdefault("grpo", {})
|
| 111 |
+
cfg["grpo"]["use_liger_kernel"] = False
|
| 112 |
+
# Ensure KL term exists so masking is meaningful (raise beta in YAML if you use 0 today)
|
| 113 |
+
if cfg["grpo"].get("beta", 0) == 0:
|
| 114 |
+
cfg["grpo"]["beta"] = 0.001
|
| 115 |
+
|
| 116 |
+
with cfg_path.open("w", encoding="utf-8") as f:
|
| 117 |
+
yaml.safe_dump(cfg, f, sort_keys=False)
|
| 118 |
+
|
| 119 |
+
print("Wrote", cfg_path)
|
| 120 |
+
print("output_dir:", cfg["trainer"]["output_dir"])
|
| 121 |
+
print("thinking_kl.inner_kl_weight:", cfg["thinking_kl"]["inner_kl_weight"])
|
| 122 |
+
print("trainer.gradient_accumulation_steps:", cfg["trainer"]["gradient_accumulation_steps"])
|
| 123 |
+
print("objective.kwargs.length_penalty_lambda:", cfg["objective"]["kwargs"]["length_penalty_lambda"])
|
| 124 |
+
print("trainer.permanent_checkpoint_steps:", cfg["trainer"]["permanent_checkpoint_steps"])
|
| 125 |
+
print("trainer.permanent_checkpoint_dir:", cfg["trainer"]["permanent_checkpoint_dir"])
|
| 126 |
+
PY
|
| 127 |
+
|
| 128 |
+
echo "Checkpoint: ${CHECKPOINT}"
|
| 129 |
+
echo "Config: ${CFG_PATH}"
|
| 130 |
+
echo "Launching ${NUM_PROCESSES} processes (online HF/datasets unless overridden)..."
|
| 131 |
+
|
| 132 |
+
accelerate launch \
|
| 133 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 134 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 135 |
+
src/train_grpo.py \
|
| 136 |
+
--config "${CFG_PATH}" \
|
| 137 |
+
--model_path "${CHECKPOINT}"
|
| 138 |
+
|
| 139 |
+
echo "Done."
|
scripts/run_lambda_0p1_existing_gate_token_util.sh
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
cd "${REPO_ROOT}"
|
| 7 |
+
# shellcheck disable=SC1091
|
| 8 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 9 |
+
set -u
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Train one GRPO run with:
|
| 13 |
+
# - lambda = 0.1
|
| 14 |
+
# - strict format gate = true
|
| 15 |
+
# - non-multiplicative interaction (correctness_length)
|
| 16 |
+
|
| 17 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 18 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 19 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 20 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 21 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/reward_variants_lambda_0p1}"
|
| 22 |
+
|
| 23 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 24 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 25 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 26 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 27 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 28 |
+
export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
|
| 29 |
+
|
| 30 |
+
mkdir -p "${OUT_ROOT}"
|
| 31 |
+
LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_lambda_0p1_existing_gate_token_util_$(date +%Y%m%d_%H%M%S).log}"
|
| 32 |
+
exec > >(tee -a "${LOG_PATH}") 2>&1
|
| 33 |
+
|
| 34 |
+
TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
|
| 35 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 36 |
+
CFG_PATH="${TMP_DIR}/grpo_lambda_0p1_existing_gate_token_util.yaml"
|
| 37 |
+
|
| 38 |
+
echo "Preparing config..."
|
| 39 |
+
echo "Base config: ${BASE_CONFIG}"
|
| 40 |
+
echo "Out root: ${OUT_ROOT}"
|
| 41 |
+
echo "Accelerate config: ${ACCELERATE_CONFIG}"
|
| 42 |
+
echo "Num processes (GPUs): ${NUM_PROCESSES}"
|
| 43 |
+
echo "Log path: ${LOG_PATH}"
|
| 44 |
+
|
| 45 |
+
BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" REPO_ROOT="${REPO_ROOT}" "${PYTHON_BIN}" - <<'PY'
|
| 46 |
+
import copy
|
| 47 |
+
import os
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
|
| 50 |
+
import yaml
|
| 51 |
+
|
| 52 |
+
repo_root = Path(os.environ["REPO_ROOT"])
|
| 53 |
+
base_config = Path(os.environ["BASE_CONFIG"])
|
| 54 |
+
out_root = Path(os.environ["OUT_ROOT"])
|
| 55 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 56 |
+
|
| 57 |
+
with base_config.open("r", encoding="utf-8") as handle:
|
| 58 |
+
cfg = yaml.safe_load(handle)
|
| 59 |
+
|
| 60 |
+
cfg = copy.deepcopy(cfg)
|
| 61 |
+
|
| 62 |
+
cfg.setdefault("rewards", {})
|
| 63 |
+
cfg["rewards"].setdefault("kwargs", {})
|
| 64 |
+
rewards_kwargs = cfg["rewards"]["kwargs"]
|
| 65 |
+
rewards_kwargs.setdefault("token_utilisation_reward", {})
|
| 66 |
+
rewards_kwargs["token_utilisation_reward"]["results_jsonl_path"] = str(
|
| 67 |
+
repo_root / "artifacts/eval/gsm8k_train_zeroshot/results.jsonl"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
cfg.setdefault("objective", {})
|
| 71 |
+
cfg["objective"].setdefault("kwargs", {})
|
| 72 |
+
obj = cfg["objective"]["kwargs"]
|
| 73 |
+
|
| 74 |
+
obj["enable_length_penalty"] = True
|
| 75 |
+
obj["enable_token_utilisation_reward"] = True
|
| 76 |
+
obj["reward_mode"] = "weighted_length_penalty"
|
| 77 |
+
obj["length_penalty_lambda"] = 0.1
|
| 78 |
+
obj["length_penalty_interaction"] = "correctness_length_format"
|
| 79 |
+
obj["strict_format_gate"] = True
|
| 80 |
+
obj["non_strict_penalty"] = -1.0
|
| 81 |
+
|
| 82 |
+
cfg.setdefault("trainer", {})
|
| 83 |
+
base_run_name = cfg["trainer"].get("run_name", "grpo")
|
| 84 |
+
cfg["trainer"]["run_name"] = f"{base_run_name}-lambda-0p1-existing-gate-token-util"
|
| 85 |
+
cfg["trainer"]["output_dir"] = str(out_root / "run_lambda_0p1_existing_gate_token_util")
|
| 86 |
+
|
| 87 |
+
with cfg_path.open("w", encoding="utf-8") as handle:
|
| 88 |
+
yaml.safe_dump(cfg, handle, sort_keys=False)
|
| 89 |
+
|
| 90 |
+
print(f"Wrote config: {cfg_path}")
|
| 91 |
+
print(f"Run name: {cfg['trainer']['run_name']}")
|
| 92 |
+
print(f"Output dir: {cfg['trainer']['output_dir']}")
|
| 93 |
+
PY
|
| 94 |
+
|
| 95 |
+
echo
|
| 96 |
+
echo "Starting training..."
|
| 97 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 98 |
+
accelerate launch \
|
| 99 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 100 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 101 |
+
src/train_grpo.py \
|
| 102 |
+
--config "${CFG_PATH}"
|
| 103 |
+
|
| 104 |
+
echo
|
| 105 |
+
echo "Done."
|
| 106 |
+
echo "Model outputs/checkpoints under: ${OUT_ROOT}/run_lambda_0p1_existing_gate_token_util"
|
| 107 |
+
echo "Full log: ${LOG_PATH}"
|
scripts/run_reward_variants_and_eval.sh
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
# Train 3 reward variants sequentially, then evaluate all and save outputs+plot.
|
| 14 |
+
#
|
| 15 |
+
# Variants (all lambda=0.75):
|
| 16 |
+
# 1) existing_gate:
|
| 17 |
+
# total = r_correctness + lambda*(r_correctness*r_length) + r_format
|
| 18 |
+
# strict_format_gate = true
|
| 19 |
+
# 2) mult_nogate:
|
| 20 |
+
# total = r_correctness + lambda*(r_correctness*r_length*r_format) + r_format
|
| 21 |
+
# strict_format_gate = false
|
| 22 |
+
# 3) mult_gate:
|
| 23 |
+
# total = r_correctness + lambda*(r_correctness*r_length*r_format) + r_format
|
| 24 |
+
# strict_format_gate = true
|
| 25 |
+
|
| 26 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 27 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 28 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 29 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 30 |
+
LAMBDA="${LAMBDA:-0.75}"
|
| 31 |
+
|
| 32 |
+
SWEEP_ROOT="${SWEEP_ROOT:-artifacts/sweeps/reward_variants_lambda_0p75}"
|
| 33 |
+
EVAL_OUT_ROOT="${EVAL_OUT_ROOT:-${SWEEP_ROOT}/eval_results}"
|
| 34 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
|
| 35 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
|
| 36 |
+
|
| 37 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 38 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 39 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 40 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 41 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 42 |
+
export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
|
| 43 |
+
|
| 44 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 45 |
+
LOG_PATH="${LOG_PATH:-${SWEEP_ROOT}/run_reward_variants_and_eval_$(date +%Y%m%d_%H%M%S).log}"
|
| 46 |
+
exec > >(tee -a "${LOG_PATH}") 2>&1
|
| 47 |
+
|
| 48 |
+
TMP_DIR="$(mktemp -d "${SWEEP_ROOT}/tmp_cfgs.XXXXXX")"
|
| 49 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 50 |
+
|
| 51 |
+
VARIANTS=("existing_gate" "mult_nogate" "mult_gate")
|
| 52 |
+
|
| 53 |
+
echo "Starting 3-variant reward sweep..."
|
| 54 |
+
echo "Base config: ${BASE_CONFIG}"
|
| 55 |
+
echo "Sweep root: ${SWEEP_ROOT}"
|
| 56 |
+
echo "Lambda: ${LAMBDA}"
|
| 57 |
+
echo "Num processes (GPUs): ${NUM_PROCESSES}"
|
| 58 |
+
echo "Accelerate config: ${ACCELERATE_CONFIG}"
|
| 59 |
+
echo "Live+saved log path: ${LOG_PATH}"
|
| 60 |
+
|
| 61 |
+
for VARIANT in "${VARIANTS[@]}"; do
|
| 62 |
+
echo
|
| 63 |
+
echo "=== Training variant: ${VARIANT} ==="
|
| 64 |
+
CFG_PATH="${TMP_DIR}/grpo_${VARIANT}.yaml"
|
| 65 |
+
|
| 66 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 67 |
+
SWEEP_ROOT="${SWEEP_ROOT}" \
|
| 68 |
+
CFG_PATH="${CFG_PATH}" \
|
| 69 |
+
VARIANT="${VARIANT}" \
|
| 70 |
+
LAMBDA="${LAMBDA}" \
|
| 71 |
+
"${PYTHON_BIN}" - <<'PY'
|
| 72 |
+
import copy
|
| 73 |
+
import os
|
| 74 |
+
from pathlib import Path
|
| 75 |
+
|
| 76 |
+
import yaml
|
| 77 |
+
|
| 78 |
+
base_config = Path(os.environ["BASE_CONFIG"])
|
| 79 |
+
sweep_root = Path(os.environ["SWEEP_ROOT"])
|
| 80 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 81 |
+
variant = os.environ["VARIANT"]
|
| 82 |
+
lam = float(os.environ["LAMBDA"])
|
| 83 |
+
|
| 84 |
+
with base_config.open("r", encoding="utf-8") as handle:
|
| 85 |
+
cfg = yaml.safe_load(handle)
|
| 86 |
+
|
| 87 |
+
cfg = copy.deepcopy(cfg)
|
| 88 |
+
cfg.setdefault("objective", {})
|
| 89 |
+
cfg["objective"].setdefault("kwargs", {})
|
| 90 |
+
obj = cfg["objective"]["kwargs"]
|
| 91 |
+
|
| 92 |
+
obj["enable_length_penalty"] = True
|
| 93 |
+
obj["reward_mode"] = "weighted_length_penalty"
|
| 94 |
+
obj["length_penalty_lambda"] = lam
|
| 95 |
+
obj["non_strict_penalty"] = -1.0
|
| 96 |
+
|
| 97 |
+
if variant == "existing_gate":
|
| 98 |
+
obj["length_penalty_interaction"] = "correctness_length"
|
| 99 |
+
obj["strict_format_gate"] = True
|
| 100 |
+
elif variant == "mult_nogate":
|
| 101 |
+
obj["length_penalty_interaction"] = "correctness_length_format"
|
| 102 |
+
obj["strict_format_gate"] = False
|
| 103 |
+
elif variant == "mult_gate":
|
| 104 |
+
obj["length_penalty_interaction"] = "correctness_length_format"
|
| 105 |
+
obj["strict_format_gate"] = True
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown variant: {variant}")
|
| 108 |
+
|
| 109 |
+
cfg.setdefault("trainer", {})
|
| 110 |
+
base_run_name = cfg["trainer"].get("run_name", "grpo")
|
| 111 |
+
cfg["trainer"]["run_name"] = f"{base_run_name}-{variant}"
|
| 112 |
+
cfg["trainer"]["output_dir"] = str(sweep_root / f"run_lambda_0p75_{variant}")
|
| 113 |
+
|
| 114 |
+
with cfg_path.open("w", encoding="utf-8") as handle:
|
| 115 |
+
yaml.safe_dump(cfg, handle, sort_keys=False)
|
| 116 |
+
|
| 117 |
+
print(f"Wrote {cfg_path}")
|
| 118 |
+
print(f"Run dir: {cfg['trainer']['output_dir']}")
|
| 119 |
+
PY
|
| 120 |
+
|
| 121 |
+
accelerate launch \
|
| 122 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 123 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 124 |
+
src/train_grpo.py \
|
| 125 |
+
--config "${CFG_PATH}"
|
| 126 |
+
done
|
| 127 |
+
|
| 128 |
+
echo
|
| 129 |
+
echo "=== Running sweep evaluation on all trained variants ==="
|
| 130 |
+
BASE_CONFIG="${BASE_CONFIG}" \
|
| 131 |
+
SWEEP_ROOT="${SWEEP_ROOT}" \
|
| 132 |
+
OUT_ROOT="${EVAL_OUT_ROOT}" \
|
| 133 |
+
EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
|
| 134 |
+
EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
|
| 135 |
+
accelerate launch \
|
| 136 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 137 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 138 |
+
src/eval_sweep_models.py
|
| 139 |
+
|
| 140 |
+
echo
|
| 141 |
+
echo "All done."
|
| 142 |
+
echo "Train runs: ${SWEEP_ROOT}/run_lambda_0p75_*"
|
| 143 |
+
echo "Eval summary JSON: ${EVAL_OUT_ROOT}/sweep_eval_summary.json"
|
| 144 |
+
echo "Eval summary CSV: ${EVAL_OUT_ROOT}/sweep_eval_summary.csv"
|
| 145 |
+
echo "Eval outputs: ${EVAL_OUT_ROOT}/outputs/"
|
| 146 |
+
echo "Accuracy plot: ${EVAL_OUT_ROOT}/sweep_eval_accuracy.svg"
|
scripts/run_sft_gsm8k_boxed_7gpu.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# SFT on GSM8K: assistant target is only \boxed{answer} (no thinking). DDP on 7 GPUs, then GSM8K test eval on GPU 0.
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# ./scripts/run_sft_gsm8k_boxed_7gpu.sh
|
| 6 |
+
# MODEL_PATH=/path/to/model HF_DATASETS_OFFLINE=0 ./scripts/run_sft_gsm8k_boxed_7gpu.sh
|
| 7 |
+
# NUM_TRAIN_EPOCHS=3 ./scripts/run_sft_gsm8k_boxed_7gpu.sh
|
| 8 |
+
set -eo pipefail
|
| 9 |
+
|
| 10 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 11 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 12 |
+
# shellcheck disable=SC1091
|
| 13 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 14 |
+
set -u
|
| 15 |
+
|
| 16 |
+
cd "${REPO_ROOT}"
|
| 17 |
+
|
| 18 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_7gpu.yaml}"
|
| 19 |
+
NUM_PROCESSES="${NUM_PROCESSES:-7}"
|
| 20 |
+
MODEL_PATH="${MODEL_PATH:-${REPO_ROOT}/cache/models/models--unsloth--Llama-3.2-3B-Instruct}"
|
| 21 |
+
# Unsloth hub cache snapshots often lack config.json / model_type; SFT pulls config from this id (HF token if gated).
|
| 22 |
+
HUB_MODEL_ID="${HUB_MODEL_ID:-meta-llama/Llama-3.2-3B-Instruct}"
|
| 23 |
+
OUTPUT_DIR="${OUTPUT_DIR:-artifacts/runs/sft_gsm8k_boxed}"
|
| 24 |
+
NUM_TRAIN_EPOCHS="${NUM_TRAIN_EPOCHS:-1}"
|
| 25 |
+
|
| 26 |
+
if [[ -z "${CUDA_HOME:-}" ]] && command -v nvcc >/dev/null 2>&1; then
|
| 27 |
+
export CUDA_HOME="$(dirname "$(dirname "$(command -v nvcc)")")"
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 31 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 32 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
|
| 33 |
+
|
| 34 |
+
# Everything after src/sft_gsm8k_boxed.py is passed to Python, not to accelerate.
|
| 35 |
+
accelerate launch \
|
| 36 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 37 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 38 |
+
src/sft_gsm8k_boxed.py \
|
| 39 |
+
--model_path "${MODEL_PATH}" \
|
| 40 |
+
--hub_model_id "${HUB_MODEL_ID}" \
|
| 41 |
+
--output_dir "${OUTPUT_DIR}" \
|
| 42 |
+
--num_train_epochs "${NUM_TRAIN_EPOCHS}"
|
| 43 |
+
|
| 44 |
+
echo "Done. Weights: ${OUTPUT_DIR}/final_model metrics: ${OUTPUT_DIR}/gsm8k_test_metrics.json"
|
scripts/run_twostage_correctness1.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
# Two-stage run with correctness_weight = 1.0
|
| 14 |
+
# Stage 1/2 schedule is inherited from BASE_CONFIG:
|
| 15 |
+
# - length_penalty_lambda (stage 1)
|
| 16 |
+
# - stage2_length_penalty_lambda (stage 2)
|
| 17 |
+
# - stage2_start_epoch
|
| 18 |
+
|
| 19 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 20 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 21 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 22 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight}"
|
| 24 |
+
|
| 25 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 26 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 27 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 28 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 29 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 30 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 31 |
+
export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
|
| 32 |
+
|
| 33 |
+
mkdir -p "${OUT_ROOT}"
|
| 34 |
+
LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_twostage_correctness1_$(date +%Y%m%d_%H%M%S).log}"
|
| 35 |
+
exec > >(tee -a "${LOG_PATH}") 2>&1
|
| 36 |
+
|
| 37 |
+
TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
|
| 38 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 39 |
+
CFG_PATH="${TMP_DIR}/grpo_twostage_correctness1.yaml"
|
| 40 |
+
|
| 41 |
+
BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
|
| 42 |
+
import copy
|
| 43 |
+
import os
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
|
| 46 |
+
import yaml
|
| 47 |
+
|
| 48 |
+
base_config = Path(os.environ["BASE_CONFIG"])
|
| 49 |
+
out_root = Path(os.environ["OUT_ROOT"])
|
| 50 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 51 |
+
|
| 52 |
+
with base_config.open("r", encoding="utf-8") as handle:
|
| 53 |
+
cfg = yaml.safe_load(handle)
|
| 54 |
+
|
| 55 |
+
cfg = copy.deepcopy(cfg)
|
| 56 |
+
cfg.setdefault("objective", {})
|
| 57 |
+
cfg["objective"].setdefault("kwargs", {})
|
| 58 |
+
cfg["objective"]["kwargs"]["correctness_weight"] = 1.0
|
| 59 |
+
|
| 60 |
+
cfg.setdefault("trainer", {})
|
| 61 |
+
base_run_name = cfg["trainer"].get("run_name", "grpo")
|
| 62 |
+
cfg["trainer"]["run_name"] = f"{base_run_name}-twostage-cw1"
|
| 63 |
+
cfg["trainer"]["output_dir"] = str(out_root / "run_twostage_correctness1")
|
| 64 |
+
|
| 65 |
+
with cfg_path.open("w", encoding="utf-8") as handle:
|
| 66 |
+
yaml.safe_dump(cfg, handle, sort_keys=False)
|
| 67 |
+
|
| 68 |
+
print(f"Wrote config: {cfg_path}")
|
| 69 |
+
print(f"Output dir: {cfg['trainer']['output_dir']}")
|
| 70 |
+
PY
|
| 71 |
+
|
| 72 |
+
echo "Starting 4-GPU training (correctness_weight=1.0)"
|
| 73 |
+
accelerate launch \
|
| 74 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 75 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 76 |
+
src/train_grpo.py \
|
| 77 |
+
--config "${CFG_PATH}"
|
| 78 |
+
|
| 79 |
+
echo "Done. Log: ${LOG_PATH}"
|
scripts/run_twostage_correctness5.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
# Two-stage run with correctness_weight = 5.0
|
| 14 |
+
# Stage 1/2 schedule is inherited from BASE_CONFIG:
|
| 15 |
+
# - length_penalty_lambda (stage 1)
|
| 16 |
+
# - stage2_length_penalty_lambda (stage 2)
|
| 17 |
+
# - stage2_start_epoch
|
| 18 |
+
|
| 19 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 20 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 21 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 22 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight}"
|
| 24 |
+
|
| 25 |
+
export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
|
| 26 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 27 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 28 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 29 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 30 |
+
export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
|
| 31 |
+
export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
|
| 32 |
+
|
| 33 |
+
mkdir -p "${OUT_ROOT}"
|
| 34 |
+
LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_twostage_correctness5_$(date +%Y%m%d_%H%M%S).log}"
|
| 35 |
+
exec > >(tee -a "${LOG_PATH}") 2>&1
|
| 36 |
+
|
| 37 |
+
TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
|
| 38 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 39 |
+
CFG_PATH="${TMP_DIR}/grpo_twostage_correctness5.yaml"
|
| 40 |
+
|
| 41 |
+
BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
|
| 42 |
+
import copy
|
| 43 |
+
import os
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
|
| 46 |
+
import yaml
|
| 47 |
+
|
| 48 |
+
base_config = Path(os.environ["BASE_CONFIG"])
|
| 49 |
+
out_root = Path(os.environ["OUT_ROOT"])
|
| 50 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 51 |
+
|
| 52 |
+
with base_config.open("r", encoding="utf-8") as handle:
|
| 53 |
+
cfg = yaml.safe_load(handle)
|
| 54 |
+
|
| 55 |
+
cfg = copy.deepcopy(cfg)
|
| 56 |
+
cfg.setdefault("objective", {})
|
| 57 |
+
cfg["objective"].setdefault("kwargs", {})
|
| 58 |
+
cfg["objective"]["kwargs"]["correctness_weight"] = 5.0
|
| 59 |
+
|
| 60 |
+
cfg.setdefault("trainer", {})
|
| 61 |
+
base_run_name = cfg["trainer"].get("run_name", "grpo")
|
| 62 |
+
cfg["trainer"]["run_name"] = f"{base_run_name}-twostage-cw5"
|
| 63 |
+
cfg["trainer"]["output_dir"] = str(out_root / "run_twostage_correctness5")
|
| 64 |
+
|
| 65 |
+
with cfg_path.open("w", encoding="utf-8") as handle:
|
| 66 |
+
yaml.safe_dump(cfg, handle, sort_keys=False)
|
| 67 |
+
|
| 68 |
+
print(f"Wrote config: {cfg_path}")
|
| 69 |
+
print(f"Output dir: {cfg['trainer']['output_dir']}")
|
| 70 |
+
PY
|
| 71 |
+
|
| 72 |
+
echo "Starting 4-GPU training (correctness_weight=5.0)"
|
| 73 |
+
accelerate launch \
|
| 74 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 75 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 76 |
+
src/train_grpo.py \
|
| 77 |
+
--config "${CFG_PATH}"
|
| 78 |
+
|
| 79 |
+
echo "Done. Log: ${LOG_PATH}"
|
scripts/sweep_length_penalty_lambda.sh
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -eo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
# shellcheck disable=SC1091
|
| 6 |
+
source "${SCRIPT_DIR}/conda_env.sh"
|
| 7 |
+
set -u
|
| 8 |
+
|
| 9 |
+
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 10 |
+
cd "${REPO_ROOT}"
|
| 11 |
+
export PYTHONPATH="${REPO_ROOT}/src"
|
| 12 |
+
|
| 13 |
+
# Sweep two GRPO runs over length_penalty_lambda:
|
| 14 |
+
# - 0.5
|
| 15 |
+
# - 0.25
|
| 16 |
+
#
|
| 17 |
+
# Each run keeps all other settings identical to BASE_CONFIG.
|
| 18 |
+
|
| 19 |
+
BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
|
| 20 |
+
ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
|
| 21 |
+
NUM_PROCESSES="${NUM_PROCESSES:-4}"
|
| 22 |
+
PYTHON_BIN="${PYTHON_BIN:-python}"
|
| 23 |
+
OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/length_penalty_lambda}"
|
| 24 |
+
LAMBDAS="${LAMBDAS:-0.5 0.25}"
|
| 25 |
+
|
| 26 |
+
export WANDB_MODE="${WANDB_MODE:-offline}"
|
| 27 |
+
export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
|
| 28 |
+
export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
|
| 29 |
+
export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
|
| 30 |
+
|
| 31 |
+
mkdir -p "${OUT_ROOT}"
|
| 32 |
+
TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
|
| 33 |
+
trap 'rm -rf "${TMP_DIR}"' EXIT
|
| 34 |
+
|
| 35 |
+
echo "Base config: ${BASE_CONFIG}"
|
| 36 |
+
echo "Lambdas: ${LAMBDAS}"
|
| 37 |
+
echo "Output root: ${OUT_ROOT}"
|
| 38 |
+
|
| 39 |
+
for LAMBDA in ${LAMBDAS}; do
|
| 40 |
+
echo
|
| 41 |
+
echo "=== Running lambda=${LAMBDA} ==="
|
| 42 |
+
|
| 43 |
+
CFG_PATH="${TMP_DIR}/grpo_lambda_${LAMBDA}.yaml"
|
| 44 |
+
BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" LAMBDA="${LAMBDA}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
|
| 45 |
+
import copy
|
| 46 |
+
import os
|
| 47 |
+
from pathlib import Path
|
| 48 |
+
|
| 49 |
+
import yaml
|
| 50 |
+
|
| 51 |
+
base_config = Path(os.environ["BASE_CONFIG"])
|
| 52 |
+
out_root = Path(os.environ["OUT_ROOT"])
|
| 53 |
+
cfg_path = Path(os.environ["CFG_PATH"])
|
| 54 |
+
lam = os.environ["LAMBDA"]
|
| 55 |
+
|
| 56 |
+
with base_config.open("r", encoding="utf-8") as handle:
|
| 57 |
+
cfg = yaml.safe_load(handle)
|
| 58 |
+
|
| 59 |
+
cfg = copy.deepcopy(cfg)
|
| 60 |
+
cfg.setdefault("objective", {})
|
| 61 |
+
cfg["objective"].setdefault("kwargs", {})
|
| 62 |
+
cfg["objective"]["kwargs"]["enable_length_penalty"] = True
|
| 63 |
+
cfg["objective"]["kwargs"]["reward_mode"] = "weighted_length_penalty"
|
| 64 |
+
cfg["objective"]["kwargs"]["length_penalty_lambda"] = float(lam)
|
| 65 |
+
|
| 66 |
+
cfg.setdefault("trainer", {})
|
| 67 |
+
base_run_name = cfg["trainer"].get("run_name", "grpo")
|
| 68 |
+
safe_lam = lam.replace(".", "p")
|
| 69 |
+
cfg["trainer"]["run_name"] = f"{base_run_name}-lambda-{safe_lam}"
|
| 70 |
+
cfg["trainer"]["output_dir"] = str(out_root / f"run_lambda_{safe_lam}")
|
| 71 |
+
|
| 72 |
+
with cfg_path.open("w", encoding="utf-8") as handle:
|
| 73 |
+
yaml.safe_dump(cfg, handle, sort_keys=False)
|
| 74 |
+
|
| 75 |
+
print(f"Wrote config: {cfg_path}")
|
| 76 |
+
print(f"run_name: {cfg['trainer']['run_name']}")
|
| 77 |
+
print(f"output_dir: {cfg['trainer']['output_dir']}")
|
| 78 |
+
PY
|
| 79 |
+
|
| 80 |
+
accelerate launch \
|
| 81 |
+
--config_file "${ACCELERATE_CONFIG}" \
|
| 82 |
+
--num_processes "${NUM_PROCESSES}" \
|
| 83 |
+
src/train_grpo.py \
|
| 84 |
+
--config "${CFG_PATH}"
|
| 85 |
+
done
|
| 86 |
+
|
| 87 |
+
echo
|
| 88 |
+
echo "Sweep complete. Runs are under: ${OUT_ROOT}"
|
src/eval_gsm8k_truncated_thinking.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GSM8K eval with the same short ``<think>...</think>`` tags as ``data_plugins.PROMPT_PREFIX``.
|
| 3 |
+
|
| 4 |
+
The inner span is truncated to at most ``--think_inner_token_limit`` tokenizer tokens (a close tag
|
| 5 |
+
is forced if the model would exceed that, including when it closes late). Then greedy generation
|
| 6 |
+
continues for the boxed answer.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import importlib.util
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
import yaml
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 21 |
+
|
| 22 |
+
import hackable # noqa: F401
|
| 23 |
+
from hackable.data_plugins import GSM8KProvider
|
| 24 |
+
from hackable.paths import resolve_storage_path, storage_layout
|
| 25 |
+
from hackable.reward_plugins import gsm8k_correctness_reward
|
| 26 |
+
from hackable.utils import resolve_repo_path
|
| 27 |
+
|
| 28 |
+
THINK_OPEN = "<think>"
|
| 29 |
+
THINK_CLOSE = "</think>"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _greedy_generate_rest(
|
| 33 |
+
model: AutoModelForCausalLM,
|
| 34 |
+
tokenizer: AutoTokenizer,
|
| 35 |
+
cur: torch.Tensor,
|
| 36 |
+
plen: int,
|
| 37 |
+
max_new_tokens: int,
|
| 38 |
+
device: torch.device,
|
| 39 |
+
) -> str:
|
| 40 |
+
"""Continue greedy from ``cur``; completion length (after prompt) is capped by ``max_new_tokens``."""
|
| 41 |
+
generated = int(cur.shape[1]) - plen
|
| 42 |
+
remaining = max_new_tokens - generated
|
| 43 |
+
if remaining <= 0:
|
| 44 |
+
return tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
|
| 45 |
+
attn = torch.ones_like(cur, device=device)
|
| 46 |
+
rest = model.generate(
|
| 47 |
+
cur,
|
| 48 |
+
attention_mask=attn,
|
| 49 |
+
max_new_tokens=remaining,
|
| 50 |
+
do_sample=False,
|
| 51 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 52 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 53 |
+
)
|
| 54 |
+
return tokenizer.decode(rest[0, plen:], skip_special_tokens=False)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _load_yaml(path: str) -> dict:
|
| 58 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 59 |
+
return yaml.safe_load(handle)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _model_dtype(cfg: dict) -> torch.dtype:
|
| 63 |
+
return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]:
|
| 67 |
+
layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 68 |
+
return layout.datasets, layout.models
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _dist_info() -> tuple[int, int, int]:
|
| 72 |
+
rank = int(os.environ.get("RANK", "0"))
|
| 73 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 74 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 75 |
+
return rank, world_size, local_rank
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _init_distributed() -> tuple[int, int, int]:
|
| 79 |
+
rank, world_size, local_rank = _dist_info()
|
| 80 |
+
if world_size > 1 and not dist.is_initialized():
|
| 81 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 82 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
| 83 |
+
return rank, world_size, local_rank
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
|
| 87 |
+
candidate = Path(model_dir)
|
| 88 |
+
if candidate.is_absolute() and candidate.exists():
|
| 89 |
+
return candidate.resolve()
|
| 90 |
+
repo_local = resolve_repo_path(model_dir)
|
| 91 |
+
if repo_local.exists():
|
| 92 |
+
return repo_local
|
| 93 |
+
prefixed = resolve_storage_path(model_dir, base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 94 |
+
if prefixed.exists():
|
| 95 |
+
return prefixed
|
| 96 |
+
raise FileNotFoundError(f"Model directory not found: {model_dir}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _flash_attn_impl() -> str:
|
| 100 |
+
if importlib.util.find_spec("flash_attn") is not None:
|
| 101 |
+
try:
|
| 102 |
+
__import__("flash_attn")
|
| 103 |
+
return "flash_attention_2"
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
return "sdpa"
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def generate_truncated_thinking(
|
| 111 |
+
model: AutoModelForCausalLM,
|
| 112 |
+
tokenizer: AutoTokenizer,
|
| 113 |
+
prompt: str,
|
| 114 |
+
device: torch.device,
|
| 115 |
+
inner_token_limit: int,
|
| 116 |
+
max_new_tokens: int,
|
| 117 |
+
) -> str:
|
| 118 |
+
"""
|
| 119 |
+
Inner reasoning is **always** at most ``inner_token_limit`` tokenizer tokens: we never keep a
|
| 120 |
+
longer span. If the model emits ``</redacted_thinking>`` after more than that many inner tokens,
|
| 121 |
+
we rewrite the prefix to the first ``inner_token_limit`` inner tokens, then ``</redacted_thinking>``,
|
| 122 |
+
then greedy-decode the boxed answer. If there is no close tag yet but inner already has that many
|
| 123 |
+
tokens, we force the close tag and decode.
|
| 124 |
+
"""
|
| 125 |
+
enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
|
| 126 |
+
input_ids = enc["input_ids"].to(device)
|
| 127 |
+
plen = int(input_ids.shape[1])
|
| 128 |
+
prompt_ids = input_ids.clone()
|
| 129 |
+
|
| 130 |
+
cur = input_ids.clone()
|
| 131 |
+
generated_count = 0
|
| 132 |
+
|
| 133 |
+
def rewrite_and_continue(gen_text: str, after_open_idx: int, inner_text: str) -> str:
|
| 134 |
+
inner_ids = tokenizer.encode(inner_text, add_special_tokens=False)
|
| 135 |
+
inner_cut = inner_ids[:inner_token_limit]
|
| 136 |
+
inner_str = tokenizer.decode(inner_cut, skip_special_tokens=False)
|
| 137 |
+
fixed_comp = gen_text[:after_open_idx] + inner_str + THINK_CLOSE + "\n"
|
| 138 |
+
comp_ids = tokenizer.encode(fixed_comp, add_special_tokens=False)
|
| 139 |
+
comp_t = torch.tensor([comp_ids], dtype=torch.long, device=device)
|
| 140 |
+
new_cur = torch.cat([prompt_ids, comp_t], dim=-1)
|
| 141 |
+
return _greedy_generate_rest(model, tokenizer, new_cur, plen, max_new_tokens, device)
|
| 142 |
+
|
| 143 |
+
while generated_count < max_new_tokens:
|
| 144 |
+
out = model(cur)
|
| 145 |
+
next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
| 146 |
+
cur = torch.cat([cur, next_id], dim=-1)
|
| 147 |
+
generated_count += 1
|
| 148 |
+
|
| 149 |
+
gen_text = tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
|
| 150 |
+
|
| 151 |
+
# If the model never opens thinking, fall back to normal greedy (much faster).
|
| 152 |
+
if generated_count >= 128 and THINK_OPEN not in gen_text:
|
| 153 |
+
return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
|
| 154 |
+
|
| 155 |
+
if THINK_OPEN not in gen_text:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
after_open_idx = gen_text.find(THINK_OPEN) + len(THINK_OPEN)
|
| 159 |
+
after_open = gen_text[after_open_idx:]
|
| 160 |
+
|
| 161 |
+
if THINK_CLOSE in after_open:
|
| 162 |
+
inner_part, _rest = after_open.split(THINK_CLOSE, 1)
|
| 163 |
+
n_inner = len(tokenizer.encode(inner_part, add_special_tokens=False))
|
| 164 |
+
if n_inner > inner_token_limit:
|
| 165 |
+
return rewrite_and_continue(gen_text, after_open_idx, inner_part)
|
| 166 |
+
return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
|
| 167 |
+
|
| 168 |
+
inner_ids = tokenizer.encode(after_open, add_special_tokens=False)
|
| 169 |
+
if len(inner_ids) >= inner_token_limit:
|
| 170 |
+
inner_str = tokenizer.decode(inner_ids[:inner_token_limit], skip_special_tokens=False)
|
| 171 |
+
fixed_comp = gen_text[:after_open_idx] + inner_str + THINK_CLOSE + "\n"
|
| 172 |
+
comp_ids = tokenizer.encode(fixed_comp, add_special_tokens=False)
|
| 173 |
+
comp_t = torch.tensor([comp_ids], dtype=torch.long, device=device)
|
| 174 |
+
cur = torch.cat([prompt_ids, comp_t], dim=-1)
|
| 175 |
+
return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
|
| 176 |
+
|
| 177 |
+
return tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def main() -> None:
|
| 181 |
+
parser = argparse.ArgumentParser()
|
| 182 |
+
parser.add_argument("--config", type=str, default="configs/grpo_llama32_3b_bf16.yaml")
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--model_path",
|
| 185 |
+
type=str,
|
| 186 |
+
required=True,
|
| 187 |
+
help="Checkpoint directory (e.g. artifacts/.../checkpoints/permanent/checkpoint-...).",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument("--think_inner_token_limit", type=int, default=5)
|
| 190 |
+
parser.add_argument("--max_samples", type=int, default=-1, help="GSM8K test examples; -1 = all.")
|
| 191 |
+
parser.add_argument("--max_new_tokens", type=int, default=256)
|
| 192 |
+
args = parser.parse_args()
|
| 193 |
+
|
| 194 |
+
rank, world_size, local_rank = _init_distributed()
|
| 195 |
+
base_cfg = _load_yaml(str(resolve_repo_path(args.config)))
|
| 196 |
+
model_dir = _resolve_local_model_dir(base_cfg, args.model_path)
|
| 197 |
+
datasets_cache, models_cache = _get_cache_paths(base_cfg)
|
| 198 |
+
dtype = _model_dtype(base_cfg)
|
| 199 |
+
model_name_fallback = str(base_cfg["model"]["name"])
|
| 200 |
+
trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", True))
|
| 201 |
+
max_prompt_len = int(base_cfg.get("generation", {}).get("max_prompt_length", 512))
|
| 202 |
+
|
| 203 |
+
local_files_only = os.environ.get("LOCAL_FILES_ONLY", "0").strip() in ("1", "true", "True")
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 207 |
+
str(model_dir),
|
| 208 |
+
trust_remote_code=trust_remote_code,
|
| 209 |
+
cache_dir=str(models_cache),
|
| 210 |
+
local_files_only=local_files_only,
|
| 211 |
+
)
|
| 212 |
+
except Exception:
|
| 213 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 214 |
+
model_name_fallback,
|
| 215 |
+
trust_remote_code=trust_remote_code,
|
| 216 |
+
cache_dir=str(models_cache),
|
| 217 |
+
local_files_only=local_files_only,
|
| 218 |
+
)
|
| 219 |
+
if tokenizer.pad_token_id is None:
|
| 220 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 221 |
+
|
| 222 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 223 |
+
str(model_dir),
|
| 224 |
+
trust_remote_code=trust_remote_code,
|
| 225 |
+
cache_dir=str(models_cache),
|
| 226 |
+
torch_dtype=dtype,
|
| 227 |
+
local_files_only=local_files_only,
|
| 228 |
+
attn_implementation=_flash_attn_impl(),
|
| 229 |
+
)
|
| 230 |
+
if torch.cuda.is_available():
|
| 231 |
+
torch.cuda.set_device(local_rank)
|
| 232 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 233 |
+
else:
|
| 234 |
+
device = torch.device("cpu")
|
| 235 |
+
model.to(device)
|
| 236 |
+
model.eval()
|
| 237 |
+
|
| 238 |
+
provider = GSM8KProvider()
|
| 239 |
+
eval_max = None if args.max_samples < 0 else args.max_samples
|
| 240 |
+
all_samples = provider.load(
|
| 241 |
+
split="test",
|
| 242 |
+
max_samples=eval_max,
|
| 243 |
+
cache_dir=str(datasets_cache),
|
| 244 |
+
)
|
| 245 |
+
indices = list(range(rank, len(all_samples), world_size))
|
| 246 |
+
local_samples = [all_samples[i] for i in indices]
|
| 247 |
+
|
| 248 |
+
records: list[dict] = []
|
| 249 |
+
for sample in local_samples:
|
| 250 |
+
prompt = sample.prompt
|
| 251 |
+
if len(tokenizer.encode(prompt)) > max_prompt_len:
|
| 252 |
+
# Match eval_sweep style: truncate prompt text is complex; skip ultra-long for this script
|
| 253 |
+
enc = tokenizer(prompt, truncation=True, max_length=max_prompt_len, return_tensors="pt")
|
| 254 |
+
prompt = tokenizer.decode(enc["input_ids"][0], skip_special_tokens=False)
|
| 255 |
+
|
| 256 |
+
completion = generate_truncated_thinking(
|
| 257 |
+
model,
|
| 258 |
+
tokenizer,
|
| 259 |
+
prompt,
|
| 260 |
+
device,
|
| 261 |
+
inner_token_limit=args.think_inner_token_limit,
|
| 262 |
+
max_new_tokens=args.max_new_tokens,
|
| 263 |
+
)
|
| 264 |
+
score = gsm8k_correctness_reward(
|
| 265 |
+
prompts=[prompt],
|
| 266 |
+
completions=[completion],
|
| 267 |
+
references=[sample.target],
|
| 268 |
+
metadata=[sample.metadata],
|
| 269 |
+
)[0]
|
| 270 |
+
records.append(
|
| 271 |
+
{
|
| 272 |
+
"sample_index": sample.metadata.get("sample_index", -1),
|
| 273 |
+
"correctness": float(score),
|
| 274 |
+
"prompt": prompt,
|
| 275 |
+
"reference": sample.target,
|
| 276 |
+
"completion": completion,
|
| 277 |
+
}
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
del model
|
| 281 |
+
if torch.cuda.is_available():
|
| 282 |
+
torch.cuda.empty_cache()
|
| 283 |
+
|
| 284 |
+
if dist.is_initialized():
|
| 285 |
+
gathered: list[list[dict] | None] = [None for _ in range(world_size)]
|
| 286 |
+
dist.all_gather_object(gathered, records)
|
| 287 |
+
merged: list[dict] = []
|
| 288 |
+
for part in gathered:
|
| 289 |
+
if part:
|
| 290 |
+
merged.extend(part)
|
| 291 |
+
else:
|
| 292 |
+
merged = records
|
| 293 |
+
|
| 294 |
+
if rank != 0:
|
| 295 |
+
return
|
| 296 |
+
|
| 297 |
+
merged.sort(key=lambda r: int(r.get("sample_index", 0)))
|
| 298 |
+
acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0
|
| 299 |
+
summary = {
|
| 300 |
+
"model_dir": str(model_dir),
|
| 301 |
+
"think_inner_token_limit": args.think_inner_token_limit,
|
| 302 |
+
"max_new_tokens": args.max_new_tokens,
|
| 303 |
+
"num_examples": len(merged),
|
| 304 |
+
"accuracy": float(acc),
|
| 305 |
+
}
|
| 306 |
+
print(json.dumps(summary, indent=2))
|
| 307 |
+
|
| 308 |
+
out = os.environ.get("OUTPUT_PATH")
|
| 309 |
+
if out:
|
| 310 |
+
p = resolve_repo_path(out)
|
| 311 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
with p.open("w", encoding="utf-8") as handle:
|
| 313 |
+
for row in merged:
|
| 314 |
+
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 315 |
+
meta_path = p.with_suffix(".summary.json")
|
| 316 |
+
meta_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 317 |
+
print("Wrote", p, "and", meta_path)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
main()
|
src/eval_gsm8k_zeroshot.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GSM8K zero-shot eval for a single checkpoint (same pipeline as eval_sweep_models)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from eval_sweep_models import (
|
| 11 |
+
_init_distributed,
|
| 12 |
+
_load_yaml,
|
| 13 |
+
_resolve_local_model_dir,
|
| 14 |
+
_summarize,
|
| 15 |
+
evaluate_one_model,
|
| 16 |
+
)
|
| 17 |
+
from hackable.utils import resolve_repo_path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _model_dir_from_env() -> str:
|
| 21 |
+
for key in ("MODEL_PATH", "MODEL_DIR"):
|
| 22 |
+
v = os.environ.get(key)
|
| 23 |
+
if v:
|
| 24 |
+
return v
|
| 25 |
+
raise SystemExit(
|
| 26 |
+
"Set MODEL_PATH or MODEL_DIR to the checkpoint directory "
|
| 27 |
+
"(e.g. artifacts/sweeps/kl_0/run_foo/checkpoints/permanent/checkpoint-1871)."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _eval_max_samples() -> int:
|
| 32 |
+
if "EVAL_MAX_SAMPLES" in os.environ:
|
| 33 |
+
return int(os.environ["EVAL_MAX_SAMPLES"])
|
| 34 |
+
if "MAX_SAMPLES" in os.environ:
|
| 35 |
+
return int(os.environ["MAX_SAMPLES"])
|
| 36 |
+
return -1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main() -> None:
|
| 40 |
+
rank, _, _ = _init_distributed()
|
| 41 |
+
base_cfg_path = str(resolve_repo_path(os.environ.get("BASE_CONFIG", "configs/grpo_llama32_3b_bf16.yaml")))
|
| 42 |
+
base_cfg = copy.deepcopy(_load_yaml(base_cfg_path))
|
| 43 |
+
gen = base_cfg.setdefault("generation", {})
|
| 44 |
+
if os.environ.get("MAX_NEW_TOKENS"):
|
| 45 |
+
gen["max_completion_length"] = int(os.environ["MAX_NEW_TOKENS"])
|
| 46 |
+
if os.environ.get("MAX_PROMPT_LENGTH"):
|
| 47 |
+
gen["max_prompt_length"] = int(os.environ["MAX_PROMPT_LENGTH"])
|
| 48 |
+
|
| 49 |
+
model_dir = _resolve_local_model_dir(base_cfg, _model_dir_from_env())
|
| 50 |
+
eval_max = _eval_max_samples()
|
| 51 |
+
batch_size = int(os.environ.get("BATCH_SIZE", "4"))
|
| 52 |
+
|
| 53 |
+
records = evaluate_one_model(
|
| 54 |
+
model_dir=model_dir,
|
| 55 |
+
base_cfg=base_cfg,
|
| 56 |
+
eval_max_samples=eval_max,
|
| 57 |
+
batch_size=batch_size,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if rank != 0:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
out = resolve_repo_path(os.environ.get("OUTPUT_PATH", "artifacts/eval/gsm8k_zeroshot_answers.jsonl"))
|
| 64 |
+
out = Path(out)
|
| 65 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 66 |
+
with out.open("w", encoding="utf-8") as handle:
|
| 67 |
+
for row in records:
|
| 68 |
+
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 69 |
+
|
| 70 |
+
summary = _summarize(records, str(model_dir))
|
| 71 |
+
summary["outputs_jsonl"] = str(out.resolve())
|
| 72 |
+
print(json.dumps(summary, indent=2))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
src/eval_math_level1_thinking_zeroshot.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import yaml
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 12 |
+
|
| 13 |
+
import hackable # noqa: F401
|
| 14 |
+
from hackable import reward_plugins as reward_plugins_mod
|
| 15 |
+
from hackable.utils import resolve_repo_path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
THINKING_SYSTEM_PROMPT = (
|
| 19 |
+
"Solve the following math problem.\n"
|
| 20 |
+
"Think step-by-step inside <think>...</think> tags.\n"
|
| 21 |
+
"Then output only the final answer in LaTeX boxed format.\n"
|
| 22 |
+
"Do not include any words or explanations outside the tags/boxed answer.\n"
|
| 23 |
+
"Output format must be exactly:\n"
|
| 24 |
+
"<think>your reasoning</think>\n"
|
| 25 |
+
"\\boxed{your_final_answer}\n"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _load_yaml(path: str) -> dict:
|
| 30 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 31 |
+
return yaml.safe_load(handle)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _dist_info() -> tuple[int, int, int]:
|
| 35 |
+
rank = int(os.environ.get("RANK", "0"))
|
| 36 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 37 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 38 |
+
return rank, world_size, local_rank
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _init_distributed() -> tuple[int, int, int]:
|
| 42 |
+
rank, world_size, local_rank = _dist_info()
|
| 43 |
+
if world_size > 1 and not dist.is_initialized():
|
| 44 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 45 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
| 46 |
+
return rank, world_size, local_rank
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
|
| 50 |
+
candidate = Path(model_dir)
|
| 51 |
+
if candidate.is_absolute() and candidate.exists():
|
| 52 |
+
return candidate.resolve()
|
| 53 |
+
if not candidate.is_absolute() and candidate.exists():
|
| 54 |
+
return candidate.resolve()
|
| 55 |
+
|
| 56 |
+
repo_local = resolve_repo_path(model_dir)
|
| 57 |
+
if repo_local.exists():
|
| 58 |
+
return repo_local
|
| 59 |
+
|
| 60 |
+
cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 61 |
+
prefixed = (cache_root / candidate).resolve()
|
| 62 |
+
if prefixed.exists():
|
| 63 |
+
return prefixed
|
| 64 |
+
|
| 65 |
+
raise FileNotFoundError(
|
| 66 |
+
f"Model directory not found locally: '{model_dir}'. "
|
| 67 |
+
f"Tried '{candidate}', '{repo_local}', and '{prefixed}'."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _build_chat_prompts(
|
| 72 |
+
tokenizer: AutoTokenizer, questions: list[str], system_prompt: str
|
| 73 |
+
) -> list[str]:
|
| 74 |
+
if getattr(tokenizer, "chat_template", None) is None:
|
| 75 |
+
raise RuntimeError("Tokenizer has no chat_template; cannot apply chat formatting.")
|
| 76 |
+
|
| 77 |
+
prompts: list[str] = []
|
| 78 |
+
for q in questions:
|
| 79 |
+
messages = [
|
| 80 |
+
{"role": "system", "content": system_prompt},
|
| 81 |
+
{"role": "user", "content": q.strip()},
|
| 82 |
+
]
|
| 83 |
+
text = tokenizer.apply_chat_template(
|
| 84 |
+
messages,
|
| 85 |
+
tokenize=False,
|
| 86 |
+
add_generation_prompt=True,
|
| 87 |
+
)
|
| 88 |
+
prompts.append(text)
|
| 89 |
+
return prompts
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _load_math_level_rows(
|
| 93 |
+
level: str,
|
| 94 |
+
split: str,
|
| 95 |
+
max_samples: int | None,
|
| 96 |
+
cache_dir: str | None,
|
| 97 |
+
) -> tuple[list[str], list[str]]:
|
| 98 |
+
dataset_name = "EleutherAI/hendrycks_math"
|
| 99 |
+
dataset_configs = (
|
| 100 |
+
"algebra",
|
| 101 |
+
"counting_and_probability",
|
| 102 |
+
"geometry",
|
| 103 |
+
"intermediate_algebra",
|
| 104 |
+
"number_theory",
|
| 105 |
+
"prealgebra",
|
| 106 |
+
"precalculus",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
questions: list[str] = []
|
| 110 |
+
references: list[str] = []
|
| 111 |
+
|
| 112 |
+
for config_name in dataset_configs:
|
| 113 |
+
rows = load_dataset(
|
| 114 |
+
dataset_name,
|
| 115 |
+
config_name,
|
| 116 |
+
split=split,
|
| 117 |
+
cache_dir=cache_dir,
|
| 118 |
+
)
|
| 119 |
+
for row in rows:
|
| 120 |
+
row_level = str(row.get("level", "")).strip()
|
| 121 |
+
if row_level != level:
|
| 122 |
+
continue
|
| 123 |
+
questions.append(str(row.get("problem", "")))
|
| 124 |
+
references.append(str(row.get("solution", "")))
|
| 125 |
+
if max_samples is not None and len(questions) >= max_samples:
|
| 126 |
+
return questions[:max_samples], references[:max_samples]
|
| 127 |
+
|
| 128 |
+
return questions, references
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def main() -> None:
|
| 133 |
+
rank, world_size, local_rank = _init_distributed()
|
| 134 |
+
|
| 135 |
+
base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
|
| 136 |
+
model_dir = os.environ.get("MODEL_DIR") or os.environ.get("MODEL_PATH")
|
| 137 |
+
if not model_dir:
|
| 138 |
+
raise ValueError("Set MODEL_DIR or MODEL_PATH for the checkpoint to evaluate.")
|
| 139 |
+
resolved_model_dir = _resolve_local_model_dir(base_cfg, model_dir)
|
| 140 |
+
|
| 141 |
+
generation = base_cfg.get("generation", {})
|
| 142 |
+
max_prompt_length = int(generation.get("max_prompt_length", 512))
|
| 143 |
+
max_new_tokens = int(generation.get("max_completion_length", 256))
|
| 144 |
+
max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", str(max_prompt_length)))
|
| 145 |
+
max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", str(max_new_tokens)))
|
| 146 |
+
|
| 147 |
+
split = os.environ.get("MATH_SPLIT", "test")
|
| 148 |
+
max_samples_env = os.environ.get("MAX_SAMPLES", os.environ.get("EVAL_MAX_SAMPLES", "-1"))
|
| 149 |
+
max_samples = None if int(max_samples_env) < 0 else int(max_samples_env)
|
| 150 |
+
|
| 151 |
+
batch_size = int(os.environ.get("BATCH_SIZE", "4"))
|
| 152 |
+
|
| 153 |
+
cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 154 |
+
datasets_cache = str(cache_root / "datasets")
|
| 155 |
+
models_cache = str(cache_root / "models")
|
| 156 |
+
|
| 157 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 158 |
+
str(resolved_model_dir),
|
| 159 |
+
trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)),
|
| 160 |
+
cache_dir=models_cache,
|
| 161 |
+
local_files_only=True,
|
| 162 |
+
)
|
| 163 |
+
if tokenizer.pad_token_id is None:
|
| 164 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 165 |
+
# Decoder-only safe.
|
| 166 |
+
tokenizer.padding_side = "left"
|
| 167 |
+
|
| 168 |
+
dtype = torch.bfloat16 if bool(base_cfg.get("trainer", {}).get("bf16", True)) else torch.float16
|
| 169 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 170 |
+
str(resolved_model_dir),
|
| 171 |
+
trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)),
|
| 172 |
+
cache_dir=models_cache,
|
| 173 |
+
torch_dtype=dtype,
|
| 174 |
+
local_files_only=True,
|
| 175 |
+
)
|
| 176 |
+
if torch.cuda.is_available():
|
| 177 |
+
torch.cuda.set_device(local_rank)
|
| 178 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 179 |
+
else:
|
| 180 |
+
device = torch.device("cpu")
|
| 181 |
+
model.to(device)
|
| 182 |
+
model.eval()
|
| 183 |
+
|
| 184 |
+
questions, references = _load_math_level_rows(
|
| 185 |
+
level="Level 1",
|
| 186 |
+
split=split,
|
| 187 |
+
max_samples=max_samples,
|
| 188 |
+
cache_dir=datasets_cache,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
indices = list(range(rank, len(questions), world_size))
|
| 192 |
+
local_questions = [questions[i] for i in indices]
|
| 193 |
+
local_refs = [references[i] for i in indices]
|
| 194 |
+
|
| 195 |
+
chat_prompts = _build_chat_prompts(tokenizer, local_questions, THINKING_SYSTEM_PROMPT)
|
| 196 |
+
completions: list[str] = []
|
| 197 |
+
|
| 198 |
+
for start in range(0, len(chat_prompts), batch_size):
|
| 199 |
+
batch_prompts = chat_prompts[start : start + batch_size]
|
| 200 |
+
enc = tokenizer(
|
| 201 |
+
batch_prompts,
|
| 202 |
+
return_tensors="pt",
|
| 203 |
+
padding=True,
|
| 204 |
+
truncation=True,
|
| 205 |
+
max_length=max_prompt_length,
|
| 206 |
+
)
|
| 207 |
+
input_ids = enc["input_ids"].to(device)
|
| 208 |
+
attn = enc["attention_mask"].to(device)
|
| 209 |
+
prompt_seq_len = input_ids.shape[1]
|
| 210 |
+
|
| 211 |
+
out = model.generate(
|
| 212 |
+
input_ids=input_ids,
|
| 213 |
+
attention_mask=attn,
|
| 214 |
+
max_new_tokens=max_new_tokens,
|
| 215 |
+
do_sample=False,
|
| 216 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 217 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 218 |
+
)
|
| 219 |
+
for bi in range(out.size(0)):
|
| 220 |
+
gen_ids = out[bi, prompt_seq_len:]
|
| 221 |
+
completions.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
|
| 222 |
+
|
| 223 |
+
# Strict boxed correctness (project metric)
|
| 224 |
+
strict_scores = []
|
| 225 |
+
for completion, reference in zip(completions, local_refs, strict=True):
|
| 226 |
+
pred_text = reward_plugins_mod._extract_predicted_answer_text(completion)
|
| 227 |
+
ref_text = reward_plugins_mod._extract_reference_answer_text(reference)
|
| 228 |
+
if not pred_text or not ref_text:
|
| 229 |
+
strict_scores.append(0.0)
|
| 230 |
+
continue
|
| 231 |
+
pred_norm = reward_plugins_mod._normalize_answer_text(pred_text)
|
| 232 |
+
ref_norm = reward_plugins_mod._normalize_answer_text(ref_text)
|
| 233 |
+
if pred_norm and ref_norm and pred_norm == ref_norm:
|
| 234 |
+
strict_scores.append(1.0)
|
| 235 |
+
continue
|
| 236 |
+
pred_value = reward_plugins_mod._parse_numeric(pred_text)
|
| 237 |
+
ref_value = reward_plugins_mod._parse_numeric(ref_text)
|
| 238 |
+
if pred_value is not None and ref_value is not None and reward_plugins_mod._is_close(pred_value, ref_value):
|
| 239 |
+
strict_scores.append(1.0)
|
| 240 |
+
else:
|
| 241 |
+
strict_scores.append(0.0)
|
| 242 |
+
|
| 243 |
+
# Lenient numeric correctness fallback
|
| 244 |
+
lenient_scores: list[float] = []
|
| 245 |
+
for completion, reference in zip(completions, local_refs, strict=True):
|
| 246 |
+
ref_val = reward_plugins_mod._extract_reference_target(reference)
|
| 247 |
+
boxed = reward_plugins_mod._extract_last_boxed(completion)
|
| 248 |
+
if boxed:
|
| 249 |
+
pred_val = reward_plugins_mod._parse_numeric(boxed)
|
| 250 |
+
if pred_val is None:
|
| 251 |
+
nums = reward_plugins_mod._extract_numbers(boxed)
|
| 252 |
+
pred_val = nums[-1] if nums else None
|
| 253 |
+
else:
|
| 254 |
+
nums = reward_plugins_mod._extract_numbers(completion)
|
| 255 |
+
pred_val = nums[-1] if nums else None
|
| 256 |
+
|
| 257 |
+
if ref_val is not None and pred_val is not None and reward_plugins_mod._is_close(pred_val, ref_val):
|
| 258 |
+
lenient_scores.append(1.0)
|
| 259 |
+
else:
|
| 260 |
+
lenient_scores.append(0.0)
|
| 261 |
+
|
| 262 |
+
local_records: list[dict] = []
|
| 263 |
+
for i, idx in enumerate(indices):
|
| 264 |
+
local_records.append(
|
| 265 |
+
{
|
| 266 |
+
"sample_index": int(idx),
|
| 267 |
+
"question": local_questions[i],
|
| 268 |
+
"reference_answer": local_refs[i],
|
| 269 |
+
"model_answer_raw": completions[i],
|
| 270 |
+
"correctness": float(lenient_scores[i]),
|
| 271 |
+
"correctness_strict_boxed": float(strict_scores[i]),
|
| 272 |
+
}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if dist.is_initialized():
|
| 276 |
+
gathered: list[list[dict] | None] = [None for _ in range(world_size)]
|
| 277 |
+
dist.all_gather_object(gathered, local_records)
|
| 278 |
+
merged: list[dict] = []
|
| 279 |
+
for part in gathered:
|
| 280 |
+
if part:
|
| 281 |
+
merged.extend(part)
|
| 282 |
+
else:
|
| 283 |
+
merged = local_records
|
| 284 |
+
|
| 285 |
+
if rank != 0:
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
merged.sort(key=lambda r: r["sample_index"])
|
| 289 |
+
output_path = resolve_repo_path(
|
| 290 |
+
os.environ.get(
|
| 291 |
+
"OUTPUT_PATH",
|
| 292 |
+
"artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl",
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 296 |
+
with output_path.open("w", encoding="utf-8") as handle:
|
| 297 |
+
for row in merged:
|
| 298 |
+
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 299 |
+
|
| 300 |
+
acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0
|
| 301 |
+
acc_strict = (
|
| 302 |
+
sum(r["correctness_strict_boxed"] for r in merged) / len(merged)
|
| 303 |
+
if merged
|
| 304 |
+
else 0.0
|
| 305 |
+
)
|
| 306 |
+
print(f"Wrote {len(merged)} rows to {output_path}")
|
| 307 |
+
print(f"Accuracy (lenient numeric): {acc:.4f}")
|
| 308 |
+
print(f"Accuracy (strict boxed): {acc_strict:.4f}")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
main()
|
src/eval_permanent_checkpoints.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
import hackable # noqa: F401
|
| 12 |
+
from hackable.utils import resolve_repo_path
|
| 13 |
+
from eval_sweep_models import (
|
| 14 |
+
_init_distributed,
|
| 15 |
+
_load_yaml,
|
| 16 |
+
_resolve_local_model_dir,
|
| 17 |
+
evaluate_one_model,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _parse_checkpoint_step(dirname: str) -> int | None:
|
| 22 |
+
m = re.match(r"^checkpoint-(\d+)$", dirname)
|
| 23 |
+
if m:
|
| 24 |
+
return int(m.group(1))
|
| 25 |
+
m = re.search(r"-step-(\d+)$", dirname)
|
| 26 |
+
if m:
|
| 27 |
+
return int(m.group(1))
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _discover_checkpoint_jobs(
|
| 32 |
+
base_cfg: dict, permanent_root: Path, run_label: str
|
| 33 |
+
) -> list[tuple[str, int, str, Path, str]]:
|
| 34 |
+
"""(run_label, step, resolved_model_dir_str, resolved_path, dir_name)"""
|
| 35 |
+
root = permanent_root.resolve()
|
| 36 |
+
if not root.is_dir():
|
| 37 |
+
raise FileNotFoundError(f"Not a directory: {root}")
|
| 38 |
+
jobs: list[tuple[str, int, str, Path, str]] = []
|
| 39 |
+
for p in sorted(root.iterdir()):
|
| 40 |
+
if not p.is_dir():
|
| 41 |
+
continue
|
| 42 |
+
step = _parse_checkpoint_step(p.name)
|
| 43 |
+
if step is None:
|
| 44 |
+
continue
|
| 45 |
+
resolved = _resolve_local_model_dir(base_cfg, str(p))
|
| 46 |
+
jobs.append((run_label, step, str(resolved), resolved, p.name))
|
| 47 |
+
jobs.sort(key=lambda x: (x[1], x[4]))
|
| 48 |
+
return jobs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _line_chart_svg(
|
| 52 |
+
series: list[tuple[str, list[tuple[int, float]], str]],
|
| 53 |
+
title: str,
|
| 54 |
+
y_label: str,
|
| 55 |
+
y_max: float,
|
| 56 |
+
path: Path,
|
| 57 |
+
) -> None:
|
| 58 |
+
width = 900
|
| 59 |
+
height = 420
|
| 60 |
+
lm, rm, tm, bm = 70, 40, 50, 55
|
| 61 |
+
pw = width - lm - rm
|
| 62 |
+
ph = height - tm - bm
|
| 63 |
+
yb = tm + ph
|
| 64 |
+
|
| 65 |
+
all_steps: list[int] = []
|
| 66 |
+
for _, pts, _ in series:
|
| 67 |
+
all_steps.extend(s for s, _ in pts)
|
| 68 |
+
if not all_steps:
|
| 69 |
+
path.write_text(
|
| 70 |
+
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
|
| 71 |
+
f'<text x="40" y="40">{title} (no data)</text></svg>',
|
| 72 |
+
encoding="utf-8",
|
| 73 |
+
)
|
| 74 |
+
return
|
| 75 |
+
x_min, x_max = min(all_steps), max(all_steps)
|
| 76 |
+
if x_max == x_min:
|
| 77 |
+
x_max = x_min + 1
|
| 78 |
+
|
| 79 |
+
def sx(x: int) -> int:
|
| 80 |
+
return lm + int((x - x_min) / (x_max - x_min) * pw)
|
| 81 |
+
|
| 82 |
+
def sy(y: float) -> int:
|
| 83 |
+
y = max(0.0, min(y_max, y))
|
| 84 |
+
return yb - int((y / y_max) * ph) if y_max > 0 else yb
|
| 85 |
+
|
| 86 |
+
parts: list[str] = [
|
| 87 |
+
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
|
| 88 |
+
'<rect width="100%" height="100%" fill="#ffffff"/>',
|
| 89 |
+
f'<text x="{lm}" y="28" font-size="16" font-family="sans-serif">{title}</text>',
|
| 90 |
+
f'<text x="20" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
|
| 91 |
+
f'transform="rotate(-90 20 {tm + ph // 2})">{y_label}</text>',
|
| 92 |
+
f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
|
| 93 |
+
f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
|
| 94 |
+
f'<text x="{lm + pw // 2}" y="{height - 12}" text-anchor="middle" '
|
| 95 |
+
f'font-size="12" font-family="sans-serif">Training step</text>',
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
for i in range(5):
|
| 99 |
+
val = (i / 4) * y_max
|
| 100 |
+
yy = sy(val)
|
| 101 |
+
parts.append(
|
| 102 |
+
f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#999"/>'
|
| 103 |
+
)
|
| 104 |
+
parts.append(
|
| 105 |
+
f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
|
| 106 |
+
f'font-family="sans-serif">{val:.2f}</text>'
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
legend_x = lm + pw - 200
|
| 110 |
+
legend_y = tm + 8
|
| 111 |
+
for idx, (name, pts, color) in enumerate(series):
|
| 112 |
+
if len(pts) < 2:
|
| 113 |
+
pts_sorted = sorted(pts, key=lambda z: z[0])
|
| 114 |
+
if not pts_sorted:
|
| 115 |
+
continue
|
| 116 |
+
cx, cy = sx(pts_sorted[0][0]), sy(pts_sorted[0][1])
|
| 117 |
+
parts.append(
|
| 118 |
+
f'<circle cx="{cx}" cy="{cy}" r="4" fill="{color}" stroke="#111"/>'
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
pts_sorted = sorted(pts, key=lambda z: z[0])
|
| 122 |
+
d = "M " + " L ".join(f"{sx(s)} {sy(v)}" for s, v in pts_sorted)
|
| 123 |
+
parts.append(
|
| 124 |
+
f'<path d="{d}" fill="none" stroke="{color}" stroke-width="2.5"/>'
|
| 125 |
+
)
|
| 126 |
+
parts.append(
|
| 127 |
+
f'<rect x="{legend_x}" y="{legend_y + idx * 18}" width="10" height="10" fill="{color}"/>'
|
| 128 |
+
)
|
| 129 |
+
parts.append(
|
| 130 |
+
f'<text x="{legend_x + 16}" y="{legend_y + idx * 18 + 9}" font-size="11" '
|
| 131 |
+
f'font-family="sans-serif">{name}</text>'
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
parts.append("</svg>")
|
| 135 |
+
path.write_text("\n".join(parts), encoding="utf-8")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _scatter_accuracy_vs_cot_svg(rows: list[dict], path: Path, title: str) -> None:
|
| 139 |
+
"""Scatter: x = avg_cot_words, y = accuracy. One color per ``run_label``; optional path by training step."""
|
| 140 |
+
width = 640
|
| 141 |
+
height = 520
|
| 142 |
+
lm, rm, tm, bm = 72, 160, 52, 64
|
| 143 |
+
pw = width - lm - rm
|
| 144 |
+
ph = height - tm - bm
|
| 145 |
+
yb = tm + ph
|
| 146 |
+
|
| 147 |
+
if not rows:
|
| 148 |
+
path.write_text(
|
| 149 |
+
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
|
| 150 |
+
f'<text x="40" y="40">{title} (no data)</text></svg>',
|
| 151 |
+
encoding="utf-8",
|
| 152 |
+
)
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
labels: list[str] = []
|
| 156 |
+
seen: set[str] = set()
|
| 157 |
+
for r in rows:
|
| 158 |
+
lab = str(r.get("run_label", "run"))
|
| 159 |
+
if lab not in seen:
|
| 160 |
+
seen.add(lab)
|
| 161 |
+
labels.append(lab)
|
| 162 |
+
|
| 163 |
+
colors = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
|
| 164 |
+
color_map = {lab: colors[i % len(colors)] for i, lab in enumerate(labels)}
|
| 165 |
+
|
| 166 |
+
xs = [float(r["avg_cot_words"]) for r in rows]
|
| 167 |
+
ys = [float(r["accuracy"]) for r in rows]
|
| 168 |
+
x_min, x_max = min(xs), max(xs)
|
| 169 |
+
y_min, y_max = 0.0, 1.0
|
| 170 |
+
if x_max <= x_min:
|
| 171 |
+
x_max = x_min + 1.0
|
| 172 |
+
pad = (x_max - x_min) * 0.06 + 1.0
|
| 173 |
+
x_min = max(0.0, x_min - pad)
|
| 174 |
+
x_max = x_max + pad
|
| 175 |
+
|
| 176 |
+
def sx(x: float) -> float:
|
| 177 |
+
return lm + (x - x_min) / (x_max - x_min) * pw
|
| 178 |
+
|
| 179 |
+
def sy(y: float) -> float:
|
| 180 |
+
y = max(y_min, min(y_max, y))
|
| 181 |
+
return yb - (y - y_min) / (y_max - y_min) * ph
|
| 182 |
+
|
| 183 |
+
parts: list[str] = [
|
| 184 |
+
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
|
| 185 |
+
'<rect width="100%" height="100%" fill="#fafafa"/>',
|
| 186 |
+
f'<text x="{lm}" y="30" font-size="15" font-family="sans-serif">{title}</text>',
|
| 187 |
+
f'<text x="{width // 2}" y="{height - 18}" text-anchor="middle" font-size="12" '
|
| 188 |
+
f'font-family="sans-serif">Avg CoT length (words)</text>',
|
| 189 |
+
f'<text x="18" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
|
| 190 |
+
f'transform="rotate(-90 18 {tm + ph // 2})">Accuracy</text>',
|
| 191 |
+
f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
|
| 192 |
+
f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
for i in range(5):
|
| 196 |
+
val = y_min + (i / 4) * (y_max - y_min)
|
| 197 |
+
yy = sy(val)
|
| 198 |
+
parts.append(f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#bbb"/>')
|
| 199 |
+
parts.append(
|
| 200 |
+
f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
|
| 201 |
+
f'font-family="sans-serif">{val:.2f}</text>'
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
for i in range(5):
|
| 205 |
+
frac = i / 4
|
| 206 |
+
xv = x_min + frac * (x_max - x_min)
|
| 207 |
+
xx = sx(xv)
|
| 208 |
+
parts.append(f'<line x1="{xx}" y1="{yb}" x2="{xx}" y2="{yb + 4}" stroke="#bbb"/>')
|
| 209 |
+
parts.append(
|
| 210 |
+
f'<text x="{xx}" y="{yb + 18}" text-anchor="middle" font-size="10" '
|
| 211 |
+
f'font-family="sans-serif">{xv:.0f}</text>'
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
for lab in labels:
|
| 215 |
+
sub = [r for r in rows if str(r.get("run_label", "run")) == lab]
|
| 216 |
+
sub.sort(key=lambda r: int(r["checkpoint_step"]))
|
| 217 |
+
color = color_map[lab]
|
| 218 |
+
if len(sub) >= 2:
|
| 219 |
+
d = "M " + " L ".join(f'{sx(float(r["avg_cot_words"])):.1f} {sy(float(r["accuracy"])):.1f}' for r in sub)
|
| 220 |
+
parts.append(
|
| 221 |
+
f'<path d="{d}" fill="none" stroke="{color}" stroke-width="1.5" stroke-opacity="0.35"/>'
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
for r in rows:
|
| 225 |
+
lab = str(r.get("run_label", "run"))
|
| 226 |
+
color = color_map[lab]
|
| 227 |
+
cx = sx(float(r["avg_cot_words"]))
|
| 228 |
+
cy = sy(float(r["accuracy"]))
|
| 229 |
+
step = int(r["checkpoint_step"])
|
| 230 |
+
name = str(r.get("checkpoint_dir", f"step-{step}"))
|
| 231 |
+
tip = f"{name}: accuracy={float(r['accuracy']):.4f}, avg_cot_words={float(r['avg_cot_words']):.2f}"
|
| 232 |
+
parts.append(
|
| 233 |
+
f'<g><circle cx="{cx:.1f}" cy="{cy:.1f}" r="5" fill="{color}" stroke="#111" stroke-width="1">'
|
| 234 |
+
f"<title>{tip}</title></circle>"
|
| 235 |
+
f'<text x="{cx + 8:.1f}" y="{cy - 6:.1f}" font-size="9" font-family="sans-serif" fill="#333">{step}</text></g>'
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
legend_x = lm + pw + 14
|
| 239 |
+
legend_y = tm + 4
|
| 240 |
+
parts.append(
|
| 241 |
+
f'<text x="{legend_x}" y="{legend_y}" font-size="11" font-family="sans-serif" font-weight="bold">Series</text>'
|
| 242 |
+
)
|
| 243 |
+
for idx, lab in enumerate(labels):
|
| 244 |
+
cy = legend_y + 18 + idx * 20
|
| 245 |
+
parts.append(
|
| 246 |
+
f'<rect x="{legend_x}" y="{cy - 8}" width="10" height="10" fill="{color_map[lab]}"/>'
|
| 247 |
+
)
|
| 248 |
+
parts.append(
|
| 249 |
+
f'<text x="{legend_x + 16}" y="{cy}" font-size="11" font-family="sans-serif">{lab}</text>'
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
parts.append("</svg>")
|
| 253 |
+
path.write_text("\n".join(parts), encoding="utf-8")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _resolve_out_root(default: Path) -> Path:
|
| 257 |
+
raw = os.environ.get("OUT_ROOT")
|
| 258 |
+
if raw is None or not str(raw).strip():
|
| 259 |
+
return resolve_repo_path(str(default))
|
| 260 |
+
return resolve_repo_path(raw)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def main() -> None:
|
| 264 |
+
rank, _, _ = _init_distributed()
|
| 265 |
+
base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
|
| 266 |
+
|
| 267 |
+
eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
|
| 268 |
+
eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
|
| 269 |
+
rollout_n = int(os.environ.get("ROLLOUT_SAMPLES", "8"))
|
| 270 |
+
|
| 271 |
+
permanent_root = os.environ.get("PERMANENT_ROOT", "").strip()
|
| 272 |
+
if permanent_root:
|
| 273 |
+
pr = resolve_repo_path(permanent_root)
|
| 274 |
+
run_label_single = os.environ.get("RUN_LABEL", "permanent")
|
| 275 |
+
out_default = pr / "eval_permanent"
|
| 276 |
+
out_root = _resolve_out_root(out_default)
|
| 277 |
+
jobs_single = _discover_checkpoint_jobs(base_cfg, pr, run_label_single)
|
| 278 |
+
all_jobs = jobs_single
|
| 279 |
+
jobs_cw1: list = []
|
| 280 |
+
jobs_cw5: list = []
|
| 281 |
+
else:
|
| 282 |
+
cw1_root = resolve_repo_path(os.environ["PERMANENT_CW1"])
|
| 283 |
+
cw5_root = resolve_repo_path(os.environ["PERMANENT_CW5"])
|
| 284 |
+
out_default = cw1_root.parent / "eval_permanent"
|
| 285 |
+
out_root = _resolve_out_root(out_default)
|
| 286 |
+
jobs_cw1 = _discover_checkpoint_jobs(base_cfg, cw1_root, "correctness_weight_1")
|
| 287 |
+
jobs_cw5 = _discover_checkpoint_jobs(base_cfg, cw5_root, "correctness_weight_5")
|
| 288 |
+
all_jobs = jobs_cw1 + jobs_cw5
|
| 289 |
+
|
| 290 |
+
if rank == 0:
|
| 291 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 292 |
+
(out_root / "rollouts").mkdir(parents=True, exist_ok=True)
|
| 293 |
+
(out_root / "full_outputs").mkdir(parents=True, exist_ok=True)
|
| 294 |
+
if permanent_root:
|
| 295 |
+
print(f"PERMANENT_ROOT: {resolve_repo_path(permanent_root)} ({len(all_jobs)} checkpoints)")
|
| 296 |
+
for run_label, step, _, _, name in all_jobs:
|
| 297 |
+
print(f" {run_label} step={step} ({name})")
|
| 298 |
+
else:
|
| 299 |
+
print(f"Found {len(jobs_cw1)} checkpoints (cw=1), {len(jobs_cw5)} checkpoints (cw=5)")
|
| 300 |
+
for jl in (jobs_cw1, jobs_cw5):
|
| 301 |
+
for run_label, step, _, _, name in jl:
|
| 302 |
+
print(f" {run_label} step={step} ({name})")
|
| 303 |
+
|
| 304 |
+
if dist.is_initialized():
|
| 305 |
+
dist.barrier()
|
| 306 |
+
|
| 307 |
+
rows: list[dict] = []
|
| 308 |
+
for run_label, step, _resolved_str, resolved_path, dir_name in all_jobs:
|
| 309 |
+
records = evaluate_one_model(
|
| 310 |
+
model_dir=resolved_path,
|
| 311 |
+
base_cfg=base_cfg,
|
| 312 |
+
eval_max_samples=eval_max_samples,
|
| 313 |
+
batch_size=eval_batch_size,
|
| 314 |
+
)
|
| 315 |
+
if rank == 0:
|
| 316 |
+
acc = sum(float(r["correctness"]) for r in records) / len(records) if records else 0.0
|
| 317 |
+
avg_cot = sum(float(r["cot_words"]) for r in records) / len(records) if records else 0.0
|
| 318 |
+
row = {
|
| 319 |
+
"run_label": run_label,
|
| 320 |
+
"checkpoint_step": step,
|
| 321 |
+
"checkpoint_dir": dir_name,
|
| 322 |
+
"model_dir": str(resolved_path),
|
| 323 |
+
"num_examples": len(records),
|
| 324 |
+
"accuracy": acc,
|
| 325 |
+
"avg_cot_words": avg_cot,
|
| 326 |
+
}
|
| 327 |
+
rows.append(row)
|
| 328 |
+
|
| 329 |
+
rollout_dir = out_root / "rollouts" / run_label
|
| 330 |
+
rollout_dir.mkdir(parents=True, exist_ok=True)
|
| 331 |
+
rollout_path = rollout_dir / f"{dir_name}_rollouts.jsonl"
|
| 332 |
+
with rollout_path.open("w", encoding="utf-8") as handle:
|
| 333 |
+
for rec in records[:rollout_n]:
|
| 334 |
+
handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
|
| 335 |
+
|
| 336 |
+
full_path = out_root / "full_outputs" / run_label / f"{dir_name}_outputs.jsonl"
|
| 337 |
+
full_path.parent.mkdir(parents=True, exist_ok=True)
|
| 338 |
+
with full_path.open("w", encoding="utf-8") as handle:
|
| 339 |
+
for rec in records:
|
| 340 |
+
handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
|
| 341 |
+
|
| 342 |
+
print(
|
| 343 |
+
f"Eval {run_label} {dir_name}: acc={acc:.4f} avg_cot_words={avg_cot:.2f} n={len(records)}"
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
if dist.is_initialized():
|
| 347 |
+
dist.barrier()
|
| 348 |
+
|
| 349 |
+
if rank != 0:
|
| 350 |
+
return
|
| 351 |
+
|
| 352 |
+
rows.sort(key=lambda r: (r["run_label"], r["checkpoint_step"], r["checkpoint_dir"]))
|
| 353 |
+
|
| 354 |
+
summary_json = out_root / "permanent_checkpoints_eval.json"
|
| 355 |
+
summary_csv = out_root / "permanent_checkpoints_eval.csv"
|
| 356 |
+
summary_json.write_text(json.dumps(rows, indent=2), encoding="utf-8")
|
| 357 |
+
with summary_csv.open("w", encoding="utf-8", newline="") as handle:
|
| 358 |
+
w = csv.DictWriter(
|
| 359 |
+
handle,
|
| 360 |
+
fieldnames=[
|
| 361 |
+
"run_label",
|
| 362 |
+
"checkpoint_step",
|
| 363 |
+
"checkpoint_dir",
|
| 364 |
+
"model_dir",
|
| 365 |
+
"num_examples",
|
| 366 |
+
"accuracy",
|
| 367 |
+
"avg_cot_words",
|
| 368 |
+
],
|
| 369 |
+
)
|
| 370 |
+
w.writeheader()
|
| 371 |
+
for row in rows:
|
| 372 |
+
w.writerow(row)
|
| 373 |
+
|
| 374 |
+
def series_for(label: str, ykey: str) -> list[tuple[int, float]]:
|
| 375 |
+
return [
|
| 376 |
+
(int(r["checkpoint_step"]), float(r[ykey]))
|
| 377 |
+
for r in rows
|
| 378 |
+
if r["run_label"] == label
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
palette = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
|
| 382 |
+
uniq_labels = sorted({str(r["run_label"]) for r in rows})
|
| 383 |
+
acc_series = [
|
| 384 |
+
(lab, series_for(lab, "accuracy"), palette[i % len(palette)])
|
| 385 |
+
for i, lab in enumerate(uniq_labels)
|
| 386 |
+
if series_for(lab, "accuracy")
|
| 387 |
+
]
|
| 388 |
+
cot_series = [
|
| 389 |
+
(lab, series_for(lab, "avg_cot_words"), palette[i % len(palette)])
|
| 390 |
+
for i, lab in enumerate(uniq_labels)
|
| 391 |
+
if series_for(lab, "avg_cot_words")
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
cot_max = 1.0
|
| 395 |
+
for r in rows:
|
| 396 |
+
cot_max = max(cot_max, float(r["avg_cot_words"]))
|
| 397 |
+
|
| 398 |
+
if acc_series:
|
| 399 |
+
_line_chart_svg(
|
| 400 |
+
acc_series,
|
| 401 |
+
"GSM8K accuracy vs checkpoint step",
|
| 402 |
+
"Accuracy",
|
| 403 |
+
1.0,
|
| 404 |
+
out_root / "accuracy_vs_step.svg",
|
| 405 |
+
)
|
| 406 |
+
if cot_series:
|
| 407 |
+
_line_chart_svg(
|
| 408 |
+
cot_series,
|
| 409 |
+
"Average CoT length (words) vs checkpoint step",
|
| 410 |
+
"Avg CoT words",
|
| 411 |
+
cot_max,
|
| 412 |
+
out_root / "avg_cot_vs_step.svg",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
_scatter_accuracy_vs_cot_svg(
|
| 416 |
+
rows,
|
| 417 |
+
out_root / "accuracy_vs_avg_cot_words.svg",
|
| 418 |
+
"GSM8K accuracy vs average CoT length (words)",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
print(f"Saved: {summary_json}")
|
| 422 |
+
print(f"Saved: {summary_csv}")
|
| 423 |
+
if acc_series:
|
| 424 |
+
print(f"Saved: {out_root / 'accuracy_vs_step.svg'}")
|
| 425 |
+
if cot_series:
|
| 426 |
+
print(f"Saved: {out_root / 'avg_cot_vs_step.svg'}")
|
| 427 |
+
print(f"Saved: {out_root / 'accuracy_vs_avg_cot_words.svg'}")
|
| 428 |
+
print(f"Rollouts: {out_root / 'rollouts'}/<run_label>/")
|
| 429 |
+
print(f"Full outputs: {out_root / 'full_outputs'}/<run_label>/")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
main()
|
src/eval_sweep_models.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import yaml
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
import hackable # noqa: F401
|
| 15 |
+
from hackable.data_plugins import GSM8KProvider
|
| 16 |
+
from hackable.paths import resolve_storage_path, storage_layout
|
| 17 |
+
from hackable.reward_plugins import gsm8k_correctness_reward
|
| 18 |
+
from hackable.utils import resolve_repo_path
|
| 19 |
+
|
| 20 |
+
THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_yaml(path: str) -> dict:
|
| 24 |
+
with open(path, "r", encoding="utf-8") as handle:
|
| 25 |
+
return yaml.safe_load(handle)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _cot_word_len(completion: str) -> int:
|
| 29 |
+
match = THINK_RE.search(completion)
|
| 30 |
+
text = match.group(1).strip() if match else ""
|
| 31 |
+
return len(text.split()) if text else 0
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _model_dtype(cfg: dict):
|
| 35 |
+
return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]:
|
| 39 |
+
layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 40 |
+
return layout.datasets, layout.models
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _dist_info() -> tuple[int, int, int]:
|
| 44 |
+
rank = int(os.environ.get("RANK", "0"))
|
| 45 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 46 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 47 |
+
return rank, world_size, local_rank
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _init_distributed() -> tuple[int, int, int]:
|
| 51 |
+
rank, world_size, local_rank = _dist_info()
|
| 52 |
+
if world_size > 1 and not dist.is_initialized():
|
| 53 |
+
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 54 |
+
dist.init_process_group(backend=backend, init_method="env://")
|
| 55 |
+
return rank, world_size, local_rank
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
|
| 59 |
+
candidate = Path(model_dir)
|
| 60 |
+
if candidate.is_absolute() and candidate.exists():
|
| 61 |
+
return candidate.resolve()
|
| 62 |
+
if not candidate.is_absolute() and candidate.exists():
|
| 63 |
+
return candidate.resolve()
|
| 64 |
+
|
| 65 |
+
repo_local = resolve_repo_path(model_dir)
|
| 66 |
+
if repo_local.exists():
|
| 67 |
+
return repo_local
|
| 68 |
+
|
| 69 |
+
cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
|
| 70 |
+
prefixed = (cache_root / candidate).resolve()
|
| 71 |
+
if prefixed.exists():
|
| 72 |
+
return prefixed
|
| 73 |
+
|
| 74 |
+
raise FileNotFoundError(
|
| 75 |
+
f"Model directory not found locally: '{model_dir}'. "
|
| 76 |
+
f"Tried '{candidate}', '{repo_local}', and '{prefixed}'."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _resolve_sweep_root(base_cfg: dict, requested_sweep_root: Path) -> Path:
|
| 81 |
+
candidate = resolve_storage_path(
|
| 82 |
+
requested_sweep_root,
|
| 83 |
+
base_cfg.get("storage", {}).get("cache_dir", "cache"),
|
| 84 |
+
)
|
| 85 |
+
if candidate.is_dir() and any(path.is_dir() and path.name.startswith("run_") for path in candidate.iterdir()):
|
| 86 |
+
return candidate
|
| 87 |
+
raise FileNotFoundError(
|
| 88 |
+
"Could not resolve SWEEP_ROOT with run directories: "
|
| 89 |
+
f"{candidate}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _discover_model_dirs(sweep_root: Path) -> list[Path]:
|
| 94 |
+
dirs = [
|
| 95 |
+
path
|
| 96 |
+
for path in sweep_root.iterdir()
|
| 97 |
+
if path.is_dir() and path.name.startswith("run_")
|
| 98 |
+
]
|
| 99 |
+
if not dirs:
|
| 100 |
+
raise FileNotFoundError(
|
| 101 |
+
f"No run directories starting with 'run_' found in {sweep_root}"
|
| 102 |
+
)
|
| 103 |
+
return sorted(dirs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def evaluate_one_model(
|
| 108 |
+
model_dir: Path,
|
| 109 |
+
base_cfg: dict,
|
| 110 |
+
eval_max_samples: int,
|
| 111 |
+
batch_size: int,
|
| 112 |
+
) -> list[dict]:
|
| 113 |
+
rank, world_size, local_rank = _dist_info()
|
| 114 |
+
generation = base_cfg.get("generation", {})
|
| 115 |
+
max_prompt_len = int(generation.get("max_prompt_length", 512))
|
| 116 |
+
max_completion_len = int(generation.get("max_completion_length", 256))
|
| 117 |
+
model_name_fallback = str(base_cfg["model"]["name"])
|
| 118 |
+
trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", False))
|
| 119 |
+
dtype = _model_dtype(base_cfg)
|
| 120 |
+
datasets_cache, models_cache = _get_cache_paths(base_cfg)
|
| 121 |
+
|
| 122 |
+
provider = GSM8KProvider()
|
| 123 |
+
all_samples = provider.load(
|
| 124 |
+
split="test",
|
| 125 |
+
max_samples=None if eval_max_samples < 0 else eval_max_samples,
|
| 126 |
+
cache_dir=str(datasets_cache),
|
| 127 |
+
)
|
| 128 |
+
indices = list(range(rank, len(all_samples), world_size))
|
| 129 |
+
local_samples = [all_samples[idx] for idx in indices]
|
| 130 |
+
prompts = [sample.prompt for sample in local_samples]
|
| 131 |
+
refs = [sample.target for sample in local_samples]
|
| 132 |
+
metadata = [sample.metadata for sample in local_samples]
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 136 |
+
str(model_dir),
|
| 137 |
+
trust_remote_code=trust_remote_code,
|
| 138 |
+
cache_dir=str(models_cache),
|
| 139 |
+
local_files_only=True,
|
| 140 |
+
)
|
| 141 |
+
except Exception:
|
| 142 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 143 |
+
model_name_fallback,
|
| 144 |
+
trust_remote_code=trust_remote_code,
|
| 145 |
+
cache_dir=str(models_cache),
|
| 146 |
+
local_files_only=True,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 150 |
+
str(model_dir),
|
| 151 |
+
trust_remote_code=trust_remote_code,
|
| 152 |
+
cache_dir=str(models_cache),
|
| 153 |
+
torch_dtype=dtype,
|
| 154 |
+
local_files_only=True,
|
| 155 |
+
)
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
torch.cuda.set_device(local_rank)
|
| 158 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 159 |
+
else:
|
| 160 |
+
device = torch.device("cpu")
|
| 161 |
+
model.to(device)
|
| 162 |
+
model.eval()
|
| 163 |
+
|
| 164 |
+
if tokenizer.pad_token_id is None:
|
| 165 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 166 |
+
|
| 167 |
+
completions: list[str] = []
|
| 168 |
+
for start in range(0, len(prompts), batch_size):
|
| 169 |
+
batch_prompts = prompts[start : start + batch_size]
|
| 170 |
+
enc = tokenizer(
|
| 171 |
+
batch_prompts,
|
| 172 |
+
return_tensors="pt",
|
| 173 |
+
padding=True,
|
| 174 |
+
truncation=True,
|
| 175 |
+
max_length=max_prompt_len,
|
| 176 |
+
)
|
| 177 |
+
input_ids = enc["input_ids"].to(device)
|
| 178 |
+
attn = enc["attention_mask"].to(device)
|
| 179 |
+
out = model.generate(
|
| 180 |
+
input_ids=input_ids,
|
| 181 |
+
attention_mask=attn,
|
| 182 |
+
max_new_tokens=max_completion_len,
|
| 183 |
+
do_sample=False,
|
| 184 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 185 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 186 |
+
)
|
| 187 |
+
prompt_lens = attn.sum(dim=1).tolist()
|
| 188 |
+
for idx in range(out.size(0)):
|
| 189 |
+
completion_ids = out[idx, int(prompt_lens[idx]) :]
|
| 190 |
+
completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
|
| 191 |
+
|
| 192 |
+
scores = gsm8k_correctness_reward(
|
| 193 |
+
prompts=prompts,
|
| 194 |
+
completions=completions,
|
| 195 |
+
references=refs,
|
| 196 |
+
metadata=metadata,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
local_records: list[dict] = []
|
| 200 |
+
for i, (prompt, reference, completion, score) in enumerate(
|
| 201 |
+
zip(prompts, refs, completions, scores, strict=True)
|
| 202 |
+
):
|
| 203 |
+
local_records.append(
|
| 204 |
+
{
|
| 205 |
+
"sample_index": int(indices[i]),
|
| 206 |
+
"prompt": prompt,
|
| 207 |
+
"reference": reference,
|
| 208 |
+
"completion": completion,
|
| 209 |
+
"correctness": float(score),
|
| 210 |
+
"cot_words": int(_cot_word_len(completion)),
|
| 211 |
+
}
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
del model
|
| 215 |
+
if torch.cuda.is_available():
|
| 216 |
+
torch.cuda.empty_cache()
|
| 217 |
+
|
| 218 |
+
if dist.is_initialized():
|
| 219 |
+
gathered: list[list[dict] | None] = [None for _ in range(world_size)]
|
| 220 |
+
dist.all_gather_object(gathered, local_records)
|
| 221 |
+
merged: list[dict] = []
|
| 222 |
+
for part in gathered:
|
| 223 |
+
if part:
|
| 224 |
+
merged.extend(part)
|
| 225 |
+
else:
|
| 226 |
+
merged = local_records
|
| 227 |
+
|
| 228 |
+
merged.sort(key=lambda row: row["sample_index"])
|
| 229 |
+
return merged
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _summarize(records: list[dict], model_dir: str) -> dict:
|
| 233 |
+
if not records:
|
| 234 |
+
return {
|
| 235 |
+
"name": Path(model_dir).name,
|
| 236 |
+
"model_dir": model_dir,
|
| 237 |
+
"num_examples": 0,
|
| 238 |
+
"accuracy": 0.0,
|
| 239 |
+
"avg_cot_words": 0.0,
|
| 240 |
+
}
|
| 241 |
+
accuracy = sum(float(row["correctness"]) for row in records) / len(records)
|
| 242 |
+
avg_cot = sum(float(row["cot_words"]) for row in records) / len(records)
|
| 243 |
+
return {
|
| 244 |
+
"name": Path(model_dir).name,
|
| 245 |
+
"model_dir": model_dir,
|
| 246 |
+
"num_examples": len(records),
|
| 247 |
+
"accuracy": float(accuracy),
|
| 248 |
+
"avg_cot_words": float(avg_cot),
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _write_accuracy_svg(summaries: list[dict], path: Path) -> None:
|
| 253 |
+
width = 1000
|
| 254 |
+
height = 460
|
| 255 |
+
left_margin = 70
|
| 256 |
+
right_margin = 30
|
| 257 |
+
top_margin = 70
|
| 258 |
+
bottom_margin = 90
|
| 259 |
+
plot_w = width - left_margin - right_margin
|
| 260 |
+
plot_h = height - top_margin - bottom_margin
|
| 261 |
+
y_base = top_margin + plot_h
|
| 262 |
+
|
| 263 |
+
runs = [row["name"] for row in summaries]
|
| 264 |
+
acc_vals = [float(row["accuracy"]) for row in summaries]
|
| 265 |
+
vmax = max(1.0, max(acc_vals) if acc_vals else 1.0)
|
| 266 |
+
|
| 267 |
+
bar_count = max(1, len(runs))
|
| 268 |
+
slot_w = plot_w / bar_count
|
| 269 |
+
bar_w = min(120, max(30, int(slot_w * 0.55)))
|
| 270 |
+
palette = ["#2563eb", "#dc2626", "#16a34a", "#ca8a04", "#7c3aed", "#0891b2"]
|
| 271 |
+
|
| 272 |
+
parts: list[str] = []
|
| 273 |
+
parts.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">')
|
| 274 |
+
parts.append('<rect width="100%" height="100%" fill="#ffffff"/>')
|
| 275 |
+
parts.append(
|
| 276 |
+
'<text x="40" y="34" font-size="20" font-family="sans-serif">Sweep Evaluation: GSM8K Accuracy</text>'
|
| 277 |
+
)
|
| 278 |
+
parts.append(
|
| 279 |
+
f'<line x1="{left_margin}" y1="{y_base}" x2="{left_margin + plot_w}" y2="{y_base}" stroke="#111" stroke-width="2" />'
|
| 280 |
+
)
|
| 281 |
+
parts.append(
|
| 282 |
+
f'<line x1="{left_margin}" y1="{top_margin}" x2="{left_margin}" y2="{y_base}" stroke="#111" stroke-width="2" />'
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# y-axis ticks
|
| 286 |
+
for tick in [0.0, 0.25, 0.5, 0.75, 1.0]:
|
| 287 |
+
y = y_base - int((tick / vmax) * plot_h) if vmax > 0 else y_base
|
| 288 |
+
parts.append(
|
| 289 |
+
f'<line x1="{left_margin - 6}" y1="{y}" x2="{left_margin}" y2="{y}" stroke="#111" stroke-width="1" />'
|
| 290 |
+
)
|
| 291 |
+
parts.append(
|
| 292 |
+
f'<text x="{left_margin - 10}" y="{y + 4}" text-anchor="end" font-size="11" font-family="sans-serif">{tick:.2f}</text>'
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
for idx, (run_name, acc) in enumerate(zip(runs, acc_vals, strict=True)):
|
| 296 |
+
center_x = left_margin + int((idx + 0.5) * slot_w)
|
| 297 |
+
bar_h = int((acc / vmax) * plot_h) if vmax > 0 else 0
|
| 298 |
+
x = center_x - bar_w // 2
|
| 299 |
+
y = y_base - bar_h
|
| 300 |
+
color = palette[idx % len(palette)]
|
| 301 |
+
parts.append(f'<rect x="{x}" y="{y}" width="{bar_w}" height="{bar_h}" fill="{color}" />')
|
| 302 |
+
parts.append(
|
| 303 |
+
f'<text x="{center_x}" y="{y - 8}" text-anchor="middle" font-size="12" font-family="sans-serif">{acc:.3f}</text>'
|
| 304 |
+
)
|
| 305 |
+
parts.append(
|
| 306 |
+
f'<text x="{center_x}" y="{y_base + 18}" text-anchor="middle" font-size="11" font-family="sans-serif">{run_name}</text>'
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
parts.append("</svg>")
|
| 310 |
+
path.write_text("\n".join(parts), encoding="utf-8")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def main() -> None:
|
| 314 |
+
rank, _, _ = _init_distributed()
|
| 315 |
+
base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
|
| 316 |
+
requested_sweep_root = Path(os.environ["SWEEP_ROOT"])
|
| 317 |
+
sweep_root = _resolve_sweep_root(base_cfg, requested_sweep_root)
|
| 318 |
+
if "OUT_ROOT" in os.environ:
|
| 319 |
+
out_root = resolve_repo_path(os.environ["OUT_ROOT"])
|
| 320 |
+
else:
|
| 321 |
+
out_root = (sweep_root / "eval_results").resolve()
|
| 322 |
+
eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
|
| 323 |
+
eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
|
| 324 |
+
|
| 325 |
+
model_dirs = _discover_model_dirs(sweep_root)
|
| 326 |
+
resolved_model_dirs = [_resolve_local_model_dir(base_cfg, str(path)) for path in model_dirs]
|
| 327 |
+
|
| 328 |
+
if rank == 0:
|
| 329 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 330 |
+
(out_root / "outputs").mkdir(parents=True, exist_ok=True)
|
| 331 |
+
|
| 332 |
+
if dist.is_initialized():
|
| 333 |
+
dist.barrier()
|
| 334 |
+
|
| 335 |
+
summaries: list[dict] = []
|
| 336 |
+
for model_dir in resolved_model_dirs:
|
| 337 |
+
records = evaluate_one_model(
|
| 338 |
+
model_dir=model_dir,
|
| 339 |
+
base_cfg=base_cfg,
|
| 340 |
+
eval_max_samples=eval_max_samples,
|
| 341 |
+
batch_size=eval_batch_size,
|
| 342 |
+
)
|
| 343 |
+
if rank == 0:
|
| 344 |
+
output_jsonl = out_root / "outputs" / f"{model_dir.name}_outputs.jsonl"
|
| 345 |
+
with output_jsonl.open("w", encoding="utf-8") as handle:
|
| 346 |
+
for row in records:
|
| 347 |
+
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 348 |
+
summary = _summarize(records, str(model_dir))
|
| 349 |
+
summary["outputs_jsonl"] = str(output_jsonl)
|
| 350 |
+
summaries.append(summary)
|
| 351 |
+
|
| 352 |
+
if dist.is_initialized():
|
| 353 |
+
dist.barrier()
|
| 354 |
+
|
| 355 |
+
if rank != 0:
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
json_path = out_root / "sweep_eval_summary.json"
|
| 359 |
+
csv_path = out_root / "sweep_eval_summary.csv"
|
| 360 |
+
svg_path = out_root / "sweep_eval_accuracy.svg"
|
| 361 |
+
json_path.write_text(json.dumps(summaries, indent=2), encoding="utf-8")
|
| 362 |
+
with csv_path.open("w", encoding="utf-8", newline="") as handle:
|
| 363 |
+
writer = csv.DictWriter(
|
| 364 |
+
handle,
|
| 365 |
+
fieldnames=[
|
| 366 |
+
"name",
|
| 367 |
+
"model_dir",
|
| 368 |
+
"num_examples",
|
| 369 |
+
"accuracy",
|
| 370 |
+
"avg_cot_words",
|
| 371 |
+
"outputs_jsonl",
|
| 372 |
+
],
|
| 373 |
+
)
|
| 374 |
+
writer.writeheader()
|
| 375 |
+
for row in summaries:
|
| 376 |
+
writer.writerow(row)
|
| 377 |
+
_write_accuracy_svg(summaries, svg_path)
|
| 378 |
+
|
| 379 |
+
print(f"Saved summary: {json_path}")
|
| 380 |
+
print(f"Saved summary: {csv_path}")
|
| 381 |
+
print(f"Saved plot: {svg_path}")
|
| 382 |
+
print(f"Saved outputs dir: {out_root / 'outputs'}")
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
main()
|
src/hackable/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hackable GRPO training primitives."""
|
| 2 |
+
|
| 3 |
+
# Ensure default plugins register on import.
|
| 4 |
+
from . import data_plugins as _data_plugins # noqa: F401
|
| 5 |
+
from . import objectives as _objectives # noqa: F401
|
| 6 |
+
from . import reward_plugins as _reward_plugins # noqa: F401
|
src/hackable/backends.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib.util
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_model_and_tokenizer(
|
| 10 |
+
model_name: str,
|
| 11 |
+
trust_remote_code: bool = False,
|
| 12 |
+
cache_dir: str | None = None,
|
| 13 |
+
load_in_4bit: bool = False,
|
| 14 |
+
torch_dtype: str = "bfloat16",
|
| 15 |
+
):
|
| 16 |
+
del load_in_4bit
|
| 17 |
+
dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16
|
| 18 |
+
|
| 19 |
+
# Apply Liger kernels before constructing Llama models.
|
| 20 |
+
if "llama" in model_name.lower():
|
| 21 |
+
try:
|
| 22 |
+
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
| 23 |
+
except Exception as exc:
|
| 24 |
+
raise RuntimeError(
|
| 25 |
+
"Failed to import Liger kernel patcher for Llama. "
|
| 26 |
+
"Install liger-kernel in the runtime environment."
|
| 27 |
+
) from exc
|
| 28 |
+
apply_liger_kernel_to_llama()
|
| 29 |
+
|
| 30 |
+
attn_impl = "sdpa"
|
| 31 |
+
if importlib.util.find_spec("flash_attn") is not None:
|
| 32 |
+
try:
|
| 33 |
+
__import__("flash_attn")
|
| 34 |
+
attn_impl = "flash_attention_2"
|
| 35 |
+
except Exception:
|
| 36 |
+
attn_impl = "sdpa"
|
| 37 |
+
|
| 38 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 39 |
+
|
| 40 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 41 |
+
model_name,
|
| 42 |
+
trust_remote_code=trust_remote_code,
|
| 43 |
+
cache_dir=cache_dir,
|
| 44 |
+
)
|
| 45 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
model_name,
|
| 47 |
+
trust_remote_code=trust_remote_code,
|
| 48 |
+
cache_dir=cache_dir,
|
| 49 |
+
dtype=dtype,
|
| 50 |
+
attn_implementation=attn_impl,
|
| 51 |
+
)
|
| 52 |
+
return model, tokenizer, "transformers"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def generation_kwargs(cfg: Any) -> dict[str, Any]:
|
| 56 |
+
return {
|
| 57 |
+
"max_prompt_length": cfg.max_prompt_length,
|
| 58 |
+
"max_completion_length": cfg.max_completion_length,
|
| 59 |
+
"num_generations": cfg.num_generations,
|
| 60 |
+
"temperature": cfg.temperature,
|
| 61 |
+
"top_p": cfg.top_p,
|
| 62 |
+
}
|
src/hackable/config.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
from .utils import resolve_repo_path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ModelConfig:
|
| 14 |
+
name: str
|
| 15 |
+
trust_remote_code: bool = False
|
| 16 |
+
load_in_4bit: bool = False
|
| 17 |
+
use_lora_adapters: bool = False
|
| 18 |
+
lora_r: int = 16
|
| 19 |
+
lora_alpha: int = 16
|
| 20 |
+
lora_dropout: float = 0.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TrainerConfig:
|
| 25 |
+
output_dir: str
|
| 26 |
+
run_name: str = "grpo-run"
|
| 27 |
+
max_steps: int = -1
|
| 28 |
+
num_train_epochs: float = 1.0
|
| 29 |
+
per_device_train_batch_size: int = 1
|
| 30 |
+
gradient_accumulation_steps: int = 8
|
| 31 |
+
learning_rate: float = 1.0e-6
|
| 32 |
+
logging_steps: int = 1
|
| 33 |
+
save_steps: int = 25
|
| 34 |
+
save_total_limit: int = 5
|
| 35 |
+
bf16: bool = True
|
| 36 |
+
seed: int = 42
|
| 37 |
+
report_to: str = "wandb"
|
| 38 |
+
optim: str = "adamw_torch"
|
| 39 |
+
gradient_checkpointing: bool = True
|
| 40 |
+
max_grad_norm: float = 1.0
|
| 41 |
+
shuffle_dataset: bool = False
|
| 42 |
+
lr_scheduler_type: str = "cosine"
|
| 43 |
+
lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 44 |
+
warmup_steps: int = 20
|
| 45 |
+
sanity_log_examples: int = 8
|
| 46 |
+
sanity_log_max_chars: int = 300
|
| 47 |
+
permanent_checkpoint_steps: int = 300
|
| 48 |
+
permanent_checkpoint_dir: str = "checkpoints/permanent"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class DataConfig:
|
| 54 |
+
provider: str = "gsm8k_math_curriculum"
|
| 55 |
+
split: str = "train"
|
| 56 |
+
max_samples: int | None = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class GenerationConfig:
|
| 61 |
+
max_prompt_length: int = 512
|
| 62 |
+
max_completion_length: int = 256
|
| 63 |
+
num_generations: int = 4
|
| 64 |
+
temperature: float = 0.9
|
| 65 |
+
top_p: float = 0.95
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class ObjectiveConfig:
|
| 70 |
+
name: str = "token_grpo"
|
| 71 |
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
| 72 |
+
class_path: str | None = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class RewardsConfig:
|
| 77 |
+
kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class AuthConfig:
|
| 82 |
+
hf_api_key: str | None = None
|
| 83 |
+
wandb_api_key: str | None = None
|
| 84 |
+
hf_api_key_env: str = "HF_TOKEN"
|
| 85 |
+
wandb_api_key_env: str = "WANDB_API_KEY"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class StorageConfig:
|
| 90 |
+
cache_dir: str = "cache"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class ThinkingKLConfig:
|
| 95 |
+
"""Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body."""
|
| 96 |
+
|
| 97 |
+
inner_kl_weight: float = 1.0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class ExperimentConfig:
|
| 102 |
+
model: ModelConfig
|
| 103 |
+
trainer: TrainerConfig
|
| 104 |
+
data: DataConfig = field(default_factory=DataConfig)
|
| 105 |
+
generation: GenerationConfig = field(default_factory=GenerationConfig)
|
| 106 |
+
objective: ObjectiveConfig = field(default_factory=ObjectiveConfig)
|
| 107 |
+
rewards: RewardsConfig = field(default_factory=RewardsConfig)
|
| 108 |
+
auth: AuthConfig = field(default_factory=AuthConfig)
|
| 109 |
+
storage: StorageConfig = field(default_factory=StorageConfig)
|
| 110 |
+
thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig)
|
| 111 |
+
grpo: dict[str, Any] = field(default_factory=dict)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def load_config(path: str | Path) -> ExperimentConfig:
|
| 115 |
+
resolved = resolve_repo_path(path)
|
| 116 |
+
with resolved.open("r", encoding="utf-8") as handle:
|
| 117 |
+
raw = yaml.safe_load(handle)
|
| 118 |
+
|
| 119 |
+
trainer_raw = raw["trainer"]
|
| 120 |
+
# Backward-compatible alias: allow "optimizer" in YAML.
|
| 121 |
+
if "optim" not in trainer_raw and "optimizer" in trainer_raw:
|
| 122 |
+
trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]}
|
| 123 |
+
if "optimizer" in trainer_raw:
|
| 124 |
+
trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"}
|
| 125 |
+
# Accept common shorthand names.
|
| 126 |
+
optim_aliases = {
|
| 127 |
+
"adamw": "adamw_torch",
|
| 128 |
+
"adamw_fused": "adamw_torch_fused",
|
| 129 |
+
}
|
| 130 |
+
scheduler_aliases = {
|
| 131 |
+
"cosine_decay": "cosine",
|
| 132 |
+
}
|
| 133 |
+
if "optim" in trainer_raw:
|
| 134 |
+
trainer_raw = {
|
| 135 |
+
**trainer_raw,
|
| 136 |
+
"optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]),
|
| 137 |
+
}
|
| 138 |
+
if "lr_scheduler_type" in trainer_raw:
|
| 139 |
+
trainer_raw = {
|
| 140 |
+
**trainer_raw,
|
| 141 |
+
"lr_scheduler_type": scheduler_aliases.get(
|
| 142 |
+
trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"]
|
| 143 |
+
),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Normalize numeric fields that may come from YAML as strings.
|
| 147 |
+
float_fields = {
|
| 148 |
+
"learning_rate",
|
| 149 |
+
"max_grad_norm",
|
| 150 |
+
"num_train_epochs",
|
| 151 |
+
}
|
| 152 |
+
int_fields = {
|
| 153 |
+
"max_steps",
|
| 154 |
+
"per_device_train_batch_size",
|
| 155 |
+
"gradient_accumulation_steps",
|
| 156 |
+
"logging_steps",
|
| 157 |
+
"save_steps",
|
| 158 |
+
"save_total_limit",
|
| 159 |
+
"seed",
|
| 160 |
+
"warmup_steps",
|
| 161 |
+
"sanity_log_examples",
|
| 162 |
+
"sanity_log_max_chars",
|
| 163 |
+
"permanent_checkpoint_steps",
|
| 164 |
+
}
|
| 165 |
+
for key in float_fields:
|
| 166 |
+
if key in trainer_raw:
|
| 167 |
+
trainer_raw = {**trainer_raw, key: float(trainer_raw[key])}
|
| 168 |
+
for key in int_fields:
|
| 169 |
+
if key in trainer_raw:
|
| 170 |
+
trainer_raw = {**trainer_raw, key: int(trainer_raw[key])}
|
| 171 |
+
|
| 172 |
+
return ExperimentConfig(
|
| 173 |
+
model=ModelConfig(**raw["model"]),
|
| 174 |
+
trainer=TrainerConfig(**trainer_raw),
|
| 175 |
+
data=DataConfig(**raw.get("data", {})),
|
| 176 |
+
generation=GenerationConfig(**raw.get("generation", {})),
|
| 177 |
+
objective=ObjectiveConfig(**raw.get("objective", {})),
|
| 178 |
+
rewards=RewardsConfig(**raw.get("rewards", {})),
|
| 179 |
+
auth=AuthConfig(**raw.get("auth", {})),
|
| 180 |
+
storage=StorageConfig(**raw.get("storage", {})),
|
| 181 |
+
thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})),
|
| 182 |
+
grpo=dict(raw.get("grpo", {})),
|
| 183 |
+
)
|
src/hackable/data_plugins.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from dataclasses import asdict
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from .interfaces import TrainingSample
|
| 8 |
+
from .registry import register_data_provider
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@contextmanager
|
| 12 |
+
def _serialized_hf_dataset_download():
|
| 13 |
+
"""
|
| 14 |
+
Serialize Hugging Face ``datasets`` downloads/prepare across processes.
|
| 15 |
+
|
| 16 |
+
Multi-GPU ``accelerate launch`` otherwise races on the same ``cache_dir`` and can
|
| 17 |
+
leave a half-written tree (e.g. missing ``dataset_info.json``).
|
| 18 |
+
"""
|
| 19 |
+
root = Path.home() / ".cache" / "neuralese"
|
| 20 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
lock_path = root / "hf_dataset_download.lock"
|
| 22 |
+
try:
|
| 23 |
+
from filelock import FileLock
|
| 24 |
+
|
| 25 |
+
with FileLock(str(lock_path), timeout=7200):
|
| 26 |
+
yield
|
| 27 |
+
except ImportError:
|
| 28 |
+
yield
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _load_hf_split(
|
| 32 |
+
path: str,
|
| 33 |
+
split: str,
|
| 34 |
+
cache_dir: str | None,
|
| 35 |
+
config_name: str | None = None,
|
| 36 |
+
):
|
| 37 |
+
from datasets import load_dataset
|
| 38 |
+
|
| 39 |
+
with _serialized_hf_dataset_download():
|
| 40 |
+
try:
|
| 41 |
+
if config_name is not None:
|
| 42 |
+
return load_dataset(path, config_name, split=split, cache_dir=cache_dir)
|
| 43 |
+
return load_dataset(path, split=split, cache_dir=cache_dir)
|
| 44 |
+
except FileNotFoundError as exc:
|
| 45 |
+
if not cache_dir:
|
| 46 |
+
raise RuntimeError(
|
| 47 |
+
"Hugging Face dataset files are missing from the default cache. "
|
| 48 |
+
"Run once with HF_DATASETS_OFFLINE=0 (or download the dataset), "
|
| 49 |
+
"or set HF_HOME / HF_DATASETS_CACHE to a populated cache."
|
| 50 |
+
) from exc
|
| 51 |
+
try:
|
| 52 |
+
if config_name is not None:
|
| 53 |
+
return load_dataset(path, config_name, split=split, cache_dir=None)
|
| 54 |
+
return load_dataset(path, split=split, cache_dir=None)
|
| 55 |
+
except FileNotFoundError as exc2:
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"Could not load the dataset from the experiment cache_dir or the default HF cache. "
|
| 58 |
+
"Seed ~/.cache/huggingface/datasets (or your HF_HOME) with HF_DATASETS_OFFLINE=0, "
|
| 59 |
+
"or point storage.cache_dir at a shared cache that already contains the dataset."
|
| 60 |
+
) from exc2
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
PROMPT_PREFIX = (
|
| 64 |
+
"Solve the following math problem.\n"
|
| 65 |
+
"Think step-by-step inside <think>...</think> tags.\n"
|
| 66 |
+
"Then output only the final answer in LaTeX boxed format.\n"
|
| 67 |
+
"Do not include any words or explanations outside the tags/boxed answer.\n"
|
| 68 |
+
"Output format must be exactly:\n"
|
| 69 |
+
"<think>your reasoning</think>\n"
|
| 70 |
+
"\\boxed{your_final_answer}\n\n"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _build_math_prompt(question: str) -> str:
|
| 75 |
+
user_content = f"{PROMPT_PREFIX}Question: {question}"
|
| 76 |
+
# Chat-style prefill so decoding starts after "assistant:".
|
| 77 |
+
return f"user: {user_content}\nassistant:"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _interleave_samples(
|
| 81 |
+
left: list[TrainingSample], right: list[TrainingSample]
|
| 82 |
+
) -> list[TrainingSample]:
|
| 83 |
+
output: list[TrainingSample] = []
|
| 84 |
+
width = max(len(left), len(right))
|
| 85 |
+
for idx in range(width):
|
| 86 |
+
if idx < len(left):
|
| 87 |
+
output.append(left[idx])
|
| 88 |
+
if idx < len(right):
|
| 89 |
+
output.append(right[idx])
|
| 90 |
+
return output
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _slice_if_needed(
|
| 94 |
+
samples: list[TrainingSample], max_samples: int | None
|
| 95 |
+
) -> list[TrainingSample]:
|
| 96 |
+
if max_samples is None:
|
| 97 |
+
return samples
|
| 98 |
+
return samples[: max(0, max_samples)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class _MathProviderBase:
|
| 102 |
+
dataset_name = "EleutherAI/hendrycks_math"
|
| 103 |
+
dataset_configs = (
|
| 104 |
+
"algebra",
|
| 105 |
+
"counting_and_probability",
|
| 106 |
+
"geometry",
|
| 107 |
+
"intermediate_algebra",
|
| 108 |
+
"number_theory",
|
| 109 |
+
"prealgebra",
|
| 110 |
+
"precalculus",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def __init__(self, levels: tuple[str, ...]):
|
| 114 |
+
self.levels = levels
|
| 115 |
+
|
| 116 |
+
def load(
|
| 117 |
+
self,
|
| 118 |
+
split: str,
|
| 119 |
+
max_samples: int | None = None,
|
| 120 |
+
cache_dir: str | None = None,
|
| 121 |
+
) -> list[TrainingSample]:
|
| 122 |
+
try:
|
| 123 |
+
import datasets # noqa: F401
|
| 124 |
+
except Exception as exc:
|
| 125 |
+
raise RuntimeError(
|
| 126 |
+
"datasets is required for Hendrycks MATH providers. Install dependencies first."
|
| 127 |
+
) from exc
|
| 128 |
+
|
| 129 |
+
level_set = {level.strip() for level in self.levels}
|
| 130 |
+
|
| 131 |
+
output: list[TrainingSample] = []
|
| 132 |
+
for config_name in self.dataset_configs:
|
| 133 |
+
rows = _load_hf_split(
|
| 134 |
+
self.dataset_name,
|
| 135 |
+
split,
|
| 136 |
+
cache_dir,
|
| 137 |
+
config_name=config_name,
|
| 138 |
+
)
|
| 139 |
+
for row in rows:
|
| 140 |
+
level = str(row.get("level", "")).strip()
|
| 141 |
+
if level not in level_set:
|
| 142 |
+
continue
|
| 143 |
+
question = str(row.get("problem", ""))
|
| 144 |
+
target = str(row.get("solution", ""))
|
| 145 |
+
output.append(
|
| 146 |
+
TrainingSample(
|
| 147 |
+
prompt=_build_math_prompt(question),
|
| 148 |
+
target=target,
|
| 149 |
+
metadata={
|
| 150 |
+
"dataset": "hendrycks_math",
|
| 151 |
+
"subject": config_name,
|
| 152 |
+
"level": level,
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
if max_samples is not None and len(output) >= max_samples:
|
| 157 |
+
return output
|
| 158 |
+
return output
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@register_data_provider("gsm8k")
|
| 162 |
+
class GSM8KProvider:
|
| 163 |
+
def __init__(self, dataset_name: str = "openai/gsm8k", subset: str = "main"):
|
| 164 |
+
self.dataset_name = dataset_name
|
| 165 |
+
self.subset = subset
|
| 166 |
+
|
| 167 |
+
def load(
|
| 168 |
+
self,
|
| 169 |
+
split: str,
|
| 170 |
+
max_samples: int | None = None,
|
| 171 |
+
cache_dir: str | None = None,
|
| 172 |
+
) -> list[TrainingSample]:
|
| 173 |
+
try:
|
| 174 |
+
import datasets # noqa: F401
|
| 175 |
+
except Exception as exc:
|
| 176 |
+
raise RuntimeError(
|
| 177 |
+
"datasets is required for GSM8K provider. Install dependencies first."
|
| 178 |
+
) from exc
|
| 179 |
+
|
| 180 |
+
rows = _load_hf_split(
|
| 181 |
+
self.dataset_name,
|
| 182 |
+
split,
|
| 183 |
+
cache_dir,
|
| 184 |
+
config_name=self.subset,
|
| 185 |
+
)
|
| 186 |
+
if max_samples is not None:
|
| 187 |
+
rows = rows.select(range(min(max_samples, len(rows))))
|
| 188 |
+
|
| 189 |
+
output: list[TrainingSample] = []
|
| 190 |
+
for sample_index, row in enumerate(rows):
|
| 191 |
+
prompt = _build_math_prompt(str(row["question"]))
|
| 192 |
+
output.append(
|
| 193 |
+
TrainingSample(
|
| 194 |
+
prompt=prompt,
|
| 195 |
+
target=row["answer"],
|
| 196 |
+
metadata={
|
| 197 |
+
"dataset": "gsm8k",
|
| 198 |
+
"sample_index": int(sample_index),
|
| 199 |
+
"split": str(split),
|
| 200 |
+
},
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
return output
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@register_data_provider("math_level_1")
|
| 207 |
+
class MathLevel1Provider(_MathProviderBase):
|
| 208 |
+
def __init__(self):
|
| 209 |
+
super().__init__(levels=("Level 1",))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@register_data_provider("math_level_2")
|
| 213 |
+
class MathLevel2Provider(_MathProviderBase):
|
| 214 |
+
def __init__(self):
|
| 215 |
+
super().__init__(levels=("Level 2",))
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@register_data_provider("math_level_3")
|
| 219 |
+
class MathLevel3Provider(_MathProviderBase):
|
| 220 |
+
def __init__(self):
|
| 221 |
+
super().__init__(levels=("Level 3",))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@register_data_provider("math_level_4")
|
| 225 |
+
class MathLevel4Provider(_MathProviderBase):
|
| 226 |
+
def __init__(self):
|
| 227 |
+
super().__init__(levels=("Level 4",))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@register_data_provider("math_level_5")
|
| 231 |
+
class MathLevel5Provider(_MathProviderBase):
|
| 232 |
+
def __init__(self):
|
| 233 |
+
super().__init__(levels=("Level 5",))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@register_data_provider("math_levels_12")
|
| 237 |
+
class MathLevels12Provider(_MathProviderBase):
|
| 238 |
+
def __init__(self):
|
| 239 |
+
super().__init__(levels=("Level 1", "Level 2"))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@register_data_provider("math_levels_345")
|
| 243 |
+
class MathLevels345Provider(_MathProviderBase):
|
| 244 |
+
def __init__(self):
|
| 245 |
+
super().__init__(levels=("Level 3", "Level 4", "Level 5"))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@register_data_provider("gsm8k_math_stage12")
|
| 249 |
+
class GSM8KMathStage12Provider:
|
| 250 |
+
def load(
|
| 251 |
+
self,
|
| 252 |
+
split: str,
|
| 253 |
+
max_samples: int | None = None,
|
| 254 |
+
cache_dir: str | None = None,
|
| 255 |
+
) -> list[TrainingSample]:
|
| 256 |
+
gsm = GSM8KProvider().load(split=split, max_samples=None, cache_dir=cache_dir)
|
| 257 |
+
math12 = MathLevels12Provider().load(
|
| 258 |
+
split=split, max_samples=None, cache_dir=cache_dir
|
| 259 |
+
)
|
| 260 |
+
mixed = _interleave_samples(gsm, math12)
|
| 261 |
+
return _slice_if_needed(mixed, max_samples)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@register_data_provider("gsm8k_math_curriculum")
|
| 265 |
+
class GSM8KMathCurriculumProvider:
|
| 266 |
+
def load(
|
| 267 |
+
self,
|
| 268 |
+
split: str,
|
| 269 |
+
max_samples: int | None = None,
|
| 270 |
+
cache_dir: str | None = None,
|
| 271 |
+
) -> list[TrainingSample]:
|
| 272 |
+
if max_samples is None:
|
| 273 |
+
stage12_budget = None
|
| 274 |
+
stage345_budget = None
|
| 275 |
+
else:
|
| 276 |
+
stage12_budget = (max_samples + 1) // 2
|
| 277 |
+
stage345_budget = max_samples // 2
|
| 278 |
+
|
| 279 |
+
stage12 = GSM8KMathStage12Provider().load(
|
| 280 |
+
split=split, max_samples=stage12_budget, cache_dir=cache_dir
|
| 281 |
+
)
|
| 282 |
+
stage345 = MathLevels345Provider().load(
|
| 283 |
+
split=split, max_samples=stage345_budget, cache_dir=cache_dir
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Curriculum order: first easier mixed set, then harder levels.
|
| 287 |
+
return stage12 + stage345
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def to_dataset_rows(samples: list[TrainingSample]) -> list[dict]:
|
| 291 |
+
return [asdict(sample) for sample in samples]
|
src/hackable/interfaces.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Protocol
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class TrainingSample:
|
| 9 |
+
prompt: str
|
| 10 |
+
target: str
|
| 11 |
+
metadata: dict[str, Any]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataProvider(Protocol):
|
| 15 |
+
def load(
|
| 16 |
+
self,
|
| 17 |
+
split: str,
|
| 18 |
+
max_samples: int | None = None,
|
| 19 |
+
cache_dir: str | None = None,
|
| 20 |
+
) -> list[TrainingSample]:
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RewardFunction(Protocol):
|
| 25 |
+
def __call__(
|
| 26 |
+
self,
|
| 27 |
+
prompts: list[str],
|
| 28 |
+
completions: list[str],
|
| 29 |
+
references: list[str],
|
| 30 |
+
metadata: list[dict[str, Any]],
|
| 31 |
+
) -> list[float]:
|
| 32 |
+
...
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ObjectiveModule(Protocol):
|
| 36 |
+
name: str
|
| 37 |
+
|
| 38 |
+
def reward_names(self) -> list[str]:
|
| 39 |
+
...
|
| 40 |
+
|
| 41 |
+
def extra_reward(
|
| 42 |
+
self,
|
| 43 |
+
prompts: list[str],
|
| 44 |
+
completions: list[str],
|
| 45 |
+
references: list[str],
|
| 46 |
+
metadata: list[dict[str, Any]],
|
| 47 |
+
) -> list[float]:
|
| 48 |
+
...
|