psidharth567 commited on
Commit
dbc69f3
·
0 Parent(s):

Export neuralese codebase (cache and .env excluded).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +108 -0
  2. .gitignore +15 -0
  3. AGENTS.md +56 -0
  4. README.md +164 -0
  5. configs/accelerate_ddp_2gpu.yaml +17 -0
  6. configs/accelerate_ddp_4gpu.yaml +17 -0
  7. configs/accelerate_ddp_7gpu.yaml +17 -0
  8. configs/accelerate_ddp_8gpu.yaml +17 -0
  9. configs/grpo_llama32_3b_bf16.yaml +70 -0
  10. docs/repository-map.md +50 -0
  11. out.jsonl +0 -0
  12. requirements.txt +9 -0
  13. scripts/conda_env.sh +40 -0
  14. scripts/eval_gsm8k_think_entropy_4gpu.sh +42 -0
  15. scripts/eval_gsm8k_truncated_thinking.sh +51 -0
  16. scripts/eval_gsm8k_zeroshot_4gpu.sh +36 -0
  17. scripts/eval_gsm8k_zeroshot_train_4gpu.sh +54 -0
  18. scripts/eval_length_penalty_ablation_offline.sh +39 -0
  19. scripts/eval_math_level1_4gpu.sh +55 -0
  20. scripts/eval_math_level1_thinking_zeroshot_4gpu.sh +48 -0
  21. scripts/eval_math_level1_zeroshot_4gpu.sh +49 -0
  22. scripts/eval_math_level2_thinking_zeroshot_4gpu.sh +48 -0
  23. scripts/eval_math_level2_zeroshot_4gpu.sh +49 -0
  24. scripts/eval_permanent_root_acc_cot.sh +53 -0
  25. scripts/eval_sweep_models_offline.sh +34 -0
  26. scripts/eval_twostage_permanent_checkpoints.sh +47 -0
  27. scripts/hf_download_repo.py +72 -0
  28. scripts/hf_upload_repo.py +128 -0
  29. scripts/resume_grpo_8gpu.sh +27 -0
  30. scripts/run_grpo.sh +19 -0
  31. scripts/run_grpo_2gpu.sh +22 -0
  32. scripts/run_grpo_4gpu.sh +22 -0
  33. scripts/run_grpo_8gpu.sh +25 -0
  34. scripts/run_grpo_thinking_kl_masked_resume_4gpu.sh +139 -0
  35. scripts/run_lambda_0p1_existing_gate_token_util.sh +107 -0
  36. scripts/run_reward_variants_and_eval.sh +146 -0
  37. scripts/run_sft_gsm8k_boxed_7gpu.sh +44 -0
  38. scripts/run_twostage_correctness1.sh +79 -0
  39. scripts/run_twostage_correctness5.sh +79 -0
  40. scripts/sweep_length_penalty_lambda.sh +88 -0
  41. src/eval_gsm8k_truncated_thinking.py +321 -0
  42. src/eval_gsm8k_zeroshot.py +76 -0
  43. src/eval_math_level1_thinking_zeroshot.py +312 -0
  44. src/eval_permanent_checkpoints.py +433 -0
  45. src/eval_sweep_models.py +386 -0
  46. src/hackable/__init__.py +6 -0
  47. src/hackable/backends.py +62 -0
  48. src/hackable/config.py +183 -0
  49. src/hackable/data_plugins.py +291 -0
  50. src/hackable/interfaces.py +48 -0
.gitattributes ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.avro filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.lz4 filter=lfs diff=lfs merge=lfs -text
13
+ *.mds filter=lfs diff=lfs merge=lfs -text
14
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
15
+ *.model filter=lfs diff=lfs merge=lfs -text
16
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
17
+ *.npy filter=lfs diff=lfs merge=lfs -text
18
+ *.npz filter=lfs diff=lfs merge=lfs -text
19
+ *.onnx filter=lfs diff=lfs merge=lfs -text
20
+ *.ot filter=lfs diff=lfs merge=lfs -text
21
+ *.parquet filter=lfs diff=lfs merge=lfs -text
22
+ *.pb filter=lfs diff=lfs merge=lfs -text
23
+ *.pickle filter=lfs diff=lfs merge=lfs -text
24
+ *.pkl filter=lfs diff=lfs merge=lfs -text
25
+ *.pt filter=lfs diff=lfs merge=lfs -text
26
+ *.pth filter=lfs diff=lfs merge=lfs -text
27
+ *.rar filter=lfs diff=lfs merge=lfs -text
28
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
29
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
31
+ *.tar filter=lfs diff=lfs merge=lfs -text
32
+ *.tflite filter=lfs diff=lfs merge=lfs -text
33
+ *.tgz filter=lfs diff=lfs merge=lfs -text
34
+ *.wasm filter=lfs diff=lfs merge=lfs -text
35
+ *.xz filter=lfs diff=lfs merge=lfs -text
36
+ *.zip filter=lfs diff=lfs merge=lfs -text
37
+ *.zst filter=lfs diff=lfs merge=lfs -text
38
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
39
+ # Audio files - uncompressed
40
+ *.pcm filter=lfs diff=lfs merge=lfs -text
41
+ *.sam filter=lfs diff=lfs merge=lfs -text
42
+ *.raw filter=lfs diff=lfs merge=lfs -text
43
+ # Audio files - compressed
44
+ *.aac filter=lfs diff=lfs merge=lfs -text
45
+ *.flac filter=lfs diff=lfs merge=lfs -text
46
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
47
+ *.ogg filter=lfs diff=lfs merge=lfs -text
48
+ *.wav filter=lfs diff=lfs merge=lfs -text
49
+ # Image files - uncompressed
50
+ *.bmp filter=lfs diff=lfs merge=lfs -text
51
+ *.gif filter=lfs diff=lfs merge=lfs -text
52
+ *.png filter=lfs diff=lfs merge=lfs -text
53
+ *.tiff filter=lfs diff=lfs merge=lfs -text
54
+ # Image files - compressed
55
+ *.jpg filter=lfs diff=lfs merge=lfs -text
56
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
57
+ *.webp filter=lfs diff=lfs merge=lfs -text
58
+ # Video files - compressed
59
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ *.webm filter=lfs diff=lfs merge=lfs -text
61
+ cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/blobs/9151ba0c695e63ba3dea7436a55bb6aa9f2d0a3d7b7eead62086959354bd6c67 filter=lfs diff=lfs merge=lfs -text
62
+ cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/blobs/6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b filter=lfs diff=lfs merge=lfs -text
63
+ cache/models/models--unsloth--llama-3.2-3b-instruct-unsloth-bnb-4bit/snapshots/19846d3f624f3eb96f3bdd275620c6bc7e21e1f8/tokenizer.json filter=lfs diff=lfs merge=lfs -text
64
+ cache/artifacts/runs/llama32-3b-grpo/tokenizer.json filter=lfs diff=lfs merge=lfs -text
65
+ cache/artifacts/runs/llama32-3b-grpo/checkpoint-275/tokenizer.json filter=lfs diff=lfs merge=lfs -text
66
+ cache/cache/artifacts/sweeps/reward_variants_lambda_0p1/run_lambda_0p1_existing_gate_token_util/checkpoint-1350/tokenizer.json filter=lfs diff=lfs merge=lfs -text
67
+ cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/tokenizer.json filter=lfs diff=lfs merge=lfs -text
68
+ cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/checkpoint-3736/tokenizer.json filter=lfs diff=lfs merge=lfs -text
69
+ cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/permanent_checkpoints/checkpoint-stage1-boundary-epoch-1p0-step-1868/tokenizer.json filter=lfs diff=lfs merge=lfs -text
70
+ cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoint-2925/tokenizer.json filter=lfs diff=lfs merge=lfs -text
71
+ cache/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/permanent_checkpoints/checkpoint-stage1-boundary-epoch-1p0-step-1868/tokenizer.json filter=lfs diff=lfs merge=lfs -text
72
+ cache/wandb/wandb/offline-run-20260218_164737-w4awch9k/run-w4awch9k.wandb filter=lfs diff=lfs merge=lfs -text
73
+ cache/wandb/wandb/offline-run-20260218_172032-b9o54gvf/run-b9o54gvf.wandb filter=lfs diff=lfs merge=lfs -text
74
+ cache/wandb/wandb/offline-run-20260218_170233-ydg9el7o/run-ydg9el7o.wandb filter=lfs diff=lfs merge=lfs -text
75
+ cache/wandb/wandb/offline-run-20260218_165455-zj6y61aw/run-zj6y61aw.wandb filter=lfs diff=lfs merge=lfs -text
76
+ cache/wandb/wandb/offline-run-20260326_035115-p68cj9a2/run-p68cj9a2.wandb filter=lfs diff=lfs merge=lfs -text
77
+ cache/wandb/wandb/offline-run-20260331_190529-w0p8d22o/run-w0p8d22o.wandb filter=lfs diff=lfs merge=lfs -text
78
+ cache/wandb/wandb/offline-run-20260326_235622-lo899m3z/run-lo899m3z.wandb filter=lfs diff=lfs merge=lfs -text
79
+ cache/wandb/wandb/run-20260218_150224-l50vxprm/run-l50vxprm.wandb filter=lfs diff=lfs merge=lfs -text
80
+ cache/wandb/wandb/offline-run-20260327_235926-kx2wc1tp/run-kx2wc1tp.wandb filter=lfs diff=lfs merge=lfs -text
81
+ cache/wandb/wandb/offline-run-20260329_125618-21yt0wn1/run-21yt0wn1.wandb filter=lfs diff=lfs merge=lfs -text
82
+ cache/wandb/wandb/run-20260218_145059-onpmdiea/run-onpmdiea.wandb filter=lfs diff=lfs merge=lfs -text
83
+ cache/wandb/wandb/offline-run-20260218_172557-kst72vbj/run-kst72vbj.wandb filter=lfs diff=lfs merge=lfs -text
84
+ cache/wandb/wandb/offline-run-20260325_203352-8ncdpjq1/run-8ncdpjq1.wandb filter=lfs diff=lfs merge=lfs -text
85
+ cache/wandb/wandb/offline-run-20260329_125618-nuzntjga/run-nuzntjga.wandb filter=lfs diff=lfs merge=lfs -text
86
+ cache/wandb/wandb/offline-run-20260328_025352-z84jpmb8/run-z84jpmb8.wandb filter=lfs diff=lfs merge=lfs -text
87
+ cache/wandb/wandb/offline-run-20260328_142710-or2vd7o6/run-or2vd7o6.wandb filter=lfs diff=lfs merge=lfs -text
88
+ cache/wandb/wandb/offline-run-20260329_124408-koirzdhw/run-koirzdhw.wandb filter=lfs diff=lfs merge=lfs -text
89
+ cache/wandb/wandb/offline-run-20260325_052529-dz6rih26/run-dz6rih26.wandb filter=lfs diff=lfs merge=lfs -text
90
+ cache/wandb/wandb/offline-run-20260325_015712-7ialnk1w/run-7ialnk1w.wandb filter=lfs diff=lfs merge=lfs -text
91
+ cache/wandb/wandb/offline-run-20260327_235600-s2pr0n6l/run-s2pr0n6l.wandb filter=lfs diff=lfs merge=lfs -text
92
+ cache/wandb/wandb/offline-run-20260218_154633-igplmjku/run-igplmjku.wandb filter=lfs diff=lfs merge=lfs -text
93
+ cache/wandb/wandb/offline-run-20260218_204810-jg5x3bn6/run-jg5x3bn6.wandb filter=lfs diff=lfs merge=lfs -text
94
+ cache/wandb/wandb/offline-run-20260218_205814-3cvhbicy/run-3cvhbicy.wandb filter=lfs diff=lfs merge=lfs -text
95
+ cache/wandb/wandb/offline-run-20260326_000646-w5fgyzj3/run-w5fgyzj3.wandb filter=lfs diff=lfs merge=lfs -text
96
+ cache/wandb/wandb/offline-run-20260326_201451-g14348cv/run-g14348cv.wandb filter=lfs diff=lfs merge=lfs -text
97
+ neuralese_latent/wandb/offline-run-20260219_214352-58nmo2ma/run-58nmo2ma.wandb filter=lfs diff=lfs merge=lfs -text
98
+ cache/models/datasets--openai--gsm8k/blobs/ee7b8da9e381df27b9e3f7758a159ab2bdaa4dbaa910546cbbc47e0cb44e4f59 filter=lfs diff=lfs merge=lfs -text
99
+ cache/models/datasets--openai--gsm8k/blobs/ea82612ea9582142387730c793eb67d3b12849002bc0b7fa6f8efafa7351419d filter=lfs diff=lfs merge=lfs -text
100
+ cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b filter=lfs diff=lfs merge=lfs -text
101
+ cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/7b770216613ac5c34d7c54bdff1fa616bc4e338a9d0b20af6303e48c295ee23c filter=lfs diff=lfs merge=lfs -text
102
+ cache/models/models--unsloth--Llama-3.2-3B-Instruct/blobs/13cbd6d16e927a0c5bad54102514e6e18b4a47b3a6eb911e39d678d328d19f55 filter=lfs diff=lfs merge=lfs -text
103
+ cache/models/models--unsloth--Llama-3.2-3B-Instruct/snapshots/006f5dcd1393c3add266de40994ba96225e9689d/tokenizer.json filter=lfs diff=lfs merge=lfs -text
104
+ cache/wandb/wandb/offline-run-20260219_155344-i4fzutup/run-i4fzutup.wandb filter=lfs diff=lfs merge=lfs -text
105
+ cache/wandb/wandb/offline-run-20260328_061445-eyjtftlz/run-eyjtftlz.wandb filter=lfs diff=lfs merge=lfs -text
106
+ cache/wandb/wandb/offline-run-20260327_234927-gbd6yui5/run-gbd6yui5.wandb filter=lfs diff=lfs merge=lfs -text
107
+ cache/wandb/wandb/offline-run-20260326_154346-zf18ek2n/run-zf18ek2n.wandb filter=lfs diff=lfs merge=lfs -text
108
+ cache/wandb/wandb/offline-run-20260325_011840-yvcn1q14/run-yvcn1q14.wandb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+
3
+ # Large local caches and artifacts (not for Hub)
4
+ cache/
5
+ **/__pycache__/
6
+ *.py[cod]
7
+ .pytest_cache/
8
+ .mypy_cache/
9
+ .ruff_cache/
10
+ *.egg-info/
11
+ .eggs/
12
+ .wandb/
13
+ *.swp
14
+ .DS_Store
15
+
AGENTS.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGENTS.md
2
+
3
+ ## Purpose
4
+
5
+ This repo is a GRPO-style math-reasoning training workspace. Agents should preserve a clean split between tracked source code and untracked runtime artifacts.
6
+
7
+ ## Canonical Layout
8
+
9
+ Tracked code and docs:
10
+
11
+ - `src/`: Python entrypoints and the `hackable` library.
12
+ - `configs/`: Accelerate and experiment YAMLs.
13
+ - `scripts/`: launchers and workflow helpers.
14
+ - `docs/`: durable repo documentation.
15
+ - `README.md`: user-facing overview.
16
+
17
+ Runtime storage under `cache/`:
18
+
19
+ - `cache/datasets/`: Hugging Face dataset cache and prepared dataset assets.
20
+ - `cache/models/`: local model weights, checkpoint loads, tokenizer cache.
21
+ - `cache/hf/`: extra Hugging Face runtime cache.
22
+ - `cache/artifacts/runs/`: single training or SFT runs.
23
+ - `cache/artifacts/sweeps/`: parameter sweeps and multi-run experiments.
24
+ - `cache/artifacts/eval/`: evaluation outputs, JSONL, CSV, plots.
25
+ - `cache/logs/wandb/`: W&B offline or online run logs.
26
+
27
+ Within a run directory, use:
28
+
29
+ - `checkpoint-*`: rotating trainer checkpoints.
30
+ - `checkpoints/permanent/`: non-rotating full-model snapshots saved by the callback.
31
+ - run-local analysis files only when they belong to that run.
32
+
33
+ ## Path Rules
34
+
35
+ - Treat `storage.cache_dir` as the runtime root. Relative runtime paths such as `artifacts/runs/foo` resolve under that root.
36
+ - Do not introduce new top-level folders for datasets, checkpoints, models, or logs.
37
+ - Prefer `hackable.paths.resolve_storage_path()` and `hackable.paths.storage_layout()` over manual string concatenation.
38
+ - Keep tracked docs and code out of `cache/`. Keep large generated outputs out of `src/`, `scripts/`, and the repo root.
39
+
40
+ ## Code Map
41
+
42
+ - `src/train_grpo.py`: main GRPO training entrypoint.
43
+ - `src/sft_gsm8k_boxed.py`: SFT baseline for boxed-only GSM8K.
44
+ - `src/eval_*.py`: evaluation entrypoints.
45
+ - `src/hackable/config.py`: YAML dataclasses.
46
+ - `src/hackable/paths.py`: canonical storage layout helpers.
47
+ - `src/hackable/data_plugins.py`: dataset providers.
48
+ - `src/hackable/reward_plugins.py`: reward functions.
49
+ - `src/hackable/objectives.py`: reward combination logic.
50
+ - `src/hackable/permanent_checkpoint_callback.py`: permanent checkpoint writer.
51
+
52
+ ## Working Agreement
53
+
54
+ - Before changing structure, inspect existing path assumptions in both `src/` and `scripts/`.
55
+ - When adding a new workflow, document it in `README.md` or `docs/repository-map.md`.
56
+ - Preserve user changes already present in the worktree unless explicitly asked to rewrite them.
README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neuralese / hackable GRPO baseline
2
+
3
+ This repository is a **hackable GRPO-style training baseline** for math reasoning with a fixed chain-of-thought format: reasoning inside `<redacted_thinking>...</redacted_thinking>` and a final `\boxed{...}` answer. The design keeps **objectives**, **rewards**, and **data providers** in small pluggable modules so you can experiment without rewriting a monolithic trainer.
4
+
5
+ ## Repository layout
6
+
7
+ | Path | Role |
8
+ |------|------|
9
+ | `src/hackable/` | Core library: config types, registries, objectives, rewards, data providers, model loading. |
10
+ | `src/train_grpo.py` | Training entrypoint expected by `scripts/*.sh` (may be absent in a partial checkout; scripts default to `--config configs/grpo_llama32_3b_bf16.yaml`). |
11
+ | `docs/repository-map.md` | Maintainer-focused map of entrypoints, modules, and storage conventions. |
12
+ | `AGENTS.md` | Agent-facing workflow and folder contract for future repo changes. |
13
+ | `configs/` | YAML experiment configs (referenced by scripts; may need to be added locally). |
14
+ | `scripts/` | Bash launchers for single/multi-GPU training, sweeps, and evaluation. |
15
+ | `cache/` | Runtime storage root: `datasets/`, `models/`, `hf/`, `artifacts/{runs,sweeps,eval}/`, and `logs/wandb/`. |
16
+ | `requirements.txt` | Python dependencies (Torch, TRL, Transformers, Accelerate, etc.). |
17
+ | `note.txt` | Unrelated environment/pip noise from a past install; not project documentation. |
18
+
19
+ ### Storage conventions
20
+
21
+ - `storage.cache_dir` is the canonical runtime root. Relative runtime paths like `artifacts/runs/grpo-llama32-3b` resolve under that directory.
22
+ - Use `cache/datasets` for dataset cache, `cache/models` for model/tokenizer cache, and `cache/logs/wandb` for W&B logs.
23
+ - Training runs live in `cache/artifacts/runs`, sweep outputs in `cache/artifacts/sweeps`, eval outputs in `cache/artifacts/eval`.
24
+ - Permanent checkpoints now live under each run at `checkpoints/permanent/`.
25
+
26
+ ### `src/hackable/` modules
27
+
28
+ - **`config.py`** — Loads YAML into typed dataclasses (`ExperimentConfig`, `ModelConfig`, `TrainerConfig`, `GenerationConfig`, `ObjectiveConfig`, `RewardsConfig`, …). Normalizes optimizer/scheduler aliases and numeric fields.
29
+ - **`registry.py`** — `@register_data_provider`, `@register_reward`, `@register_objective` plus `build_*` factories. Reward kwargs from YAML are partially applied via `build_reward` so each reward can receive static options (e.g. tokenizer name for length penalty).
30
+ - **`interfaces.py`** — `TrainingSample` and protocol shapes for providers/rewards/objectives.
31
+ - **`objectives.py`** — `TokenGRPOObjective` (main token-level GRPO recipe) and `LatentNeuraleseObjective` (stub: format reward only, `extra_reward` no-op for future latent scoring).
32
+ - **`reward_plugins.py`** — Registered rewards: strict format, GSM8K/MATH-style correctness, **length penalty**, optional token-utilisation shaping. See [Length penalty rewards](#length-penalty-rewards) below.
33
+ - **`data_plugins.py`** — GSM8K, Hendrycks MATH by level, interleaved curricula (`gsm8k_math_curriculum`, etc.). Shared prompt prefix matches the strict completion format expected by rewards.
34
+ - **`backends.py`** — Loads causal LMs with Liger Llama patches when applicable, FlashAttention2 when importable, else SDPA. Exposes `generation_kwargs` from config.
35
+ - **`utils.py`** — `import_from_path` for `objective.class_path` style `"module:Class"` loading.
36
+
37
+ Importing `hackable` registers default plugins (`hackable/__init__.py` imports `data_plugins`, `objectives`, `reward_plugins`).
38
+
39
+ ### `src/` evaluation utilities
40
+
41
+ - **`eval_sweep_models.py`** — Distributed evaluation of every `run_*` directory under a sweep root: loads checkpoints, runs GSM8K test generations, scores correctness, records CoT word length stats, writes CSV/JSON summaries (used after lambda or reward-variant sweeps).
42
+ - **`eval_permanent_checkpoints.py`** — Walks `checkpoints/permanent` folders, evaluates each checkpoint, can emit simple SVG learning curves.
43
+ - **`eval_math_level1_thinking_zeroshot.py`** — Zero-shot / thinking-format eval on MATH-style data with JSONL output (for downstream rewards or analysis).
44
+
45
+ ### `scripts/` (high level)
46
+
47
+ | Script | Purpose |
48
+ |--------|---------|
49
+ | `run_grpo.sh`, `run_grpo_2gpu.sh`, `run_grpo_4gpu.sh`, `run_grpo_8gpu.sh` | Launch training with Accelerate. |
50
+ | `resume_grpo_8gpu.sh` | Resume from latest or explicit checkpoint. |
51
+ | `sweep_length_penalty_lambda.sh` | Trains multiple runs with different `length_penalty_lambda` (weighted length-penalty mode). |
52
+ | `run_reward_variants_and_eval.sh` | Trains three interaction/gating variants, then runs `eval_sweep_models.py`. |
53
+ | `run_twostage_correctness1.sh`, `run_twostage_correctness5.sh` | Two-stage schedules; YAML is expected to set `correctness_weight` and optionally stage-2 length-penalty fields (see below). |
54
+ | `run_lambda_0p1_existing_gate_token_util.sh` | Example run with low λ and `token_utilisation_reward` enabled. |
55
+ | `eval_sweep_models_offline.sh` | Offline eval driver for a length-penalty λ sweep directory. |
56
+ | `eval_length_penalty_ablation_offline.sh` | Launches `src/eval_length_penalty_ablation.py` (script must exist alongside training code). |
57
+ | `eval_twostage_permanent_checkpoints.sh` | Eval for two-stage permanent checkpoint trees. |
58
+ | `eval_gsm8k_*.sh`, `eval_math_level*_*.sh` | Dataset-specific eval launchers. |
59
+ | `hf_upload_repo.py`, `hf_download_repo.py` | Push/pull Hugging Face dataset repo snapshots. |
60
+
61
+ ---
62
+
63
+ ## Length penalty rewards
64
+
65
+ Training prefers **shorter thinking traces** *within each GRPO group* (same prompt, multiple sampled completions). The signal is implemented as a **reward** (`length_penalty_reward`) and combined with correctness and format either **additively** or via a **weighted multiplicative** term controlled by `TokenGRPOObjective`.
66
+
67
+ ### What gets measured
68
+
69
+ 1. **Strict format only** — `_think_length_tokens` in `reward_plugins.py` parses the completion with the same regex as `format_tag_reward`: a single block `<redacted_thinking>...</redacted_thinking>` followed by `\boxed{...}`. If the completion does not match, thinking length is treated as **0** (so length reward does not reward malformed outputs on length grounds).
70
+ 2. **Token count** — If `tokenizer_name` is passed (via `rewards.kwargs.length_penalty_reward` in YAML and `build_reward`), the **thinking substring** is encoded with that tokenizer (`add_special_tokens=False`) and the length is the number of token IDs. If no tokenizer is configured, length falls back to **whitespace-split words** inside the thinking block.
71
+ 3. **Per-group normalization** — For each group \(G\) of completions (see grouping below), let \(L_i\) be the thinking length of completion \(i\), and \(\bar{L} = \frac{1}{|G|}\sum_j L_j\). If \(\bar{L} \le 0\), every score in the group is **0**. Otherwise:
72
+
73
+ \[
74
+ R^{\mathrm{length}}_i = \max\left(0,\ 1 - \frac{L_i}{\bar{L}}\right).
75
+ \]
76
+
77
+ So **shorter-than-average** thinking in the group scores closer to **1**, **longer-than-average** scores closer to **0**, and everyone at the average gets **0**. This is a **relative** length preference, not an absolute token budget.
78
+
79
+ ### How groups are formed
80
+
81
+ `length_penalty_reward` assigns the same group normalization to completions that belong to the same GRPO comparison group:
82
+
83
+ - If `group_size` is set (again, typically under `rewards.kwargs.length_penalty_reward`), the flat batch is chunked in order: `[0:group_size)`, `[group_size:2*group_size)`, …
84
+ - Else groups are inferred by **contiguous runs of identical `prompt` text** in the parallel lists passed to the reward.
85
+
86
+ ### How `TokenGRPOObjective` combines length with other rewards
87
+
88
+ Configured under `objective.name: token_grpo` with `objective.kwargs` (see `src/hackable/objectives.py`).
89
+
90
+ **Registered reward names** (when enabled):
91
+
92
+ - `format_tag_reward` — 1.0 if strict thinking + non-empty boxed answer, else 0.0.
93
+ - `gsm8k_correctness_reward` — Parses `\boxed{...}` vs reference (GSM8K `####` answers or MATH boxed solutions), numeric normalization and tolerant float compare.
94
+ - `length_penalty_reward` — if `enable_length_penalty: true`.
95
+ - `token_utilisation_reward` — optional; shapes training vs a frozen zero-shot correctness JSONL (see docstring in `reward_plugins.py`).
96
+
97
+ **`reward_mode`**
98
+
99
+ 1. **`additive`** — Total score is the **sum** of all enabled reward outputs for that sample. If `strict_format_gate: true`, any sample with `format_tag_reward ≤ 0.5` is replaced by `non_strict_penalty` (default **-1.0**) instead of the sum.
100
+ 2. **`weighted_length_penalty`** — Correctness and length interact multiplicatively; format is added (and optionally multiplied into the interaction):
101
+
102
+ - Let \(r_c, r_f, r_\ell\) be correctness, format, and length scores in \([0,1]\) (length may be 0 if disabled or malformed).
103
+ - Base interaction: \(r_c \times r_\ell\).
104
+ - If `length_penalty_interaction` is `correctness_length_format`, the interaction is \(r_c \times r_\ell \times r_f\).
105
+ - Total (before optional token-util term):
106
+
107
+ \[
108
+ \texttt{correctness\_weight} \cdot r_c + \texttt{length\_penalty\_lambda} \cdot \text{interaction} + r_f + r_{\mathrm{util}}.
109
+ \]
110
+
111
+ If `strict_format_gate` is true and \(r_f \le 0.5\), the total is **`non_strict_penalty`** and the formula above is skipped.
112
+
113
+ **Important knobs**
114
+
115
+ | KWarg | Meaning |
116
+ |-------|---------|
117
+ | `enable_length_penalty` | When false, length reward is not registered; in weighted mode \(r_\ell\) is treated as 0. |
118
+ | `length_penalty_lambda` | Scales the \(r_c \cdot r_\ell\) (or \(r_c \cdot r_\ell \cdot r_f\)) term in weighted mode. Sweeps often try e.g. 0.25, 0.5, 0.75, 1.0. |
119
+ | `correctness_weight` | Scales the standalone correctness term in weighted mode. |
120
+ | `length_penalty_interaction` | `correctness_length` vs `correctness_length_format` (whether format enters the product). |
121
+ | `strict_format_gate` / `non_strict_penalty` | Hard gate on format before crediting other terms. |
122
+ | `stage2_length_penalty_lambda`, `stage2_start_epoch` | Stored on `TokenGRPOObjective` for **two-stage** schedules; **changing λ at epoch boundaries must be implemented in the trainer** (`combine_rewards` itself only reads the current `length_penalty_lambda` on the instance). |
123
+
124
+ **YAML wiring example** (illustrative; paths depend on your repo):
125
+
126
+ ```yaml
127
+ objective:
128
+ name: token_grpo
129
+ kwargs:
130
+ reward_mode: weighted_length_penalty
131
+ enable_length_penalty: true
132
+ correctness_weight: 1.0
133
+ length_penalty_lambda: 0.5
134
+ length_penalty_interaction: correctness_length
135
+ strict_format_gate: true
136
+ non_strict_penalty: -1.0
137
+
138
+ rewards:
139
+ kwargs:
140
+ length_penalty_reward:
141
+ tokenizer_name: meta-llama/Llama-3.2-3B-Instruct # example
142
+ group_size: 4 # often match generation.num_generations
143
+ cache_dir: cache/models
144
+ ```
145
+
146
+ ---
147
+
148
+ ## Hackable extension path
149
+
150
+ - Implement a new objective class with `reward_names()` and `combine_rewards`-equivalent behavior (your trainer must call into it the same way as the baseline), or add new `@register_reward` functions.
151
+ - Point config at a custom class: `objective.class_path: "my_package.my_module:MyObjective"`.
152
+ - `LatentNeuraleseObjective` is a stub: keep token rewards while you add latent signals via `extra_reward` or a future trainer hook.
153
+
154
+ ## Quickstart
155
+
156
+ 1. `pip install -r requirements.txt` (you may need a CUDA-matched Torch and optional FlashAttention build).
157
+ 2. Set `HF_TOKEN` / `WANDB_API_KEY` if needed (names configurable under `auth` in YAML).
158
+ 3. Single GPU: `bash scripts/run_grpo.sh` (requires `src/train_grpo.py` and `configs/grpo_llama32_3b_bf16.yaml` present).
159
+ 4. Many scripts default to offline Hub/datasets/W&B; override `WANDB_MODE`, `HF_HUB_OFFLINE`, etc. if you need network access.
160
+ 5. Runtime outputs default to `cache/artifacts/...` through `storage.cache_dir`; agent docs live in `AGENTS.md` and `docs/repository-map.md`.
161
+
162
+ ## Attention backend
163
+
164
+ `backends.py` prefers **FlashAttention2** when `flash_attn` imports cleanly; otherwise **SDPA**. Llama models optionally use **Liger** kernels when installed.
configs/accelerate_ddp_2gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: "no"
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 2
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/accelerate_ddp_4gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/accelerate_ddp_7gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 7
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/accelerate_ddp_8gpu.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/grpo_llama32_3b_bf16.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Baseline GRPO experiment (paths are relative to storage.cache_dir unless absolute).
2
+ model:
3
+ name: meta-llama/Llama-3.2-3B-Instruct
4
+ trust_remote_code: false
5
+
6
+ trainer:
7
+ output_dir: artifacts/runs/grpo-llama32-3b
8
+ run_name: grpo-llama32-3b
9
+ num_train_epochs: 1.0
10
+ max_steps: -1
11
+ per_device_train_batch_size: 1
12
+ gradient_accumulation_steps: 8
13
+ learning_rate: 1.0e-6
14
+ logging_steps: 1
15
+ save_steps: 200
16
+ save_total_limit: 3
17
+ permanent_checkpoint_steps: 300
18
+ permanent_checkpoint_dir: checkpoints/permanent
19
+ bf16: true
20
+ shuffle_dataset: false
21
+ report_to: wandb
22
+ warmup_steps: 20
23
+ gradient_checkpointing: true
24
+ max_grad_norm: 1.0
25
+ lr_scheduler_type: cosine
26
+ lr_scheduler_kwargs: {}
27
+
28
+ data:
29
+ provider: gsm8k_math_curriculum
30
+ split: train
31
+ max_samples: null
32
+
33
+ generation:
34
+ max_prompt_length: 512
35
+ max_completion_length: 256
36
+ num_generations: 4
37
+ temperature: 0.9
38
+ top_p: 0.95
39
+
40
+ objective:
41
+ name: token_grpo
42
+ kwargs:
43
+ reward_mode: weighted_length_penalty
44
+ enable_length_penalty: true
45
+ length_penalty_lambda: 0.1
46
+ correctness_weight: 1.0
47
+ strict_format_gate: true
48
+
49
+ rewards:
50
+ kwargs:
51
+ length_penalty_reward:
52
+ group_size: 4
53
+
54
+ # Multiply KL on tokens inside <redacted_thinking>...</redacted_thinking> (inner body) by this weight.
55
+ # 1.0 = default TRL behavior; 0.0 = no KL penalty on those tokens (policy loss still applies everywhere).
56
+ thinking_kl:
57
+ inner_kl_weight: 1.0
58
+
59
+ # TRL GRPO-specific flags (see GRPOConfig). When thinking_kl.inner_kl_weight != 1 and beta != 0,
60
+ # train_grpo.py forces use_liger_kernel: false.
61
+ grpo:
62
+ beta: 0.001
63
+ loss_type: grpo
64
+ importance_sampling_level: token
65
+ use_liger_kernel: false
66
+
67
+ storage:
68
+ cache_dir: cache
69
+
70
+ auth: {}
docs/repository-map.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Repository Map
2
+
3
+ ## What Lives Where
4
+
5
+ - `src/train_grpo.py`: launches GRPO training from a YAML experiment config.
6
+ - `src/sft_gsm8k_boxed.py`: supervised fine-tuning entrypoint for boxed-answer GSM8K.
7
+ - `src/eval_sweep_models.py`: evaluates every `run_*` directory in a sweep root on GSM8K.
8
+ - `src/eval_permanent_checkpoints.py`: evaluates a run's permanent checkpoints and writes summaries and plots.
9
+ - `src/eval_gsm8k_zeroshot.py`: single-checkpoint GSM8K evaluation.
10
+ - `src/eval_gsm8k_truncated_thinking.py`: GSM8K eval with forced truncation of the thinking span.
11
+ - `src/eval_math_level1_thinking_zeroshot.py`: MATH level-1 zero-shot thinking-format evaluation.
12
+
13
+ ## `src/hackable/`
14
+
15
+ - `config.py`: experiment dataclasses and YAML loading.
16
+ - `paths.py`: canonical runtime storage layout and path resolution.
17
+ - `registry.py`: plugin registration and construction for objectives, rewards, and data providers.
18
+ - `interfaces.py`: shared sample and protocol types.
19
+ - `data_plugins.py`: GSM8K, MATH, and curriculum data providers.
20
+ - `reward_plugins.py`: format, correctness, length, and token-utilization rewards.
21
+ - `objectives.py`: objective classes and reward-combination strategies.
22
+ - `backends.py`: model-loading helpers and backend selection.
23
+ - `thinking_kl_grpo_trainer.py`: custom GRPO trainer with masked thinking-KL handling.
24
+ - `thinking_kl_mask.py`: token masking logic for inner-thinking KL control.
25
+ - `permanent_checkpoint_callback.py`: periodic permanent checkpoint export.
26
+ - `utils.py`: repository-relative import and path utilities.
27
+
28
+ ## Config And Launch Surface
29
+
30
+ - `configs/grpo_llama32_3b_bf16.yaml`: baseline experiment config.
31
+ - `configs/accelerate_ddp_*.yaml`: multi-GPU Accelerate launch configs.
32
+ - `scripts/run_grpo*.sh`: baseline training launchers.
33
+ - `scripts/run_twostage_*.sh`: two-stage experiments.
34
+ - `scripts/sweep_length_penalty_lambda.sh`: lambda sweep generator.
35
+ - `scripts/run_reward_variants_and_eval.sh`: reward variant sweep plus evaluation.
36
+ - `scripts/eval_*.sh`: distributed evaluation wrappers.
37
+
38
+ ## Runtime Storage
39
+
40
+ All generated assets should live under `cache/`:
41
+
42
+ - `cache/datasets/`: dataset downloads and prepared data.
43
+ - `cache/models/`: Hugging Face model cache and local model roots.
44
+ - `cache/hf/`: extra Hugging Face cache state.
45
+ - `cache/artifacts/runs/`: one-off runs.
46
+ - `cache/artifacts/sweeps/`: grouped experiments.
47
+ - `cache/artifacts/eval/`: evaluation outputs.
48
+ - `cache/logs/wandb/`: W&B logs.
49
+
50
+ Relative paths like `artifacts/runs/foo` are intentionally interpreted relative to `storage.cache_dir`, so they land at `cache/artifacts/runs/foo` by default.
out.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ flash-attn
3
+ liger-kernel
4
+ transformers
5
+ trl>=1.0.0
6
+ datasets
7
+ accelerate
8
+ pyyaml
9
+ wandb
scripts/conda_env.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source from other scripts: default training env is conda "sidharth".
2
+ # Override: CONDA_ENV=myenv bash scripts/run_grpo_8gpu.sh
3
+ #
4
+ # Conda and nvidia packages (e.g. cuda-nvcc) ship deactivate hooks that use unset variables
5
+ # (CUDAARCHS_BACKUP, etc.). Parent scripts often use `set -u`, which breaks those hooks.
6
+ # We disable nounset for this entire file, then restore the parent's setting at the end.
7
+
8
+ _nounset_was_on=0
9
+ case $- in *u*) _nounset_was_on=1 ;; esac
10
+ set +u
11
+
12
+ : "${CONDA_ENV:=sidharth}"
13
+
14
+ _conda_try_init() {
15
+ if command -v conda >/dev/null 2>&1; then
16
+ # shellcheck disable=SC2312
17
+ eval "$(conda shell.bash hook)"
18
+ return 0
19
+ fi
20
+ local d
21
+ for d in "${CONDA_PREFIX:-}" "$HOME/miniconda3" "$HOME/mambaforge" "$HOME/anaconda3" "/opt/conda"; do
22
+ if [[ -n "$d" && -f "$d/etc/profile.d/conda.sh" ]]; then
23
+ # shellcheck disable=SC1090
24
+ source "$d/etc/profile.d/conda.sh"
25
+ return 0
26
+ fi
27
+ done
28
+ return 1
29
+ }
30
+
31
+ if ! _conda_try_init; then
32
+ echo "conda_env.sh: could not initialize conda (not on PATH and no known install)." >&2
33
+ exit 1
34
+ fi
35
+
36
+ conda activate "${CONDA_ENV}"
37
+
38
+ if [[ "$_nounset_was_on" -eq 1 ]]; then
39
+ set -u
40
+ fi
scripts/eval_gsm8k_think_entropy_4gpu.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ # Required:
18
+ # MODEL_DIR=/path/to/checkpoint
19
+ MODEL_DIR="${MODEL_DIR:-}"
20
+
21
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_think_entropy/entropy.json}"
22
+ BATCH_SIZE="${BATCH_SIZE:-4}"
23
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
24
+ SAVE_PER_SAMPLE="${SAVE_PER_SAMPLE:-0}"
25
+
26
+ export PYTHONPATH="${ROOT}/src"
27
+ export BASE_CONFIG
28
+ export MODEL_DIR
29
+ export OUTPUT_PATH
30
+ export BATCH_SIZE
31
+ export MAX_SAMPLES
32
+ export SAVE_PER_SAMPLE
33
+
34
+ export WANDB_MODE="${WANDB_MODE:-offline}"
35
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
36
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
37
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
38
+
39
+ accelerate launch \
40
+ --config_file "${ACCELERATE_CONFIG}" \
41
+ --num_processes "${NUM_PROCESSES}" \
42
+ src/eval_gsm8k_think_entropy.py
scripts/eval_gsm8k_truncated_thinking.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # GSM8K eval with forced early </redacted_thinking> after N inner thinking tokens (default 5).
3
+ #
4
+ # Example:
5
+ # MODEL_PATH=artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent/checkpoint-stage1-boundary-epoch-1p0-step-1868 \
6
+ # bash scripts/eval_gsm8k_truncated_thinking.sh
7
+ set -eo pipefail
8
+
9
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
10
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
11
+ cd "${REPO_ROOT}"
12
+ # shellcheck disable=SC1091
13
+ source "${SCRIPT_DIR}/conda_env.sh"
14
+ set -u
15
+
16
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
17
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
18
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
19
+ MODEL_PATH="${MODEL_PATH:?Set MODEL_PATH to your checkpoint directory}"
20
+ THINK_INNER_TOKEN_LIMIT="${THINK_INNER_TOKEN_LIMIT:-5}"
21
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
22
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-256}"
23
+ OUTPUT_PATH="${OUTPUT_PATH:-}"
24
+
25
+ export PYTHONPATH="${REPO_ROOT}/src"
26
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
27
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
28
+
29
+ CMD_ENV=(
30
+ PYTHONPATH="${REPO_ROOT}/src"
31
+ BASE_CONFIG="${BASE_CONFIG}"
32
+ MODEL_PATH="${MODEL_PATH}"
33
+ THINK_INNER_TOKEN_LIMIT="${THINK_INNER_TOKEN_LIMIT}"
34
+ MAX_SAMPLES="${MAX_SAMPLES}"
35
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS}"
36
+ )
37
+ if [[ -n "${OUTPUT_PATH}" ]]; then
38
+ CMD_ENV+=(OUTPUT_PATH="${OUTPUT_PATH}")
39
+ fi
40
+
41
+ env "${CMD_ENV[@]}" accelerate launch \
42
+ --config_file "${ACCELERATE_CONFIG}" \
43
+ --num_processes "${NUM_PROCESSES}" \
44
+ src/eval_gsm8k_truncated_thinking.py \
45
+ --config "${BASE_CONFIG}" \
46
+ --model_path "${MODEL_PATH}" \
47
+ --think_inner_token_limit "${THINK_INNER_TOKEN_LIMIT}" \
48
+ --max_samples "${MAX_SAMPLES}" \
49
+ --max_new_tokens "${MAX_NEW_TOKENS}"
50
+
51
+ echo "Done."
scripts/eval_gsm8k_zeroshot_4gpu.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Zero-shot GSM-8K eval: system prompt asks for answer only in \\boxed{}; chat template via tokenizer.
3
+ # Usage:
4
+ # MODEL_PATH=/path/to/checkpoint OUTPUT_PATH=out.jsonl ./scripts/eval_gsm8k_zeroshot_4gpu.sh
5
+ # Or with base YAML (uses model.name if MODEL_PATH unset):
6
+ # BASE_CONFIG=configs/grpo_llama32_3b_bf16.yaml OUTPUT_PATH=out.jsonl ./scripts/eval_gsm8k_zeroshot_4gpu.sh
7
+ set -eo pipefail
8
+
9
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
10
+ # shellcheck disable=SC1091
11
+ source "${SCRIPT_DIR}/conda_env.sh"
12
+ set -u
13
+
14
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
15
+ cd "${ROOT}"
16
+
17
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
18
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
19
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
20
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_zeroshot_answers.jsonl}"
21
+ BATCH_SIZE="${BATCH_SIZE:-4}"
22
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
23
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
24
+
25
+ export PYTHONPATH="${ROOT}/src"
26
+ export BASE_CONFIG
27
+ export OUTPUT_PATH
28
+ export BATCH_SIZE
29
+ export MAX_NEW_TOKENS
30
+ export MAX_SAMPLES
31
+ # Optional: MODEL_PATH, SYSTEM_PROMPT, GSM8K_SPLIT, MAX_PROMPT_LENGTH, LOCAL_FILES_ONLY=1
32
+
33
+ accelerate launch \
34
+ --config_file "${ACCELERATE_CONFIG}" \
35
+ --num_processes "${NUM_PROCESSES}" \
36
+ src/eval_gsm8k_zeroshot.py
scripts/eval_gsm8k_zeroshot_train_4gpu.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_PATH="${MODEL_PATH:-}"
18
+ if [[ -z "${MODEL_PATH}" ]]; then
19
+ echo "Missing MODEL_PATH."
20
+ echo "Example:"
21
+ echo " MODEL_PATH=/path/to/checkpoint OUTPUT_PATH=artifacts/eval/gsm8k_train_zeroshot/results.jsonl ./scripts/eval_gsm8k_zeroshot_train_4gpu.sh"
22
+ exit 1
23
+ fi
24
+
25
+ # Train split, entire dataset by default.
26
+ GSM8K_SPLIT="${GSM8K_SPLIT:-train}"
27
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
28
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/gsm8k_train_zeroshot/results.jsonl}"
29
+ BATCH_SIZE="${BATCH_SIZE:-4}"
30
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
31
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-1024}"
32
+
33
+ export PYTHONPATH="${ROOT}/src"
34
+ export BASE_CONFIG
35
+ export MODEL_PATH
36
+ export GSM8K_SPLIT
37
+ export MAX_SAMPLES
38
+ export OUTPUT_PATH
39
+ export BATCH_SIZE
40
+ export MAX_NEW_TOKENS
41
+ export MAX_PROMPT_LENGTH
42
+
43
+ # Optional override for prompt behavior.
44
+ # export SYSTEM_PROMPT="..."
45
+
46
+ export WANDB_MODE="${WANDB_MODE:-offline}"
47
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
48
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
49
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
50
+
51
+ accelerate launch \
52
+ --config_file "${ACCELERATE_CONFIG}" \
53
+ --num_processes "${NUM_PROCESSES}" \
54
+ src/eval_gsm8k_zeroshot.py
scripts/eval_length_penalty_ablation_offline.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-2}"
16
+ OUT_ROOT="${OUT_ROOT:-artifacts/ablation_length_penalty}"
17
+
18
+ # Update these if your model dirs differ.
19
+ MODEL_LEN_ON="${MODEL_LEN_ON:-${OUT_ROOT}/run_len_on}"
20
+ MODEL_LEN_OFF="${MODEL_LEN_OFF:-${OUT_ROOT}/run_len_off}"
21
+
22
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
23
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
24
+
25
+ export WANDB_MODE="${WANDB_MODE:-offline}"
26
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
27
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
28
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
29
+
30
+ BASE_CONFIG="${BASE_CONFIG}" \
31
+ OUT_ROOT="${OUT_ROOT}" \
32
+ MODEL_LEN_ON="${MODEL_LEN_ON}" \
33
+ MODEL_LEN_OFF="${MODEL_LEN_OFF}" \
34
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
35
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
36
+ accelerate launch \
37
+ --config_file "${ACCELERATE_CONFIG}" \
38
+ --num_processes "${NUM_PROCESSES}" \
39
+ src/eval_length_penalty_ablation.py
scripts/eval_math_level1_4gpu.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_DIR="${MODEL_DIR:-}"
18
+ if [[ -z "${MODEL_DIR}" ]]; then
19
+ echo "Missing MODEL_DIR. Example:"
20
+ echo " MODEL_DIR=/path/to/checkpoint OUTPUT_PATH=out.json ./scripts/eval_math_level1_4gpu.sh"
21
+ exit 1
22
+ fi
23
+
24
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1/accuracy.json}"
25
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
26
+ BATCH_SIZE="${BATCH_SIZE:-4}"
27
+ MATH_SPLIT="${MATH_SPLIT:-test}"
28
+ SAVE_ROLLOUTS="${SAVE_ROLLOUTS:-0}"
29
+ ROLLOUTS_MAX_SAMPLES="${ROLLOUTS_MAX_SAMPLES:-50}"
30
+ ROLLOUTS_OUT_PATH="${ROLLOUTS_OUT_PATH:-artifacts/eval/math_level1/rollouts.jsonl}"
31
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-}"
32
+ MAX_COMPLETION_LENGTH="${MAX_COMPLETION_LENGTH:-}"
33
+
34
+ export PYTHONPATH="${ROOT}/src"
35
+ export BASE_CONFIG
36
+ export MODEL_DIR
37
+ export OUTPUT_PATH
38
+ export MAX_SAMPLES
39
+ export BATCH_SIZE
40
+ export MATH_SPLIT
41
+ export SAVE_ROLLOUTS
42
+ export ROLLOUTS_MAX_SAMPLES
43
+ export ROLLOUTS_OUT_PATH
44
+ if [[ -n "${MAX_PROMPT_LENGTH}" ]]; then export MAX_PROMPT_LENGTH; fi
45
+ if [[ -n "${MAX_COMPLETION_LENGTH}" ]]; then export MAX_COMPLETION_LENGTH; fi
46
+
47
+ export WANDB_MODE="${WANDB_MODE:-offline}"
48
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
49
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
50
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
51
+
52
+ accelerate launch \
53
+ --config_file "${ACCELERATE_CONFIG}" \
54
+ --num_processes "${NUM_PROCESSES}" \
55
+ src/eval_math_level1.py
scripts/eval_math_level1_thinking_zeroshot_4gpu.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_DIR="${MODEL_DIR:-}"
18
+ if [[ -z "${MODEL_DIR}" ]]; then
19
+ echo "Missing MODEL_DIR."
20
+ exit 1
21
+ fi
22
+
23
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl}"
24
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
25
+ BATCH_SIZE="${BATCH_SIZE:-4}"
26
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
27
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
28
+ MATH_SPLIT="${MATH_SPLIT:-test}"
29
+
30
+ export PYTHONPATH="${ROOT}/src"
31
+ export BASE_CONFIG
32
+ export MODEL_DIR
33
+ export OUTPUT_PATH
34
+ export MAX_SAMPLES
35
+ export BATCH_SIZE
36
+ export MAX_NEW_TOKENS
37
+ export MAX_PROMPT_LENGTH
38
+ export MATH_SPLIT
39
+
40
+ export WANDB_MODE="${WANDB_MODE:-offline}"
41
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
42
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
43
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
44
+
45
+ accelerate launch \
46
+ --config_file "${ACCELERATE_CONFIG}" \
47
+ --num_processes "${NUM_PROCESSES}" \
48
+ src/eval_math_level1_thinking_zeroshot.py
scripts/eval_math_level1_zeroshot_4gpu.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_DIR="${MODEL_DIR:-}"
18
+ if [[ -z "${MODEL_DIR}" ]]; then
19
+ echo "Missing MODEL_DIR."
20
+ exit 1
21
+ fi
22
+
23
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level1_zeroshot/answers.jsonl}"
24
+ SYSTEM_PROMPT="${SYSTEM_PROMPT:-}"
25
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
26
+ BATCH_SIZE="${BATCH_SIZE:-4}"
27
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
28
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
29
+
30
+ export PYTHONPATH="${ROOT}/src"
31
+ export BASE_CONFIG
32
+ export MODEL_DIR
33
+ export OUTPUT_PATH
34
+ export MAX_SAMPLES
35
+ export BATCH_SIZE
36
+ export MAX_NEW_TOKENS
37
+ export MAX_PROMPT_LENGTH
38
+
39
+ if [[ -n "${SYSTEM_PROMPT}" ]]; then export SYSTEM_PROMPT; fi
40
+
41
+ export WANDB_MODE="${WANDB_MODE:-offline}"
42
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
43
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
44
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
45
+
46
+ accelerate launch \
47
+ --config_file "${ACCELERATE_CONFIG}" \
48
+ --num_processes "${NUM_PROCESSES}" \
49
+ src/eval_math_level1_zeroshot.py
scripts/eval_math_level2_thinking_zeroshot_4gpu.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_DIR="${MODEL_DIR:-}"
18
+ if [[ -z "${MODEL_DIR}" ]]; then
19
+ echo "Missing MODEL_DIR."
20
+ exit 1
21
+ fi
22
+
23
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level2_thinking_zeroshot/answers.jsonl}"
24
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
25
+ BATCH_SIZE="${BATCH_SIZE:-4}"
26
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
27
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
28
+ MATH_SPLIT="${MATH_SPLIT:-test}"
29
+
30
+ export PYTHONPATH="${ROOT}/src"
31
+ export BASE_CONFIG
32
+ export MODEL_DIR
33
+ export OUTPUT_PATH
34
+ export MAX_SAMPLES
35
+ export BATCH_SIZE
36
+ export MAX_NEW_TOKENS
37
+ export MAX_PROMPT_LENGTH
38
+ export MATH_SPLIT
39
+
40
+ export WANDB_MODE="${WANDB_MODE:-offline}"
41
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
42
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
43
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
44
+
45
+ accelerate launch \
46
+ --config_file "${ACCELERATE_CONFIG}" \
47
+ --num_processes "${NUM_PROCESSES}" \
48
+ src/eval_math_level2_thinking_zeroshot.py
scripts/eval_math_level2_zeroshot_4gpu.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+
10
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
11
+ cd "${ROOT}"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+
17
+ MODEL_DIR="${MODEL_DIR:-}"
18
+ if [[ -z "${MODEL_DIR}" ]]; then
19
+ echo "Missing MODEL_DIR."
20
+ exit 1
21
+ fi
22
+
23
+ OUTPUT_PATH="${OUTPUT_PATH:-artifacts/eval/math_level2_zeroshot/answers.jsonl}"
24
+ SYSTEM_PROMPT="${SYSTEM_PROMPT:-}"
25
+ MAX_SAMPLES="${MAX_SAMPLES:--1}"
26
+ BATCH_SIZE="${BATCH_SIZE:-4}"
27
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-128}"
28
+ MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}"
29
+
30
+ export PYTHONPATH="${ROOT}/src"
31
+ export BASE_CONFIG
32
+ export MODEL_DIR
33
+ export OUTPUT_PATH
34
+ export MAX_SAMPLES
35
+ export BATCH_SIZE
36
+ export MAX_NEW_TOKENS
37
+ export MAX_PROMPT_LENGTH
38
+
39
+ if [[ -n "${SYSTEM_PROMPT}" ]]; then export SYSTEM_PROMPT; fi
40
+
41
+ export WANDB_MODE="${WANDB_MODE:-offline}"
42
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
43
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
44
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
45
+
46
+ accelerate launch \
47
+ --config_file "${ACCELERATE_CONFIG}" \
48
+ --num_processes "${NUM_PROCESSES}" \
49
+ src/eval_math_level2_zeroshot.py
scripts/eval_permanent_root_acc_cot.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Evaluate every checkpoint-* under a single checkpoints/permanent directory, then write
3
+ # JSON/CSV summary, per-checkpoint JSONL, and accuracy_vs_avg_cot_words.svg (scatter + step labels).
4
+ #
5
+ # Usage:
6
+ # PERMANENT_ROOT=artifacts/runs/.../checkpoints/permanent ./scripts/eval_permanent_root_acc_cot.sh
7
+ #
8
+ # Optional: RUN_LABEL (default permanent), OUT_ROOT, BASE_CONFIG, EVAL_MAX_SAMPLES, NUM_PROCESSES, offline HF vars.
9
+
10
+ set -eo pipefail
11
+
12
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
13
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
14
+ cd "${REPO_ROOT}"
15
+ # shellcheck disable=SC1091
16
+ source "${SCRIPT_DIR}/conda_env.sh"
17
+ set -u
18
+
19
+ : "${PERMANENT_ROOT:?Set PERMANENT_ROOT to your checkpoints/permanent directory}"
20
+
21
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
22
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
23
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
24
+ RUN_LABEL="${RUN_LABEL:-permanent}"
25
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
26
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
27
+ ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES:-8}"
28
+
29
+ export WANDB_MODE="${WANDB_MODE:-offline}"
30
+ # Allow dataset download into cache unless you already mirrored GSM8K (set HF_DATASETS_OFFLINE=1).
31
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
32
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
33
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
34
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
35
+
36
+ CMD_ENV=(
37
+ PYTHONPATH="${REPO_ROOT}/src"
38
+ BASE_CONFIG="${BASE_CONFIG}"
39
+ PERMANENT_ROOT="${PERMANENT_ROOT}"
40
+ RUN_LABEL="${RUN_LABEL}"
41
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}"
42
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}"
43
+ ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES}"
44
+ )
45
+ if [[ -n "${OUT_ROOT:-}" ]]; then
46
+ CMD_ENV+=(OUT_ROOT="${OUT_ROOT}")
47
+ fi
48
+ env "${CMD_ENV[@]}" accelerate launch \
49
+ --config_file "${ACCELERATE_CONFIG}" \
50
+ --num_processes "${NUM_PROCESSES}" \
51
+ src/eval_permanent_checkpoints.py
52
+
53
+ echo "Done. Summary and accuracy_vs_avg_cot_words.svg under OUT_ROOT (default: <PERMANENT_ROOT>/eval_permanent)."
scripts/eval_sweep_models_offline.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
14
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
15
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
16
+ SWEEP_ROOT="${SWEEP_ROOT:-artifacts/sweeps/length_penalty_lambda}"
17
+ OUT_ROOT="${OUT_ROOT:-${SWEEP_ROOT}/eval_results}"
18
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
19
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
20
+
21
+ export WANDB_MODE="${WANDB_MODE:-offline}"
22
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
23
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
24
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
25
+
26
+ BASE_CONFIG="${BASE_CONFIG}" \
27
+ SWEEP_ROOT="${SWEEP_ROOT}" \
28
+ OUT_ROOT="${OUT_ROOT}" \
29
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
30
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
31
+ accelerate launch \
32
+ --config_file "${ACCELERATE_CONFIG}" \
33
+ --num_processes "${NUM_PROCESSES}" \
34
+ src/eval_sweep_models.py
scripts/eval_twostage_permanent_checkpoints.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ cd "${REPO_ROOT}"
7
+ # shellcheck disable=SC1091
8
+ source "${SCRIPT_DIR}/conda_env.sh"
9
+ set -u
10
+
11
+
12
+ # Evaluate every checkpoint under two run-local checkpoints/permanent dirs (cw=1 and cw=5 runs).
13
+ # Uses 4 GPUs. Writes rollouts, full JSONL outputs, CSV/JSON summary, step line charts,
14
+ # and accuracy_vs_avg_cot_words.svg (scatter: accuracy vs mean CoT length).
15
+
16
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
17
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
18
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
19
+
20
+ PERMANENT_CW1="${PERMANENT_CW1:-artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent}"
21
+ PERMANENT_CW5="${PERMANENT_CW5:-artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness5/checkpoints/permanent}"
22
+
23
+ OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight/eval_permanent_checkpoints}"
24
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
25
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
26
+ ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES:-8}"
27
+
28
+ export WANDB_MODE="${WANDB_MODE:-offline}"
29
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
30
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
31
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
32
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
33
+
34
+ PYTHONPATH="${REPO_ROOT}/src" \
35
+ BASE_CONFIG="${BASE_CONFIG}" \
36
+ PERMANENT_CW1="${PERMANENT_CW1}" \
37
+ PERMANENT_CW5="${PERMANENT_CW5}" \
38
+ OUT_ROOT="${OUT_ROOT}" \
39
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
40
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
41
+ ROLLOUT_SAMPLES="${ROLLOUT_SAMPLES}" \
42
+ accelerate launch \
43
+ --config_file "${ACCELERATE_CONFIG}" \
44
+ --num_processes "${NUM_PROCESSES}" \
45
+ src/eval_permanent_checkpoints.py
46
+
47
+ echo "Done. Results under: ${OUT_ROOT}"
scripts/hf_download_repo.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import os
6
+ from pathlib import Path
7
+
8
+ from huggingface_hub import snapshot_download
9
+
10
+
11
+ def parse_args() -> argparse.Namespace:
12
+ parser = argparse.ArgumentParser(
13
+ description="Download a Hugging Face dataset repo snapshot to local disk."
14
+ )
15
+ parser.add_argument(
16
+ "--repo-name",
17
+ type=str,
18
+ default="neuralese-move",
19
+ help="HF dataset repo name under the username.",
20
+ )
21
+ parser.add_argument(
22
+ "--username",
23
+ type=str,
24
+ default="psidharth567",
25
+ help="HF username/org (default: psidharth567).",
26
+ )
27
+ parser.add_argument(
28
+ "--local-dir",
29
+ type=str,
30
+ default="./neuralese",
31
+ help="Where to download the repo snapshot locally.",
32
+ )
33
+ parser.add_argument(
34
+ "--revision",
35
+ type=str,
36
+ default=None,
37
+ help="Optional branch/tag/commit to download.",
38
+ )
39
+ parser.add_argument(
40
+ "--hf-token",
41
+ type=str,
42
+ default=os.environ.get("HF_TOKEN", ""),
43
+ help="HF token (or set HF_TOKEN env var).",
44
+ )
45
+ return parser.parse_args()
46
+
47
+
48
+ def main() -> None:
49
+ args = parse_args()
50
+ token = args.hf_token.strip()
51
+ if not token:
52
+ raise RuntimeError("Missing HF token. Set HF_TOKEN or pass --hf-token.")
53
+
54
+ repo_id = f"{args.username}/{args.repo_name}"
55
+ local_dir = Path(args.local_dir).resolve()
56
+ local_dir.parent.mkdir(parents=True, exist_ok=True)
57
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
58
+
59
+ snapshot_path = snapshot_download(
60
+ repo_id=repo_id,
61
+ repo_type="dataset",
62
+ local_dir=str(local_dir),
63
+ local_dir_use_symlinks=False,
64
+ revision=args.revision,
65
+ token=token,
66
+ )
67
+ print(f"Downloaded snapshot to: {snapshot_path}")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
72
+
scripts/hf_upload_repo.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import inspect
6
+ import os
7
+ from pathlib import Path
8
+
9
+ from huggingface_hub import HfApi, create_repo
10
+
11
+
12
+ def parse_args() -> argparse.Namespace:
13
+ parser = argparse.ArgumentParser(
14
+ description="Upload a local folder to a Hugging Face dataset repo."
15
+ )
16
+ parser.add_argument(
17
+ "--folder-path",
18
+ type=str,
19
+ default=".",
20
+ help="Local folder to upload (default: current directory).",
21
+ )
22
+ parser.add_argument(
23
+ "--repo-name",
24
+ type=str,
25
+ default="neuralese-move",
26
+ help="HF dataset repo name under the username.",
27
+ )
28
+ parser.add_argument(
29
+ "--username",
30
+ type=str,
31
+ default="psidharth567",
32
+ help="HF username/org (default: psidharth567).",
33
+ )
34
+ parser.add_argument(
35
+ "--private",
36
+ action="store_true",
37
+ help="Create repo as private (default is public).",
38
+ )
39
+ parser.add_argument(
40
+ "--path-in-repo",
41
+ type=str,
42
+ default=".",
43
+ help="Destination path inside HF repo (default: repo root).",
44
+ )
45
+ parser.add_argument(
46
+ "--commit-message",
47
+ type=str,
48
+ default="temporary cluster migration upload",
49
+ help="Commit message for upload.",
50
+ )
51
+ parser.add_argument(
52
+ "--ignore",
53
+ type=str,
54
+ nargs="*",
55
+ default=[
56
+ "**/.git/**",
57
+ "**/__pycache__/**",
58
+ "**/*.pyc",
59
+ "**/*.pyo",
60
+ "**/.DS_Store",
61
+ ],
62
+ help="Glob patterns to ignore during upload.",
63
+ )
64
+ parser.add_argument(
65
+ "--hf-token",
66
+ type=str,
67
+ default=os.environ.get("HF_TOKEN", ""),
68
+ help="HF token (or set HF_TOKEN env var).",
69
+ )
70
+ return parser.parse_args()
71
+
72
+
73
+ def main() -> None:
74
+ args = parse_args()
75
+ token = args.hf_token.strip()
76
+ if not token:
77
+ raise RuntimeError("Missing HF token. Set HF_TOKEN or pass --hf-token.")
78
+
79
+ folder = Path(args.folder_path).resolve()
80
+ if not folder.exists():
81
+ raise FileNotFoundError(f"Folder does not exist: {folder}")
82
+ if not folder.is_dir():
83
+ raise NotADirectoryError(f"Not a directory: {folder}")
84
+
85
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
86
+
87
+ repo_id = f"{args.username}/{args.repo_name}"
88
+ api = HfApi(token=token)
89
+
90
+ # Create repo if needed; no-op if already exists.
91
+ create_repo(
92
+ repo_id=repo_id,
93
+ repo_type="dataset",
94
+ private=bool(args.private),
95
+ exist_ok=True,
96
+ token=token,
97
+ )
98
+
99
+ large_sig = inspect.signature(api.upload_large_folder)
100
+ large_params = set(large_sig.parameters.keys())
101
+ upload_kwargs = {
102
+ "repo_id": repo_id,
103
+ "repo_type": "dataset",
104
+ "folder_path": str(folder),
105
+ "ignore_patterns": args.ignore,
106
+ }
107
+ if "token" in large_params:
108
+ upload_kwargs["token"] = token
109
+ # Some huggingface_hub versions support path_in_repo for large uploads,
110
+ # others do not. Pass it only when supported.
111
+ if "path_in_repo" in large_params:
112
+ upload_kwargs["path_in_repo"] = args.path_in_repo
113
+ elif args.path_in_repo not in (".", "", "/"):
114
+ print(
115
+ "Warning: this huggingface_hub version does not support "
116
+ "`path_in_repo` for upload_large_folder; uploading to repo root."
117
+ )
118
+
119
+ api.upload_large_folder(**upload_kwargs)
120
+
121
+ info = api.repo_info(repo_id=repo_id, repo_type="dataset")
122
+ print(f"Upload complete: https://huggingface.co/datasets/{repo_id}")
123
+ print(f"Latest commit: {info.sha}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
128
+
scripts/resume_grpo_8gpu.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
13
+
14
+ RESUME_FROM="${1:-latest}"
15
+ GPU_COUNT="${2:-${GPU_COUNT:-8}}"
16
+
17
+ export WANDB_MODE="${WANDB_MODE:-offline}"
18
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
19
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
20
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
21
+
22
+ accelerate launch \
23
+ --config_file configs/accelerate_ddp_8gpu.yaml \
24
+ --num_processes "${GPU_COUNT}" \
25
+ src/train_grpo.py \
26
+ --config configs/grpo_llama32_3b_bf16.yaml \
27
+ --resume_from_checkpoint "${RESUME_FROM}"
scripts/run_grpo.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
13
+
14
+ export WANDB_MODE="${WANDB_MODE:-offline}"
15
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
16
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
17
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
18
+
19
+ python src/train_grpo.py --config configs/grpo_llama32_3b_bf16.yaml
scripts/run_grpo_2gpu.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
13
+
14
+ export WANDB_MODE="${WANDB_MODE:-offline}"
15
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
16
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
17
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
18
+
19
+ accelerate launch \
20
+ --config_file configs/accelerate_ddp_2gpu.yaml \
21
+ src/train_grpo.py \
22
+ --config configs/grpo_llama32_3b_bf16.yaml
scripts/run_grpo_4gpu.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
13
+
14
+ export WANDB_MODE="${WANDB_MODE:-offline}"
15
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
16
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
17
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
18
+
19
+ accelerate launch \
20
+ --config_file configs/accelerate_ddp_4gpu.yaml \
21
+ src/train_grpo.py \
22
+ --config configs/grpo_llama32_3b_bf16.yaml
scripts/run_grpo_8gpu.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
13
+
14
+ GPU_COUNT="${1:-${GPU_COUNT:-8}}"
15
+
16
+ export WANDB_MODE="${WANDB_MODE:-offline}"
17
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
18
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
19
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
20
+
21
+ accelerate launch \
22
+ --config_file configs/accelerate_ddp_8gpu.yaml \
23
+ --num_processes "${GPU_COUNT}" \
24
+ src/train_grpo.py \
25
+ --config configs/grpo_llama32_3b_bf16.yaml
scripts/run_grpo_thinking_kl_masked_resume_4gpu.sh ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Defer `set -u` until after conda: nvidia/conda hooks use unset vars during activate.
3
+ set -eo pipefail
4
+
5
+ # Fine-tune from a stage-1 checkpoint with KL *disabled* on tokens inside
6
+ # <redacted_thinking>...</redacted_thinking> (inner body only). KL still applies
7
+ # to tags, \\boxed{...}, and any text outside the inner thinking span.
8
+ #
9
+ # Conda env: sidharth (override with CONDA_ENV=...).
10
+ #
11
+ # Defaults: online Hugging Face / datasets (fetch if missing). Override with
12
+ # HF_HUB_OFFLINE=1 HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
13
+ #
14
+ # Effective microbatch for GRPO: 4 GPUs × per_device_batch × grad_accum.
15
+ # Matches configs/grpo_llama32_3b_bf16.yaml baseline: per_device=1, grad_accum=8 -> 32.
16
+ #
17
+ # Overrides (defaults tuned for this launcher):
18
+ # LENGTH_PENALTY_LAMBDA=0.5
19
+ # PERMANENT_CHECKPOINT_STEPS=300 (set to 0 to disable extra copies under permanent_checkpoint_dir)
20
+ # PERMANENT_CHECKPOINT_DIR=checkpoints/permanent
21
+
22
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
23
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
24
+ # shellcheck disable=SC1091
25
+ source "${SCRIPT_DIR}/conda_env.sh"
26
+ set -u
27
+
28
+ cd "${REPO_ROOT}"
29
+
30
+ # DeepSpeed is imported by some TRL/Accelerate stacks even when using plain DDP. Its import path
31
+ # calls installed_cuda_version() and requires CUDA_HOME if CUDA ops are probed. PyTorch GPU builds
32
+ # often work without nvcc, but DeepSpeed does not — set CUDA_HOME from nvcc when unset, or run:
33
+ # pip uninstall deepspeed
34
+ if [[ -z "${CUDA_HOME:-}" ]] && command -v nvcc >/dev/null 2>&1; then
35
+ export CUDA_HOME="$(dirname "$(dirname "$(command -v nvcc)")")"
36
+ fi
37
+
38
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
39
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
40
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
41
+ PYTHON_BIN="${PYTHON_BIN:-python}"
42
+
43
+ CHECKPOINT="${CHECKPOINT:-${REPO_ROOT}/cache/artifacts/sweeps/twostage_correctness_weight/run_twostage_correctness1/checkpoints/permanent/checkpoint-stage1-boundary-epoch-1p0-step-1868}"
44
+
45
+ # 4 × 1 × 8 = 32 (same as base YAML; override with GRAD_ACCUM=...)
46
+ GRAD_ACCUM="${GRAD_ACCUM:-8}"
47
+ LENGTH_PENALTY_LAMBDA="${LENGTH_PENALTY_LAMBDA:-0.5}"
48
+ PERMANENT_CHECKPOINT_STEPS="${PERMANENT_CHECKPOINT_STEPS:-300}"
49
+ PERMANENT_CHECKPOINT_DIR="${PERMANENT_CHECKPOINT_DIR:-checkpoints/permanent}"
50
+
51
+ OUT_ROOT="${OUT_ROOT:-artifacts/runs/grpo_thinking_kl_masked_from_stage1_4gpu}"
52
+ RUN_NAME="${RUN_NAME:-grpo-think-kl0-resume-stage1-4gpu}"
53
+ export THINKING_INNER_KL_WEIGHT="${THINKING_INNER_KL_WEIGHT:-0.0}"
54
+
55
+ export WANDB_MODE="${WANDB_MODE:-online}"
56
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-0}"
57
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
58
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-0}"
59
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
60
+ export PYTHONPATH="${REPO_ROOT}/src"
61
+
62
+ mkdir -p "${OUT_ROOT}"
63
+ TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
64
+ trap 'rm -rf "${TMP_DIR}"' EXIT
65
+ CFG_PATH="${TMP_DIR}/grpo_thinking_kl_4gpu.yaml"
66
+
67
+ BASE_CONFIG="${BASE_CONFIG}" \
68
+ OUT_ROOT="${OUT_ROOT}" \
69
+ RUN_NAME="${RUN_NAME}" \
70
+ GRAD_ACCUM="${GRAD_ACCUM}" \
71
+ THINKING_INNER_KL_WEIGHT="${THINKING_INNER_KL_WEIGHT}" \
72
+ LENGTH_PENALTY_LAMBDA="${LENGTH_PENALTY_LAMBDA}" \
73
+ PERMANENT_CHECKPOINT_STEPS="${PERMANENT_CHECKPOINT_STEPS}" \
74
+ PERMANENT_CHECKPOINT_DIR="${PERMANENT_CHECKPOINT_DIR}" \
75
+ CFG_PATH="${CFG_PATH}" \
76
+ "${PYTHON_BIN}" - <<'PY'
77
+ import copy
78
+ import os
79
+ from pathlib import Path
80
+
81
+ import yaml
82
+
83
+ base = Path(os.environ["BASE_CONFIG"])
84
+ out_root = Path(os.environ["OUT_ROOT"])
85
+ run_name = os.environ["RUN_NAME"]
86
+ grad_accum = int(os.environ["GRAD_ACCUM"])
87
+ cfg_path = Path(os.environ["CFG_PATH"])
88
+
89
+ with base.open("r", encoding="utf-8") as f:
90
+ cfg = yaml.safe_load(f)
91
+
92
+ cfg = copy.deepcopy(cfg)
93
+ cfg.setdefault("model", {})
94
+ # Local Unsloth-touched checkpoints often need this when loading with Transformers.
95
+ cfg["model"]["trust_remote_code"] = True
96
+ cfg.setdefault("trainer", {})
97
+ cfg["trainer"]["gradient_accumulation_steps"] = grad_accum
98
+ cfg["trainer"]["run_name"] = run_name
99
+ cfg["trainer"]["output_dir"] = str(out_root)
100
+ cfg["trainer"]["permanent_checkpoint_steps"] = int(os.environ["PERMANENT_CHECKPOINT_STEPS"])
101
+ cfg["trainer"]["permanent_checkpoint_dir"] = os.environ["PERMANENT_CHECKPOINT_DIR"]
102
+
103
+ cfg.setdefault("objective", {}).setdefault("kwargs", {})
104
+ cfg["objective"]["kwargs"]["length_penalty_lambda"] = float(os.environ["LENGTH_PENALTY_LAMBDA"])
105
+
106
+ cfg.setdefault("thinking_kl", {})
107
+ # 0.0 = no KL on inner <redacted_thinking>...</redacted_thinking> body; override with THINKING_INNER_KL_WEIGHT=1
108
+ cfg["thinking_kl"]["inner_kl_weight"] = float(os.environ.get("THINKING_INNER_KL_WEIGHT", "0.0"))
109
+
110
+ cfg.setdefault("grpo", {})
111
+ cfg["grpo"]["use_liger_kernel"] = False
112
+ # Ensure KL term exists so masking is meaningful (raise beta in YAML if you use 0 today)
113
+ if cfg["grpo"].get("beta", 0) == 0:
114
+ cfg["grpo"]["beta"] = 0.001
115
+
116
+ with cfg_path.open("w", encoding="utf-8") as f:
117
+ yaml.safe_dump(cfg, f, sort_keys=False)
118
+
119
+ print("Wrote", cfg_path)
120
+ print("output_dir:", cfg["trainer"]["output_dir"])
121
+ print("thinking_kl.inner_kl_weight:", cfg["thinking_kl"]["inner_kl_weight"])
122
+ print("trainer.gradient_accumulation_steps:", cfg["trainer"]["gradient_accumulation_steps"])
123
+ print("objective.kwargs.length_penalty_lambda:", cfg["objective"]["kwargs"]["length_penalty_lambda"])
124
+ print("trainer.permanent_checkpoint_steps:", cfg["trainer"]["permanent_checkpoint_steps"])
125
+ print("trainer.permanent_checkpoint_dir:", cfg["trainer"]["permanent_checkpoint_dir"])
126
+ PY
127
+
128
+ echo "Checkpoint: ${CHECKPOINT}"
129
+ echo "Config: ${CFG_PATH}"
130
+ echo "Launching ${NUM_PROCESSES} processes (online HF/datasets unless overridden)..."
131
+
132
+ accelerate launch \
133
+ --config_file "${ACCELERATE_CONFIG}" \
134
+ --num_processes "${NUM_PROCESSES}" \
135
+ src/train_grpo.py \
136
+ --config "${CFG_PATH}" \
137
+ --model_path "${CHECKPOINT}"
138
+
139
+ echo "Done."
scripts/run_lambda_0p1_existing_gate_token_util.sh ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ cd "${REPO_ROOT}"
7
+ # shellcheck disable=SC1091
8
+ source "${SCRIPT_DIR}/conda_env.sh"
9
+ set -u
10
+
11
+
12
+ # Train one GRPO run with:
13
+ # - lambda = 0.1
14
+ # - strict format gate = true
15
+ # - non-multiplicative interaction (correctness_length)
16
+
17
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
18
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
19
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
20
+ PYTHON_BIN="${PYTHON_BIN:-python}"
21
+ OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/reward_variants_lambda_0p1}"
22
+
23
+ export WANDB_MODE="${WANDB_MODE:-offline}"
24
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
25
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
26
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
27
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
28
+ export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
29
+
30
+ mkdir -p "${OUT_ROOT}"
31
+ LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_lambda_0p1_existing_gate_token_util_$(date +%Y%m%d_%H%M%S).log}"
32
+ exec > >(tee -a "${LOG_PATH}") 2>&1
33
+
34
+ TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
35
+ trap 'rm -rf "${TMP_DIR}"' EXIT
36
+ CFG_PATH="${TMP_DIR}/grpo_lambda_0p1_existing_gate_token_util.yaml"
37
+
38
+ echo "Preparing config..."
39
+ echo "Base config: ${BASE_CONFIG}"
40
+ echo "Out root: ${OUT_ROOT}"
41
+ echo "Accelerate config: ${ACCELERATE_CONFIG}"
42
+ echo "Num processes (GPUs): ${NUM_PROCESSES}"
43
+ echo "Log path: ${LOG_PATH}"
44
+
45
+ BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" REPO_ROOT="${REPO_ROOT}" "${PYTHON_BIN}" - <<'PY'
46
+ import copy
47
+ import os
48
+ from pathlib import Path
49
+
50
+ import yaml
51
+
52
+ repo_root = Path(os.environ["REPO_ROOT"])
53
+ base_config = Path(os.environ["BASE_CONFIG"])
54
+ out_root = Path(os.environ["OUT_ROOT"])
55
+ cfg_path = Path(os.environ["CFG_PATH"])
56
+
57
+ with base_config.open("r", encoding="utf-8") as handle:
58
+ cfg = yaml.safe_load(handle)
59
+
60
+ cfg = copy.deepcopy(cfg)
61
+
62
+ cfg.setdefault("rewards", {})
63
+ cfg["rewards"].setdefault("kwargs", {})
64
+ rewards_kwargs = cfg["rewards"]["kwargs"]
65
+ rewards_kwargs.setdefault("token_utilisation_reward", {})
66
+ rewards_kwargs["token_utilisation_reward"]["results_jsonl_path"] = str(
67
+ repo_root / "artifacts/eval/gsm8k_train_zeroshot/results.jsonl"
68
+ )
69
+
70
+ cfg.setdefault("objective", {})
71
+ cfg["objective"].setdefault("kwargs", {})
72
+ obj = cfg["objective"]["kwargs"]
73
+
74
+ obj["enable_length_penalty"] = True
75
+ obj["enable_token_utilisation_reward"] = True
76
+ obj["reward_mode"] = "weighted_length_penalty"
77
+ obj["length_penalty_lambda"] = 0.1
78
+ obj["length_penalty_interaction"] = "correctness_length_format"
79
+ obj["strict_format_gate"] = True
80
+ obj["non_strict_penalty"] = -1.0
81
+
82
+ cfg.setdefault("trainer", {})
83
+ base_run_name = cfg["trainer"].get("run_name", "grpo")
84
+ cfg["trainer"]["run_name"] = f"{base_run_name}-lambda-0p1-existing-gate-token-util"
85
+ cfg["trainer"]["output_dir"] = str(out_root / "run_lambda_0p1_existing_gate_token_util")
86
+
87
+ with cfg_path.open("w", encoding="utf-8") as handle:
88
+ yaml.safe_dump(cfg, handle, sort_keys=False)
89
+
90
+ print(f"Wrote config: {cfg_path}")
91
+ print(f"Run name: {cfg['trainer']['run_name']}")
92
+ print(f"Output dir: {cfg['trainer']['output_dir']}")
93
+ PY
94
+
95
+ echo
96
+ echo "Starting training..."
97
+ export PYTHONPATH="${REPO_ROOT}/src"
98
+ accelerate launch \
99
+ --config_file "${ACCELERATE_CONFIG}" \
100
+ --num_processes "${NUM_PROCESSES}" \
101
+ src/train_grpo.py \
102
+ --config "${CFG_PATH}"
103
+
104
+ echo
105
+ echo "Done."
106
+ echo "Model outputs/checkpoints under: ${OUT_ROOT}/run_lambda_0p1_existing_gate_token_util"
107
+ echo "Full log: ${LOG_PATH}"
scripts/run_reward_variants_and_eval.sh ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ # Train 3 reward variants sequentially, then evaluate all and save outputs+plot.
14
+ #
15
+ # Variants (all lambda=0.75):
16
+ # 1) existing_gate:
17
+ # total = r_correctness + lambda*(r_correctness*r_length) + r_format
18
+ # strict_format_gate = true
19
+ # 2) mult_nogate:
20
+ # total = r_correctness + lambda*(r_correctness*r_length*r_format) + r_format
21
+ # strict_format_gate = false
22
+ # 3) mult_gate:
23
+ # total = r_correctness + lambda*(r_correctness*r_length*r_format) + r_format
24
+ # strict_format_gate = true
25
+
26
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
27
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
28
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
29
+ PYTHON_BIN="${PYTHON_BIN:-python}"
30
+ LAMBDA="${LAMBDA:-0.75}"
31
+
32
+ SWEEP_ROOT="${SWEEP_ROOT:-artifacts/sweeps/reward_variants_lambda_0p75}"
33
+ EVAL_OUT_ROOT="${EVAL_OUT_ROOT:-${SWEEP_ROOT}/eval_results}"
34
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES:-200}"
35
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE:-4}"
36
+
37
+ export WANDB_MODE="${WANDB_MODE:-offline}"
38
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
39
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
40
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
41
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
42
+ export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
43
+
44
+ mkdir -p "${SWEEP_ROOT}"
45
+ LOG_PATH="${LOG_PATH:-${SWEEP_ROOT}/run_reward_variants_and_eval_$(date +%Y%m%d_%H%M%S).log}"
46
+ exec > >(tee -a "${LOG_PATH}") 2>&1
47
+
48
+ TMP_DIR="$(mktemp -d "${SWEEP_ROOT}/tmp_cfgs.XXXXXX")"
49
+ trap 'rm -rf "${TMP_DIR}"' EXIT
50
+
51
+ VARIANTS=("existing_gate" "mult_nogate" "mult_gate")
52
+
53
+ echo "Starting 3-variant reward sweep..."
54
+ echo "Base config: ${BASE_CONFIG}"
55
+ echo "Sweep root: ${SWEEP_ROOT}"
56
+ echo "Lambda: ${LAMBDA}"
57
+ echo "Num processes (GPUs): ${NUM_PROCESSES}"
58
+ echo "Accelerate config: ${ACCELERATE_CONFIG}"
59
+ echo "Live+saved log path: ${LOG_PATH}"
60
+
61
+ for VARIANT in "${VARIANTS[@]}"; do
62
+ echo
63
+ echo "=== Training variant: ${VARIANT} ==="
64
+ CFG_PATH="${TMP_DIR}/grpo_${VARIANT}.yaml"
65
+
66
+ BASE_CONFIG="${BASE_CONFIG}" \
67
+ SWEEP_ROOT="${SWEEP_ROOT}" \
68
+ CFG_PATH="${CFG_PATH}" \
69
+ VARIANT="${VARIANT}" \
70
+ LAMBDA="${LAMBDA}" \
71
+ "${PYTHON_BIN}" - <<'PY'
72
+ import copy
73
+ import os
74
+ from pathlib import Path
75
+
76
+ import yaml
77
+
78
+ base_config = Path(os.environ["BASE_CONFIG"])
79
+ sweep_root = Path(os.environ["SWEEP_ROOT"])
80
+ cfg_path = Path(os.environ["CFG_PATH"])
81
+ variant = os.environ["VARIANT"]
82
+ lam = float(os.environ["LAMBDA"])
83
+
84
+ with base_config.open("r", encoding="utf-8") as handle:
85
+ cfg = yaml.safe_load(handle)
86
+
87
+ cfg = copy.deepcopy(cfg)
88
+ cfg.setdefault("objective", {})
89
+ cfg["objective"].setdefault("kwargs", {})
90
+ obj = cfg["objective"]["kwargs"]
91
+
92
+ obj["enable_length_penalty"] = True
93
+ obj["reward_mode"] = "weighted_length_penalty"
94
+ obj["length_penalty_lambda"] = lam
95
+ obj["non_strict_penalty"] = -1.0
96
+
97
+ if variant == "existing_gate":
98
+ obj["length_penalty_interaction"] = "correctness_length"
99
+ obj["strict_format_gate"] = True
100
+ elif variant == "mult_nogate":
101
+ obj["length_penalty_interaction"] = "correctness_length_format"
102
+ obj["strict_format_gate"] = False
103
+ elif variant == "mult_gate":
104
+ obj["length_penalty_interaction"] = "correctness_length_format"
105
+ obj["strict_format_gate"] = True
106
+ else:
107
+ raise ValueError(f"Unknown variant: {variant}")
108
+
109
+ cfg.setdefault("trainer", {})
110
+ base_run_name = cfg["trainer"].get("run_name", "grpo")
111
+ cfg["trainer"]["run_name"] = f"{base_run_name}-{variant}"
112
+ cfg["trainer"]["output_dir"] = str(sweep_root / f"run_lambda_0p75_{variant}")
113
+
114
+ with cfg_path.open("w", encoding="utf-8") as handle:
115
+ yaml.safe_dump(cfg, handle, sort_keys=False)
116
+
117
+ print(f"Wrote {cfg_path}")
118
+ print(f"Run dir: {cfg['trainer']['output_dir']}")
119
+ PY
120
+
121
+ accelerate launch \
122
+ --config_file "${ACCELERATE_CONFIG}" \
123
+ --num_processes "${NUM_PROCESSES}" \
124
+ src/train_grpo.py \
125
+ --config "${CFG_PATH}"
126
+ done
127
+
128
+ echo
129
+ echo "=== Running sweep evaluation on all trained variants ==="
130
+ BASE_CONFIG="${BASE_CONFIG}" \
131
+ SWEEP_ROOT="${SWEEP_ROOT}" \
132
+ OUT_ROOT="${EVAL_OUT_ROOT}" \
133
+ EVAL_MAX_SAMPLES="${EVAL_MAX_SAMPLES}" \
134
+ EVAL_BATCH_SIZE="${EVAL_BATCH_SIZE}" \
135
+ accelerate launch \
136
+ --config_file "${ACCELERATE_CONFIG}" \
137
+ --num_processes "${NUM_PROCESSES}" \
138
+ src/eval_sweep_models.py
139
+
140
+ echo
141
+ echo "All done."
142
+ echo "Train runs: ${SWEEP_ROOT}/run_lambda_0p75_*"
143
+ echo "Eval summary JSON: ${EVAL_OUT_ROOT}/sweep_eval_summary.json"
144
+ echo "Eval summary CSV: ${EVAL_OUT_ROOT}/sweep_eval_summary.csv"
145
+ echo "Eval outputs: ${EVAL_OUT_ROOT}/outputs/"
146
+ echo "Accuracy plot: ${EVAL_OUT_ROOT}/sweep_eval_accuracy.svg"
scripts/run_sft_gsm8k_boxed_7gpu.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # SFT on GSM8K: assistant target is only \boxed{answer} (no thinking). DDP on 7 GPUs, then GSM8K test eval on GPU 0.
3
+ #
4
+ # Usage:
5
+ # ./scripts/run_sft_gsm8k_boxed_7gpu.sh
6
+ # MODEL_PATH=/path/to/model HF_DATASETS_OFFLINE=0 ./scripts/run_sft_gsm8k_boxed_7gpu.sh
7
+ # NUM_TRAIN_EPOCHS=3 ./scripts/run_sft_gsm8k_boxed_7gpu.sh
8
+ set -eo pipefail
9
+
10
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
11
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
12
+ # shellcheck disable=SC1091
13
+ source "${SCRIPT_DIR}/conda_env.sh"
14
+ set -u
15
+
16
+ cd "${REPO_ROOT}"
17
+
18
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_7gpu.yaml}"
19
+ NUM_PROCESSES="${NUM_PROCESSES:-7}"
20
+ MODEL_PATH="${MODEL_PATH:-${REPO_ROOT}/cache/models/models--unsloth--Llama-3.2-3B-Instruct}"
21
+ # Unsloth hub cache snapshots often lack config.json / model_type; SFT pulls config from this id (HF token if gated).
22
+ HUB_MODEL_ID="${HUB_MODEL_ID:-meta-llama/Llama-3.2-3B-Instruct}"
23
+ OUTPUT_DIR="${OUTPUT_DIR:-artifacts/runs/sft_gsm8k_boxed}"
24
+ NUM_TRAIN_EPOCHS="${NUM_TRAIN_EPOCHS:-1}"
25
+
26
+ if [[ -z "${CUDA_HOME:-}" ]] && command -v nvcc >/dev/null 2>&1; then
27
+ export CUDA_HOME="$(dirname "$(dirname "$(command -v nvcc)")")"
28
+ fi
29
+
30
+ export PYTHONPATH="${REPO_ROOT}/src"
31
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
32
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-0}"
33
+
34
+ # Everything after src/sft_gsm8k_boxed.py is passed to Python, not to accelerate.
35
+ accelerate launch \
36
+ --config_file "${ACCELERATE_CONFIG}" \
37
+ --num_processes "${NUM_PROCESSES}" \
38
+ src/sft_gsm8k_boxed.py \
39
+ --model_path "${MODEL_PATH}" \
40
+ --hub_model_id "${HUB_MODEL_ID}" \
41
+ --output_dir "${OUTPUT_DIR}" \
42
+ --num_train_epochs "${NUM_TRAIN_EPOCHS}"
43
+
44
+ echo "Done. Weights: ${OUTPUT_DIR}/final_model metrics: ${OUTPUT_DIR}/gsm8k_test_metrics.json"
scripts/run_twostage_correctness1.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ # Two-stage run with correctness_weight = 1.0
14
+ # Stage 1/2 schedule is inherited from BASE_CONFIG:
15
+ # - length_penalty_lambda (stage 1)
16
+ # - stage2_length_penalty_lambda (stage 2)
17
+ # - stage2_start_epoch
18
+
19
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
20
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
21
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
22
+ PYTHON_BIN="${PYTHON_BIN:-python}"
23
+ OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight}"
24
+
25
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
26
+ export WANDB_MODE="${WANDB_MODE:-offline}"
27
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
28
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
29
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
30
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
31
+ export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
32
+
33
+ mkdir -p "${OUT_ROOT}"
34
+ LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_twostage_correctness1_$(date +%Y%m%d_%H%M%S).log}"
35
+ exec > >(tee -a "${LOG_PATH}") 2>&1
36
+
37
+ TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
38
+ trap 'rm -rf "${TMP_DIR}"' EXIT
39
+ CFG_PATH="${TMP_DIR}/grpo_twostage_correctness1.yaml"
40
+
41
+ BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
42
+ import copy
43
+ import os
44
+ from pathlib import Path
45
+
46
+ import yaml
47
+
48
+ base_config = Path(os.environ["BASE_CONFIG"])
49
+ out_root = Path(os.environ["OUT_ROOT"])
50
+ cfg_path = Path(os.environ["CFG_PATH"])
51
+
52
+ with base_config.open("r", encoding="utf-8") as handle:
53
+ cfg = yaml.safe_load(handle)
54
+
55
+ cfg = copy.deepcopy(cfg)
56
+ cfg.setdefault("objective", {})
57
+ cfg["objective"].setdefault("kwargs", {})
58
+ cfg["objective"]["kwargs"]["correctness_weight"] = 1.0
59
+
60
+ cfg.setdefault("trainer", {})
61
+ base_run_name = cfg["trainer"].get("run_name", "grpo")
62
+ cfg["trainer"]["run_name"] = f"{base_run_name}-twostage-cw1"
63
+ cfg["trainer"]["output_dir"] = str(out_root / "run_twostage_correctness1")
64
+
65
+ with cfg_path.open("w", encoding="utf-8") as handle:
66
+ yaml.safe_dump(cfg, handle, sort_keys=False)
67
+
68
+ print(f"Wrote config: {cfg_path}")
69
+ print(f"Output dir: {cfg['trainer']['output_dir']}")
70
+ PY
71
+
72
+ echo "Starting 4-GPU training (correctness_weight=1.0)"
73
+ accelerate launch \
74
+ --config_file "${ACCELERATE_CONFIG}" \
75
+ --num_processes "${NUM_PROCESSES}" \
76
+ src/train_grpo.py \
77
+ --config "${CFG_PATH}"
78
+
79
+ echo "Done. Log: ${LOG_PATH}"
scripts/run_twostage_correctness5.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ # Two-stage run with correctness_weight = 5.0
14
+ # Stage 1/2 schedule is inherited from BASE_CONFIG:
15
+ # - length_penalty_lambda (stage 1)
16
+ # - stage2_length_penalty_lambda (stage 2)
17
+ # - stage2_start_epoch
18
+
19
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
20
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
21
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
22
+ PYTHON_BIN="${PYTHON_BIN:-python}"
23
+ OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/twostage_correctness_weight}"
24
+
25
+ export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/cache/logs/wandb}"
26
+ export WANDB_MODE="${WANDB_MODE:-offline}"
27
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
28
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
29
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
30
+ export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}"
31
+ export ACCELERATE_LOG_LEVEL="${ACCELERATE_LOG_LEVEL:-info}"
32
+
33
+ mkdir -p "${OUT_ROOT}"
34
+ LOG_PATH="${LOG_PATH:-${OUT_ROOT}/run_twostage_correctness5_$(date +%Y%m%d_%H%M%S).log}"
35
+ exec > >(tee -a "${LOG_PATH}") 2>&1
36
+
37
+ TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
38
+ trap 'rm -rf "${TMP_DIR}"' EXIT
39
+ CFG_PATH="${TMP_DIR}/grpo_twostage_correctness5.yaml"
40
+
41
+ BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
42
+ import copy
43
+ import os
44
+ from pathlib import Path
45
+
46
+ import yaml
47
+
48
+ base_config = Path(os.environ["BASE_CONFIG"])
49
+ out_root = Path(os.environ["OUT_ROOT"])
50
+ cfg_path = Path(os.environ["CFG_PATH"])
51
+
52
+ with base_config.open("r", encoding="utf-8") as handle:
53
+ cfg = yaml.safe_load(handle)
54
+
55
+ cfg = copy.deepcopy(cfg)
56
+ cfg.setdefault("objective", {})
57
+ cfg["objective"].setdefault("kwargs", {})
58
+ cfg["objective"]["kwargs"]["correctness_weight"] = 5.0
59
+
60
+ cfg.setdefault("trainer", {})
61
+ base_run_name = cfg["trainer"].get("run_name", "grpo")
62
+ cfg["trainer"]["run_name"] = f"{base_run_name}-twostage-cw5"
63
+ cfg["trainer"]["output_dir"] = str(out_root / "run_twostage_correctness5")
64
+
65
+ with cfg_path.open("w", encoding="utf-8") as handle:
66
+ yaml.safe_dump(cfg, handle, sort_keys=False)
67
+
68
+ print(f"Wrote config: {cfg_path}")
69
+ print(f"Output dir: {cfg['trainer']['output_dir']}")
70
+ PY
71
+
72
+ echo "Starting 4-GPU training (correctness_weight=5.0)"
73
+ accelerate launch \
74
+ --config_file "${ACCELERATE_CONFIG}" \
75
+ --num_processes "${NUM_PROCESSES}" \
76
+ src/train_grpo.py \
77
+ --config "${CFG_PATH}"
78
+
79
+ echo "Done. Log: ${LOG_PATH}"
scripts/sweep_length_penalty_lambda.sh ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -eo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ # shellcheck disable=SC1091
6
+ source "${SCRIPT_DIR}/conda_env.sh"
7
+ set -u
8
+
9
+ REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
10
+ cd "${REPO_ROOT}"
11
+ export PYTHONPATH="${REPO_ROOT}/src"
12
+
13
+ # Sweep two GRPO runs over length_penalty_lambda:
14
+ # - 0.5
15
+ # - 0.25
16
+ #
17
+ # Each run keeps all other settings identical to BASE_CONFIG.
18
+
19
+ BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}"
20
+ ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}"
21
+ NUM_PROCESSES="${NUM_PROCESSES:-4}"
22
+ PYTHON_BIN="${PYTHON_BIN:-python}"
23
+ OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/length_penalty_lambda}"
24
+ LAMBDAS="${LAMBDAS:-0.5 0.25}"
25
+
26
+ export WANDB_MODE="${WANDB_MODE:-offline}"
27
+ export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}"
28
+ export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}"
29
+ export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}"
30
+
31
+ mkdir -p "${OUT_ROOT}"
32
+ TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")"
33
+ trap 'rm -rf "${TMP_DIR}"' EXIT
34
+
35
+ echo "Base config: ${BASE_CONFIG}"
36
+ echo "Lambdas: ${LAMBDAS}"
37
+ echo "Output root: ${OUT_ROOT}"
38
+
39
+ for LAMBDA in ${LAMBDAS}; do
40
+ echo
41
+ echo "=== Running lambda=${LAMBDA} ==="
42
+
43
+ CFG_PATH="${TMP_DIR}/grpo_lambda_${LAMBDA}.yaml"
44
+ BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" LAMBDA="${LAMBDA}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY'
45
+ import copy
46
+ import os
47
+ from pathlib import Path
48
+
49
+ import yaml
50
+
51
+ base_config = Path(os.environ["BASE_CONFIG"])
52
+ out_root = Path(os.environ["OUT_ROOT"])
53
+ cfg_path = Path(os.environ["CFG_PATH"])
54
+ lam = os.environ["LAMBDA"]
55
+
56
+ with base_config.open("r", encoding="utf-8") as handle:
57
+ cfg = yaml.safe_load(handle)
58
+
59
+ cfg = copy.deepcopy(cfg)
60
+ cfg.setdefault("objective", {})
61
+ cfg["objective"].setdefault("kwargs", {})
62
+ cfg["objective"]["kwargs"]["enable_length_penalty"] = True
63
+ cfg["objective"]["kwargs"]["reward_mode"] = "weighted_length_penalty"
64
+ cfg["objective"]["kwargs"]["length_penalty_lambda"] = float(lam)
65
+
66
+ cfg.setdefault("trainer", {})
67
+ base_run_name = cfg["trainer"].get("run_name", "grpo")
68
+ safe_lam = lam.replace(".", "p")
69
+ cfg["trainer"]["run_name"] = f"{base_run_name}-lambda-{safe_lam}"
70
+ cfg["trainer"]["output_dir"] = str(out_root / f"run_lambda_{safe_lam}")
71
+
72
+ with cfg_path.open("w", encoding="utf-8") as handle:
73
+ yaml.safe_dump(cfg, handle, sort_keys=False)
74
+
75
+ print(f"Wrote config: {cfg_path}")
76
+ print(f"run_name: {cfg['trainer']['run_name']}")
77
+ print(f"output_dir: {cfg['trainer']['output_dir']}")
78
+ PY
79
+
80
+ accelerate launch \
81
+ --config_file "${ACCELERATE_CONFIG}" \
82
+ --num_processes "${NUM_PROCESSES}" \
83
+ src/train_grpo.py \
84
+ --config "${CFG_PATH}"
85
+ done
86
+
87
+ echo
88
+ echo "Sweep complete. Runs are under: ${OUT_ROOT}"
src/eval_gsm8k_truncated_thinking.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GSM8K eval with the same short ``<think>...</think>`` tags as ``data_plugins.PROMPT_PREFIX``.
3
+
4
+ The inner span is truncated to at most ``--think_inner_token_limit`` tokenizer tokens (a close tag
5
+ is forced if the model would exceed that, including when it closes late). Then greedy generation
6
+ continues for the boxed answer.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import importlib.util
13
+ import json
14
+ import os
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ import yaml
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+
22
+ import hackable # noqa: F401
23
+ from hackable.data_plugins import GSM8KProvider
24
+ from hackable.paths import resolve_storage_path, storage_layout
25
+ from hackable.reward_plugins import gsm8k_correctness_reward
26
+ from hackable.utils import resolve_repo_path
27
+
28
+ THINK_OPEN = "<think>"
29
+ THINK_CLOSE = "</think>"
30
+
31
+
32
+ def _greedy_generate_rest(
33
+ model: AutoModelForCausalLM,
34
+ tokenizer: AutoTokenizer,
35
+ cur: torch.Tensor,
36
+ plen: int,
37
+ max_new_tokens: int,
38
+ device: torch.device,
39
+ ) -> str:
40
+ """Continue greedy from ``cur``; completion length (after prompt) is capped by ``max_new_tokens``."""
41
+ generated = int(cur.shape[1]) - plen
42
+ remaining = max_new_tokens - generated
43
+ if remaining <= 0:
44
+ return tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
45
+ attn = torch.ones_like(cur, device=device)
46
+ rest = model.generate(
47
+ cur,
48
+ attention_mask=attn,
49
+ max_new_tokens=remaining,
50
+ do_sample=False,
51
+ pad_token_id=tokenizer.pad_token_id,
52
+ eos_token_id=tokenizer.eos_token_id,
53
+ )
54
+ return tokenizer.decode(rest[0, plen:], skip_special_tokens=False)
55
+
56
+
57
+ def _load_yaml(path: str) -> dict:
58
+ with open(path, "r", encoding="utf-8") as handle:
59
+ return yaml.safe_load(handle)
60
+
61
+
62
+ def _model_dtype(cfg: dict) -> torch.dtype:
63
+ return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16
64
+
65
+
66
+ def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]:
67
+ layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache"))
68
+ return layout.datasets, layout.models
69
+
70
+
71
+ def _dist_info() -> tuple[int, int, int]:
72
+ rank = int(os.environ.get("RANK", "0"))
73
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
74
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
75
+ return rank, world_size, local_rank
76
+
77
+
78
+ def _init_distributed() -> tuple[int, int, int]:
79
+ rank, world_size, local_rank = _dist_info()
80
+ if world_size > 1 and not dist.is_initialized():
81
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
82
+ dist.init_process_group(backend=backend, init_method="env://")
83
+ return rank, world_size, local_rank
84
+
85
+
86
+ def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
87
+ candidate = Path(model_dir)
88
+ if candidate.is_absolute() and candidate.exists():
89
+ return candidate.resolve()
90
+ repo_local = resolve_repo_path(model_dir)
91
+ if repo_local.exists():
92
+ return repo_local
93
+ prefixed = resolve_storage_path(model_dir, base_cfg.get("storage", {}).get("cache_dir", "cache"))
94
+ if prefixed.exists():
95
+ return prefixed
96
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
97
+
98
+
99
+ def _flash_attn_impl() -> str:
100
+ if importlib.util.find_spec("flash_attn") is not None:
101
+ try:
102
+ __import__("flash_attn")
103
+ return "flash_attention_2"
104
+ except Exception:
105
+ pass
106
+ return "sdpa"
107
+
108
+
109
+ @torch.no_grad()
110
+ def generate_truncated_thinking(
111
+ model: AutoModelForCausalLM,
112
+ tokenizer: AutoTokenizer,
113
+ prompt: str,
114
+ device: torch.device,
115
+ inner_token_limit: int,
116
+ max_new_tokens: int,
117
+ ) -> str:
118
+ """
119
+ Inner reasoning is **always** at most ``inner_token_limit`` tokenizer tokens: we never keep a
120
+ longer span. If the model emits ``</redacted_thinking>`` after more than that many inner tokens,
121
+ we rewrite the prefix to the first ``inner_token_limit`` inner tokens, then ``</redacted_thinking>``,
122
+ then greedy-decode the boxed answer. If there is no close tag yet but inner already has that many
123
+ tokens, we force the close tag and decode.
124
+ """
125
+ enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
126
+ input_ids = enc["input_ids"].to(device)
127
+ plen = int(input_ids.shape[1])
128
+ prompt_ids = input_ids.clone()
129
+
130
+ cur = input_ids.clone()
131
+ generated_count = 0
132
+
133
+ def rewrite_and_continue(gen_text: str, after_open_idx: int, inner_text: str) -> str:
134
+ inner_ids = tokenizer.encode(inner_text, add_special_tokens=False)
135
+ inner_cut = inner_ids[:inner_token_limit]
136
+ inner_str = tokenizer.decode(inner_cut, skip_special_tokens=False)
137
+ fixed_comp = gen_text[:after_open_idx] + inner_str + THINK_CLOSE + "\n"
138
+ comp_ids = tokenizer.encode(fixed_comp, add_special_tokens=False)
139
+ comp_t = torch.tensor([comp_ids], dtype=torch.long, device=device)
140
+ new_cur = torch.cat([prompt_ids, comp_t], dim=-1)
141
+ return _greedy_generate_rest(model, tokenizer, new_cur, plen, max_new_tokens, device)
142
+
143
+ while generated_count < max_new_tokens:
144
+ out = model(cur)
145
+ next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
146
+ cur = torch.cat([cur, next_id], dim=-1)
147
+ generated_count += 1
148
+
149
+ gen_text = tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
150
+
151
+ # If the model never opens thinking, fall back to normal greedy (much faster).
152
+ if generated_count >= 128 and THINK_OPEN not in gen_text:
153
+ return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
154
+
155
+ if THINK_OPEN not in gen_text:
156
+ continue
157
+
158
+ after_open_idx = gen_text.find(THINK_OPEN) + len(THINK_OPEN)
159
+ after_open = gen_text[after_open_idx:]
160
+
161
+ if THINK_CLOSE in after_open:
162
+ inner_part, _rest = after_open.split(THINK_CLOSE, 1)
163
+ n_inner = len(tokenizer.encode(inner_part, add_special_tokens=False))
164
+ if n_inner > inner_token_limit:
165
+ return rewrite_and_continue(gen_text, after_open_idx, inner_part)
166
+ return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
167
+
168
+ inner_ids = tokenizer.encode(after_open, add_special_tokens=False)
169
+ if len(inner_ids) >= inner_token_limit:
170
+ inner_str = tokenizer.decode(inner_ids[:inner_token_limit], skip_special_tokens=False)
171
+ fixed_comp = gen_text[:after_open_idx] + inner_str + THINK_CLOSE + "\n"
172
+ comp_ids = tokenizer.encode(fixed_comp, add_special_tokens=False)
173
+ comp_t = torch.tensor([comp_ids], dtype=torch.long, device=device)
174
+ cur = torch.cat([prompt_ids, comp_t], dim=-1)
175
+ return _greedy_generate_rest(model, tokenizer, cur, plen, max_new_tokens, device)
176
+
177
+ return tokenizer.decode(cur[0, plen:], skip_special_tokens=False)
178
+
179
+
180
+ def main() -> None:
181
+ parser = argparse.ArgumentParser()
182
+ parser.add_argument("--config", type=str, default="configs/grpo_llama32_3b_bf16.yaml")
183
+ parser.add_argument(
184
+ "--model_path",
185
+ type=str,
186
+ required=True,
187
+ help="Checkpoint directory (e.g. artifacts/.../checkpoints/permanent/checkpoint-...).",
188
+ )
189
+ parser.add_argument("--think_inner_token_limit", type=int, default=5)
190
+ parser.add_argument("--max_samples", type=int, default=-1, help="GSM8K test examples; -1 = all.")
191
+ parser.add_argument("--max_new_tokens", type=int, default=256)
192
+ args = parser.parse_args()
193
+
194
+ rank, world_size, local_rank = _init_distributed()
195
+ base_cfg = _load_yaml(str(resolve_repo_path(args.config)))
196
+ model_dir = _resolve_local_model_dir(base_cfg, args.model_path)
197
+ datasets_cache, models_cache = _get_cache_paths(base_cfg)
198
+ dtype = _model_dtype(base_cfg)
199
+ model_name_fallback = str(base_cfg["model"]["name"])
200
+ trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", True))
201
+ max_prompt_len = int(base_cfg.get("generation", {}).get("max_prompt_length", 512))
202
+
203
+ local_files_only = os.environ.get("LOCAL_FILES_ONLY", "0").strip() in ("1", "true", "True")
204
+
205
+ try:
206
+ tokenizer = AutoTokenizer.from_pretrained(
207
+ str(model_dir),
208
+ trust_remote_code=trust_remote_code,
209
+ cache_dir=str(models_cache),
210
+ local_files_only=local_files_only,
211
+ )
212
+ except Exception:
213
+ tokenizer = AutoTokenizer.from_pretrained(
214
+ model_name_fallback,
215
+ trust_remote_code=trust_remote_code,
216
+ cache_dir=str(models_cache),
217
+ local_files_only=local_files_only,
218
+ )
219
+ if tokenizer.pad_token_id is None:
220
+ tokenizer.pad_token = tokenizer.eos_token
221
+
222
+ model = AutoModelForCausalLM.from_pretrained(
223
+ str(model_dir),
224
+ trust_remote_code=trust_remote_code,
225
+ cache_dir=str(models_cache),
226
+ torch_dtype=dtype,
227
+ local_files_only=local_files_only,
228
+ attn_implementation=_flash_attn_impl(),
229
+ )
230
+ if torch.cuda.is_available():
231
+ torch.cuda.set_device(local_rank)
232
+ device = torch.device(f"cuda:{local_rank}")
233
+ else:
234
+ device = torch.device("cpu")
235
+ model.to(device)
236
+ model.eval()
237
+
238
+ provider = GSM8KProvider()
239
+ eval_max = None if args.max_samples < 0 else args.max_samples
240
+ all_samples = provider.load(
241
+ split="test",
242
+ max_samples=eval_max,
243
+ cache_dir=str(datasets_cache),
244
+ )
245
+ indices = list(range(rank, len(all_samples), world_size))
246
+ local_samples = [all_samples[i] for i in indices]
247
+
248
+ records: list[dict] = []
249
+ for sample in local_samples:
250
+ prompt = sample.prompt
251
+ if len(tokenizer.encode(prompt)) > max_prompt_len:
252
+ # Match eval_sweep style: truncate prompt text is complex; skip ultra-long for this script
253
+ enc = tokenizer(prompt, truncation=True, max_length=max_prompt_len, return_tensors="pt")
254
+ prompt = tokenizer.decode(enc["input_ids"][0], skip_special_tokens=False)
255
+
256
+ completion = generate_truncated_thinking(
257
+ model,
258
+ tokenizer,
259
+ prompt,
260
+ device,
261
+ inner_token_limit=args.think_inner_token_limit,
262
+ max_new_tokens=args.max_new_tokens,
263
+ )
264
+ score = gsm8k_correctness_reward(
265
+ prompts=[prompt],
266
+ completions=[completion],
267
+ references=[sample.target],
268
+ metadata=[sample.metadata],
269
+ )[0]
270
+ records.append(
271
+ {
272
+ "sample_index": sample.metadata.get("sample_index", -1),
273
+ "correctness": float(score),
274
+ "prompt": prompt,
275
+ "reference": sample.target,
276
+ "completion": completion,
277
+ }
278
+ )
279
+
280
+ del model
281
+ if torch.cuda.is_available():
282
+ torch.cuda.empty_cache()
283
+
284
+ if dist.is_initialized():
285
+ gathered: list[list[dict] | None] = [None for _ in range(world_size)]
286
+ dist.all_gather_object(gathered, records)
287
+ merged: list[dict] = []
288
+ for part in gathered:
289
+ if part:
290
+ merged.extend(part)
291
+ else:
292
+ merged = records
293
+
294
+ if rank != 0:
295
+ return
296
+
297
+ merged.sort(key=lambda r: int(r.get("sample_index", 0)))
298
+ acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0
299
+ summary = {
300
+ "model_dir": str(model_dir),
301
+ "think_inner_token_limit": args.think_inner_token_limit,
302
+ "max_new_tokens": args.max_new_tokens,
303
+ "num_examples": len(merged),
304
+ "accuracy": float(acc),
305
+ }
306
+ print(json.dumps(summary, indent=2))
307
+
308
+ out = os.environ.get("OUTPUT_PATH")
309
+ if out:
310
+ p = resolve_repo_path(out)
311
+ p.parent.mkdir(parents=True, exist_ok=True)
312
+ with p.open("w", encoding="utf-8") as handle:
313
+ for row in merged:
314
+ handle.write(json.dumps(row, ensure_ascii=True) + "\n")
315
+ meta_path = p.with_suffix(".summary.json")
316
+ meta_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
317
+ print("Wrote", p, "and", meta_path)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()
src/eval_gsm8k_zeroshot.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GSM8K zero-shot eval for a single checkpoint (same pipeline as eval_sweep_models)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+
10
+ from eval_sweep_models import (
11
+ _init_distributed,
12
+ _load_yaml,
13
+ _resolve_local_model_dir,
14
+ _summarize,
15
+ evaluate_one_model,
16
+ )
17
+ from hackable.utils import resolve_repo_path
18
+
19
+
20
+ def _model_dir_from_env() -> str:
21
+ for key in ("MODEL_PATH", "MODEL_DIR"):
22
+ v = os.environ.get(key)
23
+ if v:
24
+ return v
25
+ raise SystemExit(
26
+ "Set MODEL_PATH or MODEL_DIR to the checkpoint directory "
27
+ "(e.g. artifacts/sweeps/kl_0/run_foo/checkpoints/permanent/checkpoint-1871)."
28
+ )
29
+
30
+
31
+ def _eval_max_samples() -> int:
32
+ if "EVAL_MAX_SAMPLES" in os.environ:
33
+ return int(os.environ["EVAL_MAX_SAMPLES"])
34
+ if "MAX_SAMPLES" in os.environ:
35
+ return int(os.environ["MAX_SAMPLES"])
36
+ return -1
37
+
38
+
39
+ def main() -> None:
40
+ rank, _, _ = _init_distributed()
41
+ base_cfg_path = str(resolve_repo_path(os.environ.get("BASE_CONFIG", "configs/grpo_llama32_3b_bf16.yaml")))
42
+ base_cfg = copy.deepcopy(_load_yaml(base_cfg_path))
43
+ gen = base_cfg.setdefault("generation", {})
44
+ if os.environ.get("MAX_NEW_TOKENS"):
45
+ gen["max_completion_length"] = int(os.environ["MAX_NEW_TOKENS"])
46
+ if os.environ.get("MAX_PROMPT_LENGTH"):
47
+ gen["max_prompt_length"] = int(os.environ["MAX_PROMPT_LENGTH"])
48
+
49
+ model_dir = _resolve_local_model_dir(base_cfg, _model_dir_from_env())
50
+ eval_max = _eval_max_samples()
51
+ batch_size = int(os.environ.get("BATCH_SIZE", "4"))
52
+
53
+ records = evaluate_one_model(
54
+ model_dir=model_dir,
55
+ base_cfg=base_cfg,
56
+ eval_max_samples=eval_max,
57
+ batch_size=batch_size,
58
+ )
59
+
60
+ if rank != 0:
61
+ return
62
+
63
+ out = resolve_repo_path(os.environ.get("OUTPUT_PATH", "artifacts/eval/gsm8k_zeroshot_answers.jsonl"))
64
+ out = Path(out)
65
+ out.parent.mkdir(parents=True, exist_ok=True)
66
+ with out.open("w", encoding="utf-8") as handle:
67
+ for row in records:
68
+ handle.write(json.dumps(row, ensure_ascii=True) + "\n")
69
+
70
+ summary = _summarize(records, str(model_dir))
71
+ summary["outputs_jsonl"] = str(out.resolve())
72
+ print(json.dumps(summary, indent=2))
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
src/eval_math_level1_thinking_zeroshot.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import yaml
10
+ from datasets import load_dataset
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ import hackable # noqa: F401
14
+ from hackable import reward_plugins as reward_plugins_mod
15
+ from hackable.utils import resolve_repo_path
16
+
17
+
18
+ THINKING_SYSTEM_PROMPT = (
19
+ "Solve the following math problem.\n"
20
+ "Think step-by-step inside <think>...</think> tags.\n"
21
+ "Then output only the final answer in LaTeX boxed format.\n"
22
+ "Do not include any words or explanations outside the tags/boxed answer.\n"
23
+ "Output format must be exactly:\n"
24
+ "<think>your reasoning</think>\n"
25
+ "\\boxed{your_final_answer}\n"
26
+ )
27
+
28
+
29
+ def _load_yaml(path: str) -> dict:
30
+ with open(path, "r", encoding="utf-8") as handle:
31
+ return yaml.safe_load(handle)
32
+
33
+
34
+ def _dist_info() -> tuple[int, int, int]:
35
+ rank = int(os.environ.get("RANK", "0"))
36
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
37
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
38
+ return rank, world_size, local_rank
39
+
40
+
41
+ def _init_distributed() -> tuple[int, int, int]:
42
+ rank, world_size, local_rank = _dist_info()
43
+ if world_size > 1 and not dist.is_initialized():
44
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
45
+ dist.init_process_group(backend=backend, init_method="env://")
46
+ return rank, world_size, local_rank
47
+
48
+
49
+ def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
50
+ candidate = Path(model_dir)
51
+ if candidate.is_absolute() and candidate.exists():
52
+ return candidate.resolve()
53
+ if not candidate.is_absolute() and candidate.exists():
54
+ return candidate.resolve()
55
+
56
+ repo_local = resolve_repo_path(model_dir)
57
+ if repo_local.exists():
58
+ return repo_local
59
+
60
+ cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
61
+ prefixed = (cache_root / candidate).resolve()
62
+ if prefixed.exists():
63
+ return prefixed
64
+
65
+ raise FileNotFoundError(
66
+ f"Model directory not found locally: '{model_dir}'. "
67
+ f"Tried '{candidate}', '{repo_local}', and '{prefixed}'."
68
+ )
69
+
70
+
71
+ def _build_chat_prompts(
72
+ tokenizer: AutoTokenizer, questions: list[str], system_prompt: str
73
+ ) -> list[str]:
74
+ if getattr(tokenizer, "chat_template", None) is None:
75
+ raise RuntimeError("Tokenizer has no chat_template; cannot apply chat formatting.")
76
+
77
+ prompts: list[str] = []
78
+ for q in questions:
79
+ messages = [
80
+ {"role": "system", "content": system_prompt},
81
+ {"role": "user", "content": q.strip()},
82
+ ]
83
+ text = tokenizer.apply_chat_template(
84
+ messages,
85
+ tokenize=False,
86
+ add_generation_prompt=True,
87
+ )
88
+ prompts.append(text)
89
+ return prompts
90
+
91
+
92
+ def _load_math_level_rows(
93
+ level: str,
94
+ split: str,
95
+ max_samples: int | None,
96
+ cache_dir: str | None,
97
+ ) -> tuple[list[str], list[str]]:
98
+ dataset_name = "EleutherAI/hendrycks_math"
99
+ dataset_configs = (
100
+ "algebra",
101
+ "counting_and_probability",
102
+ "geometry",
103
+ "intermediate_algebra",
104
+ "number_theory",
105
+ "prealgebra",
106
+ "precalculus",
107
+ )
108
+
109
+ questions: list[str] = []
110
+ references: list[str] = []
111
+
112
+ for config_name in dataset_configs:
113
+ rows = load_dataset(
114
+ dataset_name,
115
+ config_name,
116
+ split=split,
117
+ cache_dir=cache_dir,
118
+ )
119
+ for row in rows:
120
+ row_level = str(row.get("level", "")).strip()
121
+ if row_level != level:
122
+ continue
123
+ questions.append(str(row.get("problem", "")))
124
+ references.append(str(row.get("solution", "")))
125
+ if max_samples is not None and len(questions) >= max_samples:
126
+ return questions[:max_samples], references[:max_samples]
127
+
128
+ return questions, references
129
+
130
+
131
+ @torch.no_grad()
132
+ def main() -> None:
133
+ rank, world_size, local_rank = _init_distributed()
134
+
135
+ base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
136
+ model_dir = os.environ.get("MODEL_DIR") or os.environ.get("MODEL_PATH")
137
+ if not model_dir:
138
+ raise ValueError("Set MODEL_DIR or MODEL_PATH for the checkpoint to evaluate.")
139
+ resolved_model_dir = _resolve_local_model_dir(base_cfg, model_dir)
140
+
141
+ generation = base_cfg.get("generation", {})
142
+ max_prompt_length = int(generation.get("max_prompt_length", 512))
143
+ max_new_tokens = int(generation.get("max_completion_length", 256))
144
+ max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", str(max_prompt_length)))
145
+ max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", str(max_new_tokens)))
146
+
147
+ split = os.environ.get("MATH_SPLIT", "test")
148
+ max_samples_env = os.environ.get("MAX_SAMPLES", os.environ.get("EVAL_MAX_SAMPLES", "-1"))
149
+ max_samples = None if int(max_samples_env) < 0 else int(max_samples_env)
150
+
151
+ batch_size = int(os.environ.get("BATCH_SIZE", "4"))
152
+
153
+ cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
154
+ datasets_cache = str(cache_root / "datasets")
155
+ models_cache = str(cache_root / "models")
156
+
157
+ tokenizer = AutoTokenizer.from_pretrained(
158
+ str(resolved_model_dir),
159
+ trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)),
160
+ cache_dir=models_cache,
161
+ local_files_only=True,
162
+ )
163
+ if tokenizer.pad_token_id is None:
164
+ tokenizer.pad_token = tokenizer.eos_token
165
+ # Decoder-only safe.
166
+ tokenizer.padding_side = "left"
167
+
168
+ dtype = torch.bfloat16 if bool(base_cfg.get("trainer", {}).get("bf16", True)) else torch.float16
169
+ model = AutoModelForCausalLM.from_pretrained(
170
+ str(resolved_model_dir),
171
+ trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)),
172
+ cache_dir=models_cache,
173
+ torch_dtype=dtype,
174
+ local_files_only=True,
175
+ )
176
+ if torch.cuda.is_available():
177
+ torch.cuda.set_device(local_rank)
178
+ device = torch.device(f"cuda:{local_rank}")
179
+ else:
180
+ device = torch.device("cpu")
181
+ model.to(device)
182
+ model.eval()
183
+
184
+ questions, references = _load_math_level_rows(
185
+ level="Level 1",
186
+ split=split,
187
+ max_samples=max_samples,
188
+ cache_dir=datasets_cache,
189
+ )
190
+
191
+ indices = list(range(rank, len(questions), world_size))
192
+ local_questions = [questions[i] for i in indices]
193
+ local_refs = [references[i] for i in indices]
194
+
195
+ chat_prompts = _build_chat_prompts(tokenizer, local_questions, THINKING_SYSTEM_PROMPT)
196
+ completions: list[str] = []
197
+
198
+ for start in range(0, len(chat_prompts), batch_size):
199
+ batch_prompts = chat_prompts[start : start + batch_size]
200
+ enc = tokenizer(
201
+ batch_prompts,
202
+ return_tensors="pt",
203
+ padding=True,
204
+ truncation=True,
205
+ max_length=max_prompt_length,
206
+ )
207
+ input_ids = enc["input_ids"].to(device)
208
+ attn = enc["attention_mask"].to(device)
209
+ prompt_seq_len = input_ids.shape[1]
210
+
211
+ out = model.generate(
212
+ input_ids=input_ids,
213
+ attention_mask=attn,
214
+ max_new_tokens=max_new_tokens,
215
+ do_sample=False,
216
+ pad_token_id=tokenizer.pad_token_id,
217
+ eos_token_id=tokenizer.eos_token_id,
218
+ )
219
+ for bi in range(out.size(0)):
220
+ gen_ids = out[bi, prompt_seq_len:]
221
+ completions.append(tokenizer.decode(gen_ids, skip_special_tokens=True))
222
+
223
+ # Strict boxed correctness (project metric)
224
+ strict_scores = []
225
+ for completion, reference in zip(completions, local_refs, strict=True):
226
+ pred_text = reward_plugins_mod._extract_predicted_answer_text(completion)
227
+ ref_text = reward_plugins_mod._extract_reference_answer_text(reference)
228
+ if not pred_text or not ref_text:
229
+ strict_scores.append(0.0)
230
+ continue
231
+ pred_norm = reward_plugins_mod._normalize_answer_text(pred_text)
232
+ ref_norm = reward_plugins_mod._normalize_answer_text(ref_text)
233
+ if pred_norm and ref_norm and pred_norm == ref_norm:
234
+ strict_scores.append(1.0)
235
+ continue
236
+ pred_value = reward_plugins_mod._parse_numeric(pred_text)
237
+ ref_value = reward_plugins_mod._parse_numeric(ref_text)
238
+ if pred_value is not None and ref_value is not None and reward_plugins_mod._is_close(pred_value, ref_value):
239
+ strict_scores.append(1.0)
240
+ else:
241
+ strict_scores.append(0.0)
242
+
243
+ # Lenient numeric correctness fallback
244
+ lenient_scores: list[float] = []
245
+ for completion, reference in zip(completions, local_refs, strict=True):
246
+ ref_val = reward_plugins_mod._extract_reference_target(reference)
247
+ boxed = reward_plugins_mod._extract_last_boxed(completion)
248
+ if boxed:
249
+ pred_val = reward_plugins_mod._parse_numeric(boxed)
250
+ if pred_val is None:
251
+ nums = reward_plugins_mod._extract_numbers(boxed)
252
+ pred_val = nums[-1] if nums else None
253
+ else:
254
+ nums = reward_plugins_mod._extract_numbers(completion)
255
+ pred_val = nums[-1] if nums else None
256
+
257
+ if ref_val is not None and pred_val is not None and reward_plugins_mod._is_close(pred_val, ref_val):
258
+ lenient_scores.append(1.0)
259
+ else:
260
+ lenient_scores.append(0.0)
261
+
262
+ local_records: list[dict] = []
263
+ for i, idx in enumerate(indices):
264
+ local_records.append(
265
+ {
266
+ "sample_index": int(idx),
267
+ "question": local_questions[i],
268
+ "reference_answer": local_refs[i],
269
+ "model_answer_raw": completions[i],
270
+ "correctness": float(lenient_scores[i]),
271
+ "correctness_strict_boxed": float(strict_scores[i]),
272
+ }
273
+ )
274
+
275
+ if dist.is_initialized():
276
+ gathered: list[list[dict] | None] = [None for _ in range(world_size)]
277
+ dist.all_gather_object(gathered, local_records)
278
+ merged: list[dict] = []
279
+ for part in gathered:
280
+ if part:
281
+ merged.extend(part)
282
+ else:
283
+ merged = local_records
284
+
285
+ if rank != 0:
286
+ return
287
+
288
+ merged.sort(key=lambda r: r["sample_index"])
289
+ output_path = resolve_repo_path(
290
+ os.environ.get(
291
+ "OUTPUT_PATH",
292
+ "artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl",
293
+ )
294
+ )
295
+ output_path.parent.mkdir(parents=True, exist_ok=True)
296
+ with output_path.open("w", encoding="utf-8") as handle:
297
+ for row in merged:
298
+ handle.write(json.dumps(row, ensure_ascii=True) + "\n")
299
+
300
+ acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0
301
+ acc_strict = (
302
+ sum(r["correctness_strict_boxed"] for r in merged) / len(merged)
303
+ if merged
304
+ else 0.0
305
+ )
306
+ print(f"Wrote {len(merged)} rows to {output_path}")
307
+ print(f"Accuracy (lenient numeric): {acc:.4f}")
308
+ print(f"Accuracy (strict boxed): {acc_strict:.4f}")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
src/eval_permanent_checkpoints.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import json
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import torch.distributed as dist
10
+
11
+ import hackable # noqa: F401
12
+ from hackable.utils import resolve_repo_path
13
+ from eval_sweep_models import (
14
+ _init_distributed,
15
+ _load_yaml,
16
+ _resolve_local_model_dir,
17
+ evaluate_one_model,
18
+ )
19
+
20
+
21
+ def _parse_checkpoint_step(dirname: str) -> int | None:
22
+ m = re.match(r"^checkpoint-(\d+)$", dirname)
23
+ if m:
24
+ return int(m.group(1))
25
+ m = re.search(r"-step-(\d+)$", dirname)
26
+ if m:
27
+ return int(m.group(1))
28
+ return None
29
+
30
+
31
+ def _discover_checkpoint_jobs(
32
+ base_cfg: dict, permanent_root: Path, run_label: str
33
+ ) -> list[tuple[str, int, str, Path, str]]:
34
+ """(run_label, step, resolved_model_dir_str, resolved_path, dir_name)"""
35
+ root = permanent_root.resolve()
36
+ if not root.is_dir():
37
+ raise FileNotFoundError(f"Not a directory: {root}")
38
+ jobs: list[tuple[str, int, str, Path, str]] = []
39
+ for p in sorted(root.iterdir()):
40
+ if not p.is_dir():
41
+ continue
42
+ step = _parse_checkpoint_step(p.name)
43
+ if step is None:
44
+ continue
45
+ resolved = _resolve_local_model_dir(base_cfg, str(p))
46
+ jobs.append((run_label, step, str(resolved), resolved, p.name))
47
+ jobs.sort(key=lambda x: (x[1], x[4]))
48
+ return jobs
49
+
50
+
51
+ def _line_chart_svg(
52
+ series: list[tuple[str, list[tuple[int, float]], str]],
53
+ title: str,
54
+ y_label: str,
55
+ y_max: float,
56
+ path: Path,
57
+ ) -> None:
58
+ width = 900
59
+ height = 420
60
+ lm, rm, tm, bm = 70, 40, 50, 55
61
+ pw = width - lm - rm
62
+ ph = height - tm - bm
63
+ yb = tm + ph
64
+
65
+ all_steps: list[int] = []
66
+ for _, pts, _ in series:
67
+ all_steps.extend(s for s, _ in pts)
68
+ if not all_steps:
69
+ path.write_text(
70
+ f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
71
+ f'<text x="40" y="40">{title} (no data)</text></svg>',
72
+ encoding="utf-8",
73
+ )
74
+ return
75
+ x_min, x_max = min(all_steps), max(all_steps)
76
+ if x_max == x_min:
77
+ x_max = x_min + 1
78
+
79
+ def sx(x: int) -> int:
80
+ return lm + int((x - x_min) / (x_max - x_min) * pw)
81
+
82
+ def sy(y: float) -> int:
83
+ y = max(0.0, min(y_max, y))
84
+ return yb - int((y / y_max) * ph) if y_max > 0 else yb
85
+
86
+ parts: list[str] = [
87
+ f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
88
+ '<rect width="100%" height="100%" fill="#ffffff"/>',
89
+ f'<text x="{lm}" y="28" font-size="16" font-family="sans-serif">{title}</text>',
90
+ f'<text x="20" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
91
+ f'transform="rotate(-90 20 {tm + ph // 2})">{y_label}</text>',
92
+ f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
93
+ f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
94
+ f'<text x="{lm + pw // 2}" y="{height - 12}" text-anchor="middle" '
95
+ f'font-size="12" font-family="sans-serif">Training step</text>',
96
+ ]
97
+
98
+ for i in range(5):
99
+ val = (i / 4) * y_max
100
+ yy = sy(val)
101
+ parts.append(
102
+ f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#999"/>'
103
+ )
104
+ parts.append(
105
+ f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
106
+ f'font-family="sans-serif">{val:.2f}</text>'
107
+ )
108
+
109
+ legend_x = lm + pw - 200
110
+ legend_y = tm + 8
111
+ for idx, (name, pts, color) in enumerate(series):
112
+ if len(pts) < 2:
113
+ pts_sorted = sorted(pts, key=lambda z: z[0])
114
+ if not pts_sorted:
115
+ continue
116
+ cx, cy = sx(pts_sorted[0][0]), sy(pts_sorted[0][1])
117
+ parts.append(
118
+ f'<circle cx="{cx}" cy="{cy}" r="4" fill="{color}" stroke="#111"/>'
119
+ )
120
+ else:
121
+ pts_sorted = sorted(pts, key=lambda z: z[0])
122
+ d = "M " + " L ".join(f"{sx(s)} {sy(v)}" for s, v in pts_sorted)
123
+ parts.append(
124
+ f'<path d="{d}" fill="none" stroke="{color}" stroke-width="2.5"/>'
125
+ )
126
+ parts.append(
127
+ f'<rect x="{legend_x}" y="{legend_y + idx * 18}" width="10" height="10" fill="{color}"/>'
128
+ )
129
+ parts.append(
130
+ f'<text x="{legend_x + 16}" y="{legend_y + idx * 18 + 9}" font-size="11" '
131
+ f'font-family="sans-serif">{name}</text>'
132
+ )
133
+
134
+ parts.append("</svg>")
135
+ path.write_text("\n".join(parts), encoding="utf-8")
136
+
137
+
138
+ def _scatter_accuracy_vs_cot_svg(rows: list[dict], path: Path, title: str) -> None:
139
+ """Scatter: x = avg_cot_words, y = accuracy. One color per ``run_label``; optional path by training step."""
140
+ width = 640
141
+ height = 520
142
+ lm, rm, tm, bm = 72, 160, 52, 64
143
+ pw = width - lm - rm
144
+ ph = height - tm - bm
145
+ yb = tm + ph
146
+
147
+ if not rows:
148
+ path.write_text(
149
+ f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">'
150
+ f'<text x="40" y="40">{title} (no data)</text></svg>',
151
+ encoding="utf-8",
152
+ )
153
+ return
154
+
155
+ labels: list[str] = []
156
+ seen: set[str] = set()
157
+ for r in rows:
158
+ lab = str(r.get("run_label", "run"))
159
+ if lab not in seen:
160
+ seen.add(lab)
161
+ labels.append(lab)
162
+
163
+ colors = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
164
+ color_map = {lab: colors[i % len(colors)] for i, lab in enumerate(labels)}
165
+
166
+ xs = [float(r["avg_cot_words"]) for r in rows]
167
+ ys = [float(r["accuracy"]) for r in rows]
168
+ x_min, x_max = min(xs), max(xs)
169
+ y_min, y_max = 0.0, 1.0
170
+ if x_max <= x_min:
171
+ x_max = x_min + 1.0
172
+ pad = (x_max - x_min) * 0.06 + 1.0
173
+ x_min = max(0.0, x_min - pad)
174
+ x_max = x_max + pad
175
+
176
+ def sx(x: float) -> float:
177
+ return lm + (x - x_min) / (x_max - x_min) * pw
178
+
179
+ def sy(y: float) -> float:
180
+ y = max(y_min, min(y_max, y))
181
+ return yb - (y - y_min) / (y_max - y_min) * ph
182
+
183
+ parts: list[str] = [
184
+ f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">',
185
+ '<rect width="100%" height="100%" fill="#fafafa"/>',
186
+ f'<text x="{lm}" y="30" font-size="15" font-family="sans-serif">{title}</text>',
187
+ f'<text x="{width // 2}" y="{height - 18}" text-anchor="middle" font-size="12" '
188
+ f'font-family="sans-serif">Avg CoT length (words)</text>',
189
+ f'<text x="18" y="{tm + ph // 2}" font-size="12" font-family="sans-serif" '
190
+ f'transform="rotate(-90 18 {tm + ph // 2})">Accuracy</text>',
191
+ f'<line x1="{lm}" y1="{yb}" x2="{lm + pw}" y2="{yb}" stroke="#111" stroke-width="2"/>',
192
+ f'<line x1="{lm}" y1="{tm}" x2="{lm}" y2="{yb}" stroke="#111" stroke-width="2"/>',
193
+ ]
194
+
195
+ for i in range(5):
196
+ val = y_min + (i / 4) * (y_max - y_min)
197
+ yy = sy(val)
198
+ parts.append(f'<line x1="{lm - 4}" y1="{yy}" x2="{lm}" y2="{yy}" stroke="#bbb"/>')
199
+ parts.append(
200
+ f'<text x="{lm - 8}" y="{yy + 4}" text-anchor="end" font-size="10" '
201
+ f'font-family="sans-serif">{val:.2f}</text>'
202
+ )
203
+
204
+ for i in range(5):
205
+ frac = i / 4
206
+ xv = x_min + frac * (x_max - x_min)
207
+ xx = sx(xv)
208
+ parts.append(f'<line x1="{xx}" y1="{yb}" x2="{xx}" y2="{yb + 4}" stroke="#bbb"/>')
209
+ parts.append(
210
+ f'<text x="{xx}" y="{yb + 18}" text-anchor="middle" font-size="10" '
211
+ f'font-family="sans-serif">{xv:.0f}</text>'
212
+ )
213
+
214
+ for lab in labels:
215
+ sub = [r for r in rows if str(r.get("run_label", "run")) == lab]
216
+ sub.sort(key=lambda r: int(r["checkpoint_step"]))
217
+ color = color_map[lab]
218
+ if len(sub) >= 2:
219
+ d = "M " + " L ".join(f'{sx(float(r["avg_cot_words"])):.1f} {sy(float(r["accuracy"])):.1f}' for r in sub)
220
+ parts.append(
221
+ f'<path d="{d}" fill="none" stroke="{color}" stroke-width="1.5" stroke-opacity="0.35"/>'
222
+ )
223
+
224
+ for r in rows:
225
+ lab = str(r.get("run_label", "run"))
226
+ color = color_map[lab]
227
+ cx = sx(float(r["avg_cot_words"]))
228
+ cy = sy(float(r["accuracy"]))
229
+ step = int(r["checkpoint_step"])
230
+ name = str(r.get("checkpoint_dir", f"step-{step}"))
231
+ tip = f"{name}: accuracy={float(r['accuracy']):.4f}, avg_cot_words={float(r['avg_cot_words']):.2f}"
232
+ parts.append(
233
+ f'<g><circle cx="{cx:.1f}" cy="{cy:.1f}" r="5" fill="{color}" stroke="#111" stroke-width="1">'
234
+ f"<title>{tip}</title></circle>"
235
+ f'<text x="{cx + 8:.1f}" y="{cy - 6:.1f}" font-size="9" font-family="sans-serif" fill="#333">{step}</text></g>'
236
+ )
237
+
238
+ legend_x = lm + pw + 14
239
+ legend_y = tm + 4
240
+ parts.append(
241
+ f'<text x="{legend_x}" y="{legend_y}" font-size="11" font-family="sans-serif" font-weight="bold">Series</text>'
242
+ )
243
+ for idx, lab in enumerate(labels):
244
+ cy = legend_y + 18 + idx * 20
245
+ parts.append(
246
+ f'<rect x="{legend_x}" y="{cy - 8}" width="10" height="10" fill="{color_map[lab]}"/>'
247
+ )
248
+ parts.append(
249
+ f'<text x="{legend_x + 16}" y="{cy}" font-size="11" font-family="sans-serif">{lab}</text>'
250
+ )
251
+
252
+ parts.append("</svg>")
253
+ path.write_text("\n".join(parts), encoding="utf-8")
254
+
255
+
256
+ def _resolve_out_root(default: Path) -> Path:
257
+ raw = os.environ.get("OUT_ROOT")
258
+ if raw is None or not str(raw).strip():
259
+ return resolve_repo_path(str(default))
260
+ return resolve_repo_path(raw)
261
+
262
+
263
+ def main() -> None:
264
+ rank, _, _ = _init_distributed()
265
+ base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
266
+
267
+ eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
268
+ eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
269
+ rollout_n = int(os.environ.get("ROLLOUT_SAMPLES", "8"))
270
+
271
+ permanent_root = os.environ.get("PERMANENT_ROOT", "").strip()
272
+ if permanent_root:
273
+ pr = resolve_repo_path(permanent_root)
274
+ run_label_single = os.environ.get("RUN_LABEL", "permanent")
275
+ out_default = pr / "eval_permanent"
276
+ out_root = _resolve_out_root(out_default)
277
+ jobs_single = _discover_checkpoint_jobs(base_cfg, pr, run_label_single)
278
+ all_jobs = jobs_single
279
+ jobs_cw1: list = []
280
+ jobs_cw5: list = []
281
+ else:
282
+ cw1_root = resolve_repo_path(os.environ["PERMANENT_CW1"])
283
+ cw5_root = resolve_repo_path(os.environ["PERMANENT_CW5"])
284
+ out_default = cw1_root.parent / "eval_permanent"
285
+ out_root = _resolve_out_root(out_default)
286
+ jobs_cw1 = _discover_checkpoint_jobs(base_cfg, cw1_root, "correctness_weight_1")
287
+ jobs_cw5 = _discover_checkpoint_jobs(base_cfg, cw5_root, "correctness_weight_5")
288
+ all_jobs = jobs_cw1 + jobs_cw5
289
+
290
+ if rank == 0:
291
+ out_root.mkdir(parents=True, exist_ok=True)
292
+ (out_root / "rollouts").mkdir(parents=True, exist_ok=True)
293
+ (out_root / "full_outputs").mkdir(parents=True, exist_ok=True)
294
+ if permanent_root:
295
+ print(f"PERMANENT_ROOT: {resolve_repo_path(permanent_root)} ({len(all_jobs)} checkpoints)")
296
+ for run_label, step, _, _, name in all_jobs:
297
+ print(f" {run_label} step={step} ({name})")
298
+ else:
299
+ print(f"Found {len(jobs_cw1)} checkpoints (cw=1), {len(jobs_cw5)} checkpoints (cw=5)")
300
+ for jl in (jobs_cw1, jobs_cw5):
301
+ for run_label, step, _, _, name in jl:
302
+ print(f" {run_label} step={step} ({name})")
303
+
304
+ if dist.is_initialized():
305
+ dist.barrier()
306
+
307
+ rows: list[dict] = []
308
+ for run_label, step, _resolved_str, resolved_path, dir_name in all_jobs:
309
+ records = evaluate_one_model(
310
+ model_dir=resolved_path,
311
+ base_cfg=base_cfg,
312
+ eval_max_samples=eval_max_samples,
313
+ batch_size=eval_batch_size,
314
+ )
315
+ if rank == 0:
316
+ acc = sum(float(r["correctness"]) for r in records) / len(records) if records else 0.0
317
+ avg_cot = sum(float(r["cot_words"]) for r in records) / len(records) if records else 0.0
318
+ row = {
319
+ "run_label": run_label,
320
+ "checkpoint_step": step,
321
+ "checkpoint_dir": dir_name,
322
+ "model_dir": str(resolved_path),
323
+ "num_examples": len(records),
324
+ "accuracy": acc,
325
+ "avg_cot_words": avg_cot,
326
+ }
327
+ rows.append(row)
328
+
329
+ rollout_dir = out_root / "rollouts" / run_label
330
+ rollout_dir.mkdir(parents=True, exist_ok=True)
331
+ rollout_path = rollout_dir / f"{dir_name}_rollouts.jsonl"
332
+ with rollout_path.open("w", encoding="utf-8") as handle:
333
+ for rec in records[:rollout_n]:
334
+ handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
335
+
336
+ full_path = out_root / "full_outputs" / run_label / f"{dir_name}_outputs.jsonl"
337
+ full_path.parent.mkdir(parents=True, exist_ok=True)
338
+ with full_path.open("w", encoding="utf-8") as handle:
339
+ for rec in records:
340
+ handle.write(json.dumps(rec, ensure_ascii=True) + "\n")
341
+
342
+ print(
343
+ f"Eval {run_label} {dir_name}: acc={acc:.4f} avg_cot_words={avg_cot:.2f} n={len(records)}"
344
+ )
345
+
346
+ if dist.is_initialized():
347
+ dist.barrier()
348
+
349
+ if rank != 0:
350
+ return
351
+
352
+ rows.sort(key=lambda r: (r["run_label"], r["checkpoint_step"], r["checkpoint_dir"]))
353
+
354
+ summary_json = out_root / "permanent_checkpoints_eval.json"
355
+ summary_csv = out_root / "permanent_checkpoints_eval.csv"
356
+ summary_json.write_text(json.dumps(rows, indent=2), encoding="utf-8")
357
+ with summary_csv.open("w", encoding="utf-8", newline="") as handle:
358
+ w = csv.DictWriter(
359
+ handle,
360
+ fieldnames=[
361
+ "run_label",
362
+ "checkpoint_step",
363
+ "checkpoint_dir",
364
+ "model_dir",
365
+ "num_examples",
366
+ "accuracy",
367
+ "avg_cot_words",
368
+ ],
369
+ )
370
+ w.writeheader()
371
+ for row in rows:
372
+ w.writerow(row)
373
+
374
+ def series_for(label: str, ykey: str) -> list[tuple[int, float]]:
375
+ return [
376
+ (int(r["checkpoint_step"]), float(r[ykey]))
377
+ for r in rows
378
+ if r["run_label"] == label
379
+ ]
380
+
381
+ palette = ["#2563eb", "#dc2626", "#16a34a", "#9333ea", "#ca8a04", "#0891b2"]
382
+ uniq_labels = sorted({str(r["run_label"]) for r in rows})
383
+ acc_series = [
384
+ (lab, series_for(lab, "accuracy"), palette[i % len(palette)])
385
+ for i, lab in enumerate(uniq_labels)
386
+ if series_for(lab, "accuracy")
387
+ ]
388
+ cot_series = [
389
+ (lab, series_for(lab, "avg_cot_words"), palette[i % len(palette)])
390
+ for i, lab in enumerate(uniq_labels)
391
+ if series_for(lab, "avg_cot_words")
392
+ ]
393
+
394
+ cot_max = 1.0
395
+ for r in rows:
396
+ cot_max = max(cot_max, float(r["avg_cot_words"]))
397
+
398
+ if acc_series:
399
+ _line_chart_svg(
400
+ acc_series,
401
+ "GSM8K accuracy vs checkpoint step",
402
+ "Accuracy",
403
+ 1.0,
404
+ out_root / "accuracy_vs_step.svg",
405
+ )
406
+ if cot_series:
407
+ _line_chart_svg(
408
+ cot_series,
409
+ "Average CoT length (words) vs checkpoint step",
410
+ "Avg CoT words",
411
+ cot_max,
412
+ out_root / "avg_cot_vs_step.svg",
413
+ )
414
+
415
+ _scatter_accuracy_vs_cot_svg(
416
+ rows,
417
+ out_root / "accuracy_vs_avg_cot_words.svg",
418
+ "GSM8K accuracy vs average CoT length (words)",
419
+ )
420
+
421
+ print(f"Saved: {summary_json}")
422
+ print(f"Saved: {summary_csv}")
423
+ if acc_series:
424
+ print(f"Saved: {out_root / 'accuracy_vs_step.svg'}")
425
+ if cot_series:
426
+ print(f"Saved: {out_root / 'avg_cot_vs_step.svg'}")
427
+ print(f"Saved: {out_root / 'accuracy_vs_avg_cot_words.svg'}")
428
+ print(f"Rollouts: {out_root / 'rollouts'}/<run_label>/")
429
+ print(f"Full outputs: {out_root / 'full_outputs'}/<run_label>/")
430
+
431
+
432
+ if __name__ == "__main__":
433
+ main()
src/eval_sweep_models.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import json
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ import yaml
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
+ import hackable # noqa: F401
15
+ from hackable.data_plugins import GSM8KProvider
16
+ from hackable.paths import resolve_storage_path, storage_layout
17
+ from hackable.reward_plugins import gsm8k_correctness_reward
18
+ from hackable.utils import resolve_repo_path
19
+
20
+ THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
21
+
22
+
23
+ def _load_yaml(path: str) -> dict:
24
+ with open(path, "r", encoding="utf-8") as handle:
25
+ return yaml.safe_load(handle)
26
+
27
+
28
+ def _cot_word_len(completion: str) -> int:
29
+ match = THINK_RE.search(completion)
30
+ text = match.group(1).strip() if match else ""
31
+ return len(text.split()) if text else 0
32
+
33
+
34
+ def _model_dtype(cfg: dict):
35
+ return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16
36
+
37
+
38
+ def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]:
39
+ layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache"))
40
+ return layout.datasets, layout.models
41
+
42
+
43
+ def _dist_info() -> tuple[int, int, int]:
44
+ rank = int(os.environ.get("RANK", "0"))
45
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
46
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
47
+ return rank, world_size, local_rank
48
+
49
+
50
+ def _init_distributed() -> tuple[int, int, int]:
51
+ rank, world_size, local_rank = _dist_info()
52
+ if world_size > 1 and not dist.is_initialized():
53
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
54
+ dist.init_process_group(backend=backend, init_method="env://")
55
+ return rank, world_size, local_rank
56
+
57
+
58
+ def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
59
+ candidate = Path(model_dir)
60
+ if candidate.is_absolute() and candidate.exists():
61
+ return candidate.resolve()
62
+ if not candidate.is_absolute() and candidate.exists():
63
+ return candidate.resolve()
64
+
65
+ repo_local = resolve_repo_path(model_dir)
66
+ if repo_local.exists():
67
+ return repo_local
68
+
69
+ cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
70
+ prefixed = (cache_root / candidate).resolve()
71
+ if prefixed.exists():
72
+ return prefixed
73
+
74
+ raise FileNotFoundError(
75
+ f"Model directory not found locally: '{model_dir}'. "
76
+ f"Tried '{candidate}', '{repo_local}', and '{prefixed}'."
77
+ )
78
+
79
+
80
+ def _resolve_sweep_root(base_cfg: dict, requested_sweep_root: Path) -> Path:
81
+ candidate = resolve_storage_path(
82
+ requested_sweep_root,
83
+ base_cfg.get("storage", {}).get("cache_dir", "cache"),
84
+ )
85
+ if candidate.is_dir() and any(path.is_dir() and path.name.startswith("run_") for path in candidate.iterdir()):
86
+ return candidate
87
+ raise FileNotFoundError(
88
+ "Could not resolve SWEEP_ROOT with run directories: "
89
+ f"{candidate}"
90
+ )
91
+
92
+
93
+ def _discover_model_dirs(sweep_root: Path) -> list[Path]:
94
+ dirs = [
95
+ path
96
+ for path in sweep_root.iterdir()
97
+ if path.is_dir() and path.name.startswith("run_")
98
+ ]
99
+ if not dirs:
100
+ raise FileNotFoundError(
101
+ f"No run directories starting with 'run_' found in {sweep_root}"
102
+ )
103
+ return sorted(dirs)
104
+
105
+
106
+ @torch.no_grad()
107
+ def evaluate_one_model(
108
+ model_dir: Path,
109
+ base_cfg: dict,
110
+ eval_max_samples: int,
111
+ batch_size: int,
112
+ ) -> list[dict]:
113
+ rank, world_size, local_rank = _dist_info()
114
+ generation = base_cfg.get("generation", {})
115
+ max_prompt_len = int(generation.get("max_prompt_length", 512))
116
+ max_completion_len = int(generation.get("max_completion_length", 256))
117
+ model_name_fallback = str(base_cfg["model"]["name"])
118
+ trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", False))
119
+ dtype = _model_dtype(base_cfg)
120
+ datasets_cache, models_cache = _get_cache_paths(base_cfg)
121
+
122
+ provider = GSM8KProvider()
123
+ all_samples = provider.load(
124
+ split="test",
125
+ max_samples=None if eval_max_samples < 0 else eval_max_samples,
126
+ cache_dir=str(datasets_cache),
127
+ )
128
+ indices = list(range(rank, len(all_samples), world_size))
129
+ local_samples = [all_samples[idx] for idx in indices]
130
+ prompts = [sample.prompt for sample in local_samples]
131
+ refs = [sample.target for sample in local_samples]
132
+ metadata = [sample.metadata for sample in local_samples]
133
+
134
+ try:
135
+ tokenizer = AutoTokenizer.from_pretrained(
136
+ str(model_dir),
137
+ trust_remote_code=trust_remote_code,
138
+ cache_dir=str(models_cache),
139
+ local_files_only=True,
140
+ )
141
+ except Exception:
142
+ tokenizer = AutoTokenizer.from_pretrained(
143
+ model_name_fallback,
144
+ trust_remote_code=trust_remote_code,
145
+ cache_dir=str(models_cache),
146
+ local_files_only=True,
147
+ )
148
+
149
+ model = AutoModelForCausalLM.from_pretrained(
150
+ str(model_dir),
151
+ trust_remote_code=trust_remote_code,
152
+ cache_dir=str(models_cache),
153
+ torch_dtype=dtype,
154
+ local_files_only=True,
155
+ )
156
+ if torch.cuda.is_available():
157
+ torch.cuda.set_device(local_rank)
158
+ device = torch.device(f"cuda:{local_rank}")
159
+ else:
160
+ device = torch.device("cpu")
161
+ model.to(device)
162
+ model.eval()
163
+
164
+ if tokenizer.pad_token_id is None:
165
+ tokenizer.pad_token = tokenizer.eos_token
166
+
167
+ completions: list[str] = []
168
+ for start in range(0, len(prompts), batch_size):
169
+ batch_prompts = prompts[start : start + batch_size]
170
+ enc = tokenizer(
171
+ batch_prompts,
172
+ return_tensors="pt",
173
+ padding=True,
174
+ truncation=True,
175
+ max_length=max_prompt_len,
176
+ )
177
+ input_ids = enc["input_ids"].to(device)
178
+ attn = enc["attention_mask"].to(device)
179
+ out = model.generate(
180
+ input_ids=input_ids,
181
+ attention_mask=attn,
182
+ max_new_tokens=max_completion_len,
183
+ do_sample=False,
184
+ pad_token_id=tokenizer.pad_token_id,
185
+ eos_token_id=tokenizer.eos_token_id,
186
+ )
187
+ prompt_lens = attn.sum(dim=1).tolist()
188
+ for idx in range(out.size(0)):
189
+ completion_ids = out[idx, int(prompt_lens[idx]) :]
190
+ completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
191
+
192
+ scores = gsm8k_correctness_reward(
193
+ prompts=prompts,
194
+ completions=completions,
195
+ references=refs,
196
+ metadata=metadata,
197
+ )
198
+
199
+ local_records: list[dict] = []
200
+ for i, (prompt, reference, completion, score) in enumerate(
201
+ zip(prompts, refs, completions, scores, strict=True)
202
+ ):
203
+ local_records.append(
204
+ {
205
+ "sample_index": int(indices[i]),
206
+ "prompt": prompt,
207
+ "reference": reference,
208
+ "completion": completion,
209
+ "correctness": float(score),
210
+ "cot_words": int(_cot_word_len(completion)),
211
+ }
212
+ )
213
+
214
+ del model
215
+ if torch.cuda.is_available():
216
+ torch.cuda.empty_cache()
217
+
218
+ if dist.is_initialized():
219
+ gathered: list[list[dict] | None] = [None for _ in range(world_size)]
220
+ dist.all_gather_object(gathered, local_records)
221
+ merged: list[dict] = []
222
+ for part in gathered:
223
+ if part:
224
+ merged.extend(part)
225
+ else:
226
+ merged = local_records
227
+
228
+ merged.sort(key=lambda row: row["sample_index"])
229
+ return merged
230
+
231
+
232
+ def _summarize(records: list[dict], model_dir: str) -> dict:
233
+ if not records:
234
+ return {
235
+ "name": Path(model_dir).name,
236
+ "model_dir": model_dir,
237
+ "num_examples": 0,
238
+ "accuracy": 0.0,
239
+ "avg_cot_words": 0.0,
240
+ }
241
+ accuracy = sum(float(row["correctness"]) for row in records) / len(records)
242
+ avg_cot = sum(float(row["cot_words"]) for row in records) / len(records)
243
+ return {
244
+ "name": Path(model_dir).name,
245
+ "model_dir": model_dir,
246
+ "num_examples": len(records),
247
+ "accuracy": float(accuracy),
248
+ "avg_cot_words": float(avg_cot),
249
+ }
250
+
251
+
252
+ def _write_accuracy_svg(summaries: list[dict], path: Path) -> None:
253
+ width = 1000
254
+ height = 460
255
+ left_margin = 70
256
+ right_margin = 30
257
+ top_margin = 70
258
+ bottom_margin = 90
259
+ plot_w = width - left_margin - right_margin
260
+ plot_h = height - top_margin - bottom_margin
261
+ y_base = top_margin + plot_h
262
+
263
+ runs = [row["name"] for row in summaries]
264
+ acc_vals = [float(row["accuracy"]) for row in summaries]
265
+ vmax = max(1.0, max(acc_vals) if acc_vals else 1.0)
266
+
267
+ bar_count = max(1, len(runs))
268
+ slot_w = plot_w / bar_count
269
+ bar_w = min(120, max(30, int(slot_w * 0.55)))
270
+ palette = ["#2563eb", "#dc2626", "#16a34a", "#ca8a04", "#7c3aed", "#0891b2"]
271
+
272
+ parts: list[str] = []
273
+ parts.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">')
274
+ parts.append('<rect width="100%" height="100%" fill="#ffffff"/>')
275
+ parts.append(
276
+ '<text x="40" y="34" font-size="20" font-family="sans-serif">Sweep Evaluation: GSM8K Accuracy</text>'
277
+ )
278
+ parts.append(
279
+ f'<line x1="{left_margin}" y1="{y_base}" x2="{left_margin + plot_w}" y2="{y_base}" stroke="#111" stroke-width="2" />'
280
+ )
281
+ parts.append(
282
+ f'<line x1="{left_margin}" y1="{top_margin}" x2="{left_margin}" y2="{y_base}" stroke="#111" stroke-width="2" />'
283
+ )
284
+
285
+ # y-axis ticks
286
+ for tick in [0.0, 0.25, 0.5, 0.75, 1.0]:
287
+ y = y_base - int((tick / vmax) * plot_h) if vmax > 0 else y_base
288
+ parts.append(
289
+ f'<line x1="{left_margin - 6}" y1="{y}" x2="{left_margin}" y2="{y}" stroke="#111" stroke-width="1" />'
290
+ )
291
+ parts.append(
292
+ f'<text x="{left_margin - 10}" y="{y + 4}" text-anchor="end" font-size="11" font-family="sans-serif">{tick:.2f}</text>'
293
+ )
294
+
295
+ for idx, (run_name, acc) in enumerate(zip(runs, acc_vals, strict=True)):
296
+ center_x = left_margin + int((idx + 0.5) * slot_w)
297
+ bar_h = int((acc / vmax) * plot_h) if vmax > 0 else 0
298
+ x = center_x - bar_w // 2
299
+ y = y_base - bar_h
300
+ color = palette[idx % len(palette)]
301
+ parts.append(f'<rect x="{x}" y="{y}" width="{bar_w}" height="{bar_h}" fill="{color}" />')
302
+ parts.append(
303
+ f'<text x="{center_x}" y="{y - 8}" text-anchor="middle" font-size="12" font-family="sans-serif">{acc:.3f}</text>'
304
+ )
305
+ parts.append(
306
+ f'<text x="{center_x}" y="{y_base + 18}" text-anchor="middle" font-size="11" font-family="sans-serif">{run_name}</text>'
307
+ )
308
+
309
+ parts.append("</svg>")
310
+ path.write_text("\n".join(parts), encoding="utf-8")
311
+
312
+
313
+ def main() -> None:
314
+ rank, _, _ = _init_distributed()
315
+ base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
316
+ requested_sweep_root = Path(os.environ["SWEEP_ROOT"])
317
+ sweep_root = _resolve_sweep_root(base_cfg, requested_sweep_root)
318
+ if "OUT_ROOT" in os.environ:
319
+ out_root = resolve_repo_path(os.environ["OUT_ROOT"])
320
+ else:
321
+ out_root = (sweep_root / "eval_results").resolve()
322
+ eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
323
+ eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
324
+
325
+ model_dirs = _discover_model_dirs(sweep_root)
326
+ resolved_model_dirs = [_resolve_local_model_dir(base_cfg, str(path)) for path in model_dirs]
327
+
328
+ if rank == 0:
329
+ out_root.mkdir(parents=True, exist_ok=True)
330
+ (out_root / "outputs").mkdir(parents=True, exist_ok=True)
331
+
332
+ if dist.is_initialized():
333
+ dist.barrier()
334
+
335
+ summaries: list[dict] = []
336
+ for model_dir in resolved_model_dirs:
337
+ records = evaluate_one_model(
338
+ model_dir=model_dir,
339
+ base_cfg=base_cfg,
340
+ eval_max_samples=eval_max_samples,
341
+ batch_size=eval_batch_size,
342
+ )
343
+ if rank == 0:
344
+ output_jsonl = out_root / "outputs" / f"{model_dir.name}_outputs.jsonl"
345
+ with output_jsonl.open("w", encoding="utf-8") as handle:
346
+ for row in records:
347
+ handle.write(json.dumps(row, ensure_ascii=True) + "\n")
348
+ summary = _summarize(records, str(model_dir))
349
+ summary["outputs_jsonl"] = str(output_jsonl)
350
+ summaries.append(summary)
351
+
352
+ if dist.is_initialized():
353
+ dist.barrier()
354
+
355
+ if rank != 0:
356
+ return
357
+
358
+ json_path = out_root / "sweep_eval_summary.json"
359
+ csv_path = out_root / "sweep_eval_summary.csv"
360
+ svg_path = out_root / "sweep_eval_accuracy.svg"
361
+ json_path.write_text(json.dumps(summaries, indent=2), encoding="utf-8")
362
+ with csv_path.open("w", encoding="utf-8", newline="") as handle:
363
+ writer = csv.DictWriter(
364
+ handle,
365
+ fieldnames=[
366
+ "name",
367
+ "model_dir",
368
+ "num_examples",
369
+ "accuracy",
370
+ "avg_cot_words",
371
+ "outputs_jsonl",
372
+ ],
373
+ )
374
+ writer.writeheader()
375
+ for row in summaries:
376
+ writer.writerow(row)
377
+ _write_accuracy_svg(summaries, svg_path)
378
+
379
+ print(f"Saved summary: {json_path}")
380
+ print(f"Saved summary: {csv_path}")
381
+ print(f"Saved plot: {svg_path}")
382
+ print(f"Saved outputs dir: {out_root / 'outputs'}")
383
+
384
+
385
+ if __name__ == "__main__":
386
+ main()
src/hackable/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Hackable GRPO training primitives."""
2
+
3
+ # Ensure default plugins register on import.
4
+ from . import data_plugins as _data_plugins # noqa: F401
5
+ from . import objectives as _objectives # noqa: F401
6
+ from . import reward_plugins as _reward_plugins # noqa: F401
src/hackable/backends.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+
9
+ def load_model_and_tokenizer(
10
+ model_name: str,
11
+ trust_remote_code: bool = False,
12
+ cache_dir: str | None = None,
13
+ load_in_4bit: bool = False,
14
+ torch_dtype: str = "bfloat16",
15
+ ):
16
+ del load_in_4bit
17
+ dtype = torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16
18
+
19
+ # Apply Liger kernels before constructing Llama models.
20
+ if "llama" in model_name.lower():
21
+ try:
22
+ from liger_kernel.transformers import apply_liger_kernel_to_llama
23
+ except Exception as exc:
24
+ raise RuntimeError(
25
+ "Failed to import Liger kernel patcher for Llama. "
26
+ "Install liger-kernel in the runtime environment."
27
+ ) from exc
28
+ apply_liger_kernel_to_llama()
29
+
30
+ attn_impl = "sdpa"
31
+ if importlib.util.find_spec("flash_attn") is not None:
32
+ try:
33
+ __import__("flash_attn")
34
+ attn_impl = "flash_attention_2"
35
+ except Exception:
36
+ attn_impl = "sdpa"
37
+
38
+ from transformers import AutoModelForCausalLM, AutoTokenizer
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ model_name,
42
+ trust_remote_code=trust_remote_code,
43
+ cache_dir=cache_dir,
44
+ )
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ trust_remote_code=trust_remote_code,
48
+ cache_dir=cache_dir,
49
+ dtype=dtype,
50
+ attn_implementation=attn_impl,
51
+ )
52
+ return model, tokenizer, "transformers"
53
+
54
+
55
+ def generation_kwargs(cfg: Any) -> dict[str, Any]:
56
+ return {
57
+ "max_prompt_length": cfg.max_prompt_length,
58
+ "max_completion_length": cfg.max_completion_length,
59
+ "num_generations": cfg.num_generations,
60
+ "temperature": cfg.temperature,
61
+ "top_p": cfg.top_p,
62
+ }
src/hackable/config.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import yaml
8
+
9
+ from .utils import resolve_repo_path
10
+
11
+
12
+ @dataclass
13
+ class ModelConfig:
14
+ name: str
15
+ trust_remote_code: bool = False
16
+ load_in_4bit: bool = False
17
+ use_lora_adapters: bool = False
18
+ lora_r: int = 16
19
+ lora_alpha: int = 16
20
+ lora_dropout: float = 0.0
21
+
22
+
23
+ @dataclass
24
+ class TrainerConfig:
25
+ output_dir: str
26
+ run_name: str = "grpo-run"
27
+ max_steps: int = -1
28
+ num_train_epochs: float = 1.0
29
+ per_device_train_batch_size: int = 1
30
+ gradient_accumulation_steps: int = 8
31
+ learning_rate: float = 1.0e-6
32
+ logging_steps: int = 1
33
+ save_steps: int = 25
34
+ save_total_limit: int = 5
35
+ bf16: bool = True
36
+ seed: int = 42
37
+ report_to: str = "wandb"
38
+ optim: str = "adamw_torch"
39
+ gradient_checkpointing: bool = True
40
+ max_grad_norm: float = 1.0
41
+ shuffle_dataset: bool = False
42
+ lr_scheduler_type: str = "cosine"
43
+ lr_scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
44
+ warmup_steps: int = 20
45
+ sanity_log_examples: int = 8
46
+ sanity_log_max_chars: int = 300
47
+ permanent_checkpoint_steps: int = 300
48
+ permanent_checkpoint_dir: str = "checkpoints/permanent"
49
+
50
+
51
+
52
+ @dataclass
53
+ class DataConfig:
54
+ provider: str = "gsm8k_math_curriculum"
55
+ split: str = "train"
56
+ max_samples: int | None = None
57
+
58
+
59
+ @dataclass
60
+ class GenerationConfig:
61
+ max_prompt_length: int = 512
62
+ max_completion_length: int = 256
63
+ num_generations: int = 4
64
+ temperature: float = 0.9
65
+ top_p: float = 0.95
66
+
67
+
68
+ @dataclass
69
+ class ObjectiveConfig:
70
+ name: str = "token_grpo"
71
+ kwargs: dict[str, Any] = field(default_factory=dict)
72
+ class_path: str | None = None
73
+
74
+
75
+ @dataclass
76
+ class RewardsConfig:
77
+ kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
78
+
79
+
80
+ @dataclass
81
+ class AuthConfig:
82
+ hf_api_key: str | None = None
83
+ wandb_api_key: str | None = None
84
+ hf_api_key_env: str = "HF_TOKEN"
85
+ wandb_api_key_env: str = "WANDB_API_KEY"
86
+
87
+
88
+ @dataclass
89
+ class StorageConfig:
90
+ cache_dir: str = "cache"
91
+
92
+
93
+ @dataclass
94
+ class ThinkingKLConfig:
95
+ """Scale KL penalty on completion tokens that overlap the *inner* redacted thinking body."""
96
+
97
+ inner_kl_weight: float = 1.0
98
+
99
+
100
+ @dataclass
101
+ class ExperimentConfig:
102
+ model: ModelConfig
103
+ trainer: TrainerConfig
104
+ data: DataConfig = field(default_factory=DataConfig)
105
+ generation: GenerationConfig = field(default_factory=GenerationConfig)
106
+ objective: ObjectiveConfig = field(default_factory=ObjectiveConfig)
107
+ rewards: RewardsConfig = field(default_factory=RewardsConfig)
108
+ auth: AuthConfig = field(default_factory=AuthConfig)
109
+ storage: StorageConfig = field(default_factory=StorageConfig)
110
+ thinking_kl: ThinkingKLConfig = field(default_factory=ThinkingKLConfig)
111
+ grpo: dict[str, Any] = field(default_factory=dict)
112
+
113
+
114
+ def load_config(path: str | Path) -> ExperimentConfig:
115
+ resolved = resolve_repo_path(path)
116
+ with resolved.open("r", encoding="utf-8") as handle:
117
+ raw = yaml.safe_load(handle)
118
+
119
+ trainer_raw = raw["trainer"]
120
+ # Backward-compatible alias: allow "optimizer" in YAML.
121
+ if "optim" not in trainer_raw and "optimizer" in trainer_raw:
122
+ trainer_raw = {**trainer_raw, "optim": trainer_raw["optimizer"]}
123
+ if "optimizer" in trainer_raw:
124
+ trainer_raw = {k: v for k, v in trainer_raw.items() if k != "optimizer"}
125
+ # Accept common shorthand names.
126
+ optim_aliases = {
127
+ "adamw": "adamw_torch",
128
+ "adamw_fused": "adamw_torch_fused",
129
+ }
130
+ scheduler_aliases = {
131
+ "cosine_decay": "cosine",
132
+ }
133
+ if "optim" in trainer_raw:
134
+ trainer_raw = {
135
+ **trainer_raw,
136
+ "optim": optim_aliases.get(trainer_raw["optim"], trainer_raw["optim"]),
137
+ }
138
+ if "lr_scheduler_type" in trainer_raw:
139
+ trainer_raw = {
140
+ **trainer_raw,
141
+ "lr_scheduler_type": scheduler_aliases.get(
142
+ trainer_raw["lr_scheduler_type"], trainer_raw["lr_scheduler_type"]
143
+ ),
144
+ }
145
+
146
+ # Normalize numeric fields that may come from YAML as strings.
147
+ float_fields = {
148
+ "learning_rate",
149
+ "max_grad_norm",
150
+ "num_train_epochs",
151
+ }
152
+ int_fields = {
153
+ "max_steps",
154
+ "per_device_train_batch_size",
155
+ "gradient_accumulation_steps",
156
+ "logging_steps",
157
+ "save_steps",
158
+ "save_total_limit",
159
+ "seed",
160
+ "warmup_steps",
161
+ "sanity_log_examples",
162
+ "sanity_log_max_chars",
163
+ "permanent_checkpoint_steps",
164
+ }
165
+ for key in float_fields:
166
+ if key in trainer_raw:
167
+ trainer_raw = {**trainer_raw, key: float(trainer_raw[key])}
168
+ for key in int_fields:
169
+ if key in trainer_raw:
170
+ trainer_raw = {**trainer_raw, key: int(trainer_raw[key])}
171
+
172
+ return ExperimentConfig(
173
+ model=ModelConfig(**raw["model"]),
174
+ trainer=TrainerConfig(**trainer_raw),
175
+ data=DataConfig(**raw.get("data", {})),
176
+ generation=GenerationConfig(**raw.get("generation", {})),
177
+ objective=ObjectiveConfig(**raw.get("objective", {})),
178
+ rewards=RewardsConfig(**raw.get("rewards", {})),
179
+ auth=AuthConfig(**raw.get("auth", {})),
180
+ storage=StorageConfig(**raw.get("storage", {})),
181
+ thinking_kl=ThinkingKLConfig(**raw.get("thinking_kl", {})),
182
+ grpo=dict(raw.get("grpo", {})),
183
+ )
src/hackable/data_plugins.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from dataclasses import asdict
5
+ from pathlib import Path
6
+
7
+ from .interfaces import TrainingSample
8
+ from .registry import register_data_provider
9
+
10
+
11
+ @contextmanager
12
+ def _serialized_hf_dataset_download():
13
+ """
14
+ Serialize Hugging Face ``datasets`` downloads/prepare across processes.
15
+
16
+ Multi-GPU ``accelerate launch`` otherwise races on the same ``cache_dir`` and can
17
+ leave a half-written tree (e.g. missing ``dataset_info.json``).
18
+ """
19
+ root = Path.home() / ".cache" / "neuralese"
20
+ root.mkdir(parents=True, exist_ok=True)
21
+ lock_path = root / "hf_dataset_download.lock"
22
+ try:
23
+ from filelock import FileLock
24
+
25
+ with FileLock(str(lock_path), timeout=7200):
26
+ yield
27
+ except ImportError:
28
+ yield
29
+
30
+
31
+ def _load_hf_split(
32
+ path: str,
33
+ split: str,
34
+ cache_dir: str | None,
35
+ config_name: str | None = None,
36
+ ):
37
+ from datasets import load_dataset
38
+
39
+ with _serialized_hf_dataset_download():
40
+ try:
41
+ if config_name is not None:
42
+ return load_dataset(path, config_name, split=split, cache_dir=cache_dir)
43
+ return load_dataset(path, split=split, cache_dir=cache_dir)
44
+ except FileNotFoundError as exc:
45
+ if not cache_dir:
46
+ raise RuntimeError(
47
+ "Hugging Face dataset files are missing from the default cache. "
48
+ "Run once with HF_DATASETS_OFFLINE=0 (or download the dataset), "
49
+ "or set HF_HOME / HF_DATASETS_CACHE to a populated cache."
50
+ ) from exc
51
+ try:
52
+ if config_name is not None:
53
+ return load_dataset(path, config_name, split=split, cache_dir=None)
54
+ return load_dataset(path, split=split, cache_dir=None)
55
+ except FileNotFoundError as exc2:
56
+ raise RuntimeError(
57
+ "Could not load the dataset from the experiment cache_dir or the default HF cache. "
58
+ "Seed ~/.cache/huggingface/datasets (or your HF_HOME) with HF_DATASETS_OFFLINE=0, "
59
+ "or point storage.cache_dir at a shared cache that already contains the dataset."
60
+ ) from exc2
61
+
62
+
63
+ PROMPT_PREFIX = (
64
+ "Solve the following math problem.\n"
65
+ "Think step-by-step inside <think>...</think> tags.\n"
66
+ "Then output only the final answer in LaTeX boxed format.\n"
67
+ "Do not include any words or explanations outside the tags/boxed answer.\n"
68
+ "Output format must be exactly:\n"
69
+ "<think>your reasoning</think>\n"
70
+ "\\boxed{your_final_answer}\n\n"
71
+ )
72
+
73
+
74
+ def _build_math_prompt(question: str) -> str:
75
+ user_content = f"{PROMPT_PREFIX}Question: {question}"
76
+ # Chat-style prefill so decoding starts after "assistant:".
77
+ return f"user: {user_content}\nassistant:"
78
+
79
+
80
+ def _interleave_samples(
81
+ left: list[TrainingSample], right: list[TrainingSample]
82
+ ) -> list[TrainingSample]:
83
+ output: list[TrainingSample] = []
84
+ width = max(len(left), len(right))
85
+ for idx in range(width):
86
+ if idx < len(left):
87
+ output.append(left[idx])
88
+ if idx < len(right):
89
+ output.append(right[idx])
90
+ return output
91
+
92
+
93
+ def _slice_if_needed(
94
+ samples: list[TrainingSample], max_samples: int | None
95
+ ) -> list[TrainingSample]:
96
+ if max_samples is None:
97
+ return samples
98
+ return samples[: max(0, max_samples)]
99
+
100
+
101
+ class _MathProviderBase:
102
+ dataset_name = "EleutherAI/hendrycks_math"
103
+ dataset_configs = (
104
+ "algebra",
105
+ "counting_and_probability",
106
+ "geometry",
107
+ "intermediate_algebra",
108
+ "number_theory",
109
+ "prealgebra",
110
+ "precalculus",
111
+ )
112
+
113
+ def __init__(self, levels: tuple[str, ...]):
114
+ self.levels = levels
115
+
116
+ def load(
117
+ self,
118
+ split: str,
119
+ max_samples: int | None = None,
120
+ cache_dir: str | None = None,
121
+ ) -> list[TrainingSample]:
122
+ try:
123
+ import datasets # noqa: F401
124
+ except Exception as exc:
125
+ raise RuntimeError(
126
+ "datasets is required for Hendrycks MATH providers. Install dependencies first."
127
+ ) from exc
128
+
129
+ level_set = {level.strip() for level in self.levels}
130
+
131
+ output: list[TrainingSample] = []
132
+ for config_name in self.dataset_configs:
133
+ rows = _load_hf_split(
134
+ self.dataset_name,
135
+ split,
136
+ cache_dir,
137
+ config_name=config_name,
138
+ )
139
+ for row in rows:
140
+ level = str(row.get("level", "")).strip()
141
+ if level not in level_set:
142
+ continue
143
+ question = str(row.get("problem", ""))
144
+ target = str(row.get("solution", ""))
145
+ output.append(
146
+ TrainingSample(
147
+ prompt=_build_math_prompt(question),
148
+ target=target,
149
+ metadata={
150
+ "dataset": "hendrycks_math",
151
+ "subject": config_name,
152
+ "level": level,
153
+ },
154
+ )
155
+ )
156
+ if max_samples is not None and len(output) >= max_samples:
157
+ return output
158
+ return output
159
+
160
+
161
+ @register_data_provider("gsm8k")
162
+ class GSM8KProvider:
163
+ def __init__(self, dataset_name: str = "openai/gsm8k", subset: str = "main"):
164
+ self.dataset_name = dataset_name
165
+ self.subset = subset
166
+
167
+ def load(
168
+ self,
169
+ split: str,
170
+ max_samples: int | None = None,
171
+ cache_dir: str | None = None,
172
+ ) -> list[TrainingSample]:
173
+ try:
174
+ import datasets # noqa: F401
175
+ except Exception as exc:
176
+ raise RuntimeError(
177
+ "datasets is required for GSM8K provider. Install dependencies first."
178
+ ) from exc
179
+
180
+ rows = _load_hf_split(
181
+ self.dataset_name,
182
+ split,
183
+ cache_dir,
184
+ config_name=self.subset,
185
+ )
186
+ if max_samples is not None:
187
+ rows = rows.select(range(min(max_samples, len(rows))))
188
+
189
+ output: list[TrainingSample] = []
190
+ for sample_index, row in enumerate(rows):
191
+ prompt = _build_math_prompt(str(row["question"]))
192
+ output.append(
193
+ TrainingSample(
194
+ prompt=prompt,
195
+ target=row["answer"],
196
+ metadata={
197
+ "dataset": "gsm8k",
198
+ "sample_index": int(sample_index),
199
+ "split": str(split),
200
+ },
201
+ )
202
+ )
203
+ return output
204
+
205
+
206
+ @register_data_provider("math_level_1")
207
+ class MathLevel1Provider(_MathProviderBase):
208
+ def __init__(self):
209
+ super().__init__(levels=("Level 1",))
210
+
211
+
212
+ @register_data_provider("math_level_2")
213
+ class MathLevel2Provider(_MathProviderBase):
214
+ def __init__(self):
215
+ super().__init__(levels=("Level 2",))
216
+
217
+
218
+ @register_data_provider("math_level_3")
219
+ class MathLevel3Provider(_MathProviderBase):
220
+ def __init__(self):
221
+ super().__init__(levels=("Level 3",))
222
+
223
+
224
+ @register_data_provider("math_level_4")
225
+ class MathLevel4Provider(_MathProviderBase):
226
+ def __init__(self):
227
+ super().__init__(levels=("Level 4",))
228
+
229
+
230
+ @register_data_provider("math_level_5")
231
+ class MathLevel5Provider(_MathProviderBase):
232
+ def __init__(self):
233
+ super().__init__(levels=("Level 5",))
234
+
235
+
236
+ @register_data_provider("math_levels_12")
237
+ class MathLevels12Provider(_MathProviderBase):
238
+ def __init__(self):
239
+ super().__init__(levels=("Level 1", "Level 2"))
240
+
241
+
242
+ @register_data_provider("math_levels_345")
243
+ class MathLevels345Provider(_MathProviderBase):
244
+ def __init__(self):
245
+ super().__init__(levels=("Level 3", "Level 4", "Level 5"))
246
+
247
+
248
+ @register_data_provider("gsm8k_math_stage12")
249
+ class GSM8KMathStage12Provider:
250
+ def load(
251
+ self,
252
+ split: str,
253
+ max_samples: int | None = None,
254
+ cache_dir: str | None = None,
255
+ ) -> list[TrainingSample]:
256
+ gsm = GSM8KProvider().load(split=split, max_samples=None, cache_dir=cache_dir)
257
+ math12 = MathLevels12Provider().load(
258
+ split=split, max_samples=None, cache_dir=cache_dir
259
+ )
260
+ mixed = _interleave_samples(gsm, math12)
261
+ return _slice_if_needed(mixed, max_samples)
262
+
263
+
264
+ @register_data_provider("gsm8k_math_curriculum")
265
+ class GSM8KMathCurriculumProvider:
266
+ def load(
267
+ self,
268
+ split: str,
269
+ max_samples: int | None = None,
270
+ cache_dir: str | None = None,
271
+ ) -> list[TrainingSample]:
272
+ if max_samples is None:
273
+ stage12_budget = None
274
+ stage345_budget = None
275
+ else:
276
+ stage12_budget = (max_samples + 1) // 2
277
+ stage345_budget = max_samples // 2
278
+
279
+ stage12 = GSM8KMathStage12Provider().load(
280
+ split=split, max_samples=stage12_budget, cache_dir=cache_dir
281
+ )
282
+ stage345 = MathLevels345Provider().load(
283
+ split=split, max_samples=stage345_budget, cache_dir=cache_dir
284
+ )
285
+
286
+ # Curriculum order: first easier mixed set, then harder levels.
287
+ return stage12 + stage345
288
+
289
+
290
+ def to_dataset_rows(samples: list[TrainingSample]) -> list[dict]:
291
+ return [asdict(sample) for sample in samples]
src/hackable/interfaces.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Protocol
5
+
6
+
7
+ @dataclass
8
+ class TrainingSample:
9
+ prompt: str
10
+ target: str
11
+ metadata: dict[str, Any]
12
+
13
+
14
+ class DataProvider(Protocol):
15
+ def load(
16
+ self,
17
+ split: str,
18
+ max_samples: int | None = None,
19
+ cache_dir: str | None = None,
20
+ ) -> list[TrainingSample]:
21
+ ...
22
+
23
+
24
+ class RewardFunction(Protocol):
25
+ def __call__(
26
+ self,
27
+ prompts: list[str],
28
+ completions: list[str],
29
+ references: list[str],
30
+ metadata: list[dict[str, Any]],
31
+ ) -> list[float]:
32
+ ...
33
+
34
+
35
+ class ObjectiveModule(Protocol):
36
+ name: str
37
+
38
+ def reward_names(self) -> list[str]:
39
+ ...
40
+
41
+ def extra_reward(
42
+ self,
43
+ prompts: list[str],
44
+ completions: list[str],
45
+ references: list[str],
46
+ metadata: list[dict[str, Any]],
47
+ ) -> list[float]:
48
+ ...