Avra98 commited on
Commit
76de008
·
verified ·
1 Parent(s): 46d2c93

Initial code dump (rebuttal-ready snapshot)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +70 -0
  2. .gitignore +26 -0
  3. REBUTTAL_REPORT.md +589 -0
  4. REBUTTAL_REPORT.tex +711 -0
  5. addition/README.md +114 -0
  6. addition/__init__.py +1 -0
  7. addition/config.py +294 -0
  8. addition/data.py +390 -0
  9. addition/eval.py +326 -0
  10. addition/model.py +190 -0
  11. addition/plots.py +135 -0
  12. addition/run_comparison.py +122 -0
  13. addition/train.py +369 -0
  14. aligned_cell_policy/shared_cell_policy.py +69 -0
  15. analysis/eval_saved_hard9x9_checkpoints.py +273 -0
  16. checkpoint_utils.py +127 -0
  17. format_utils_icon.py +39 -0
  18. formatting_icon.py +34 -0
  19. hard_9x9_10empty/launch_baseline_stage3_pipeline.sh +102 -0
  20. hard_9x9_15empty/launch_baseline_pipeline.sh +65 -0
  21. hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh +103 -0
  22. hard_9x9_7empty/launch_stage1_sft.sh +99 -0
  23. hard_9x9_curriculum/build_stage3_hard_dataset.py +448 -0
  24. hard_9x9_stage1_consistency_queue/README.md +117 -0
  25. hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh +158 -0
  26. hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh +62 -0
  27. hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh +365 -0
  28. hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh +112 -0
  29. hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh +125 -0
  30. hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh +62 -0
  31. hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh +341 -0
  32. hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh +279 -0
  33. hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh +368 -0
  34. hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh +112 -0
  35. hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh +187 -0
  36. hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh +394 -0
  37. hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh +419 -0
  38. hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh +372 -0
  39. hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh +113 -0
  40. hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md +68 -0
  41. hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md +83 -0
  42. hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh +84 -0
  43. hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md +65 -0
  44. large_baseline_extension/README.md +35 -0
  45. large_baseline_extension/launch_nonlocation_grpo.sh +103 -0
  46. large_baseline_extension/launch_nonlocation_pipeline.sh +80 -0
  47. large_baseline_extension/launch_nonlocation_sft.sh +87 -0
  48. large_latent_extension/README.md +32 -0
  49. large_latent_extension/launch_nonlocation_grpo.sh +101 -0
  50. large_latent_extension/launch_nonlocation_pipeline.sh +82 -0
.env.example ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy to .env for local runs. Do not put real secrets in this example file.
2
+
3
+ # Hugging Face token, if needed for private/gated repos or higher rate limits.
4
+ # HF_TOKEN=hf_xxx
5
+
6
+ # Weights & Biases configuration.
7
+ WANDB_MODE=online
8
+ WANDB_ENTITY=training-dynamics
9
+ WANDB_PROJECT=
10
+ WANDB_RUN_NAME=
11
+ WANDB__SERVICE_WAIT=300
12
+
13
+ # GPU/distributed launch defaults.
14
+ GPU_IDS=0,1,2,3,4,5,6,7
15
+ NUM_PROCESSES=8
16
+ NPROC_PER_JOB=2
17
+ MASTER_PORT=29501
18
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
19
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
20
+
21
+ # Model/cache/output defaults used by launch scripts.
22
+ MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
23
+ CACHE_DIR=.hf_cache
24
+ RUN_TAG=
25
+ CHECKPOINT_ROOT=
26
+ OUTPUT_ROOT=
27
+
28
+ # Warm-baseline all-latent stage pipeline defaults.
29
+ EMPTIES=20
30
+ MODES_SPEC=recurrent_hidden
31
+ GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7
32
+ TRAIN_PUZZLES=10000
33
+ EVAL_PUZZLES=100
34
+ SOLVE_TARGET=0.95
35
+ VALUE_TARGET=0
36
+ MIN_STEPS_BEFORE_STOP=50
37
+ BASELINE_WARM_MAX_STEPS=1000
38
+ LATENT_SFT_MAX_STEPS=1000
39
+ LATENT_GRPO_MAX_STEPS=500
40
+ SFT_NUM_EPOCHS=64
41
+ GRPO_NUM_TRAIN_EPOCHS=50
42
+
43
+ # SFT/GRPO batch and LoRA defaults.
44
+ SFT_PER_DEVICE_BS=8
45
+ SFT_GRAD_ACCUM=2
46
+ BASELINE_PER_DEVICE_BS=16
47
+ BASELINE_GRAD_ACCUM=2
48
+ GRPO_PER_DEVICE_BS=4
49
+ GRPO_GRAD_ACCUM=2
50
+ LORA_R=32
51
+ LORA_ALPHA=64
52
+ LORA_DROPOUT=0.05
53
+ GRPO_BETA=0.0
54
+
55
+ # Optional resume adapters.
56
+ STAGE1_BASELINE_ADAPTER_DIR=
57
+ STAGE1_LATENT_SFT_ADAPTER_DIR=
58
+ STAGE1_LATENT_GRPO_ADAPTER_DIR=
59
+ STAGE2_BASELINE_WARM_ADAPTER_DIR=
60
+ STAGE2_LATENT_SFT_ADAPTER_DIR=
61
+ STAGE2_LATENT_GRPO_ADAPTER_DIR=
62
+ STAGE3_BASELINE_WARM_ADAPTER_DIR=
63
+ STAGE3_LATENT_SFT_ADAPTER_DIR=
64
+
65
+ # Optional debug knobs.
66
+ FIXED_SLOT_DEBUG_LIMIT=0
67
+ FIXED_SLOT_DECODE_DEBUG_LIMIT=0
68
+ LATENT_VOCAB_DEBUG_TOPK=1
69
+ ATTN_DENSITY_DEBUG_LIMIT=0
70
+ ATTN_DENSITY_THRESHOLD_MULT=1.0
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+
4
+ .env
5
+ .env.*
6
+ !.env.example
7
+
8
+ _prepared_data/
9
+ data/
10
+ checkpoints/
11
+ final_checkpoint/
12
+ location_learner/
13
+
14
+ *.jsonl
15
+
16
+ .wandb/
17
+ wandb/
18
+ **/.wandb/
19
+ **/wandb/
20
+ **/wandb_runtime/
21
+
22
+ .venv/
23
+ .hf_cache/
24
+ addition_runs/
25
+ tmp_latent_debug/
26
+ logs/
REBUTTAL_REPORT.md ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Curriculum CoT for 9x9 Sudoku — Rebuttal/Paper-Section Material
2
+
3
+ _Last updated: 2026-05-24_
4
+
5
+ This document is a comprehensive, paper-ready reference of (a) the data pipeline,
6
+ (b) the instruction-tuning prompt format, (c) the curriculum and reward design,
7
+ (d) the latent thought-token architecture, (e) the multi-stage SFT-then-GRPO
8
+ training recipe, and (f) the headline numerical results — so a rebuttal section
9
+ can be assembled directly from this document.
10
+
11
+ ---
12
+
13
+ ## 1. Task
14
+
15
+ We use the model as a **per-cell value policy** for 9×9 Sudoku. For a fixed
16
+ target empty cell, the model emits a JSON set of candidate digits that are
17
+ "i-consistent" with the current grid (definition in §4). We evaluate two
18
+ metrics:
19
+
20
+ - **per-cell exact set match** (`exact_set_match`) — predicted set equals the
21
+ ground-truth i-consistent set;
22
+ - **whole-puzzle solve rate** (`solve`) — every empty cell on a 20-empty puzzle
23
+ produces an exact set match.
24
+
25
+ Because solve = ∏ exact_set_match across the ~20 empty cells of a puzzle,
26
+ the two metrics are non-linearly coupled:
27
+
28
+ $$ \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} $$
29
+
30
+ so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ — every percentage
31
+ point of per-cell exact maps to a much larger swing in solve.
32
+
33
+ ---
34
+
35
+ ## 2. Data pipeline
36
+
37
+ ### 2.1 Puzzle generation
38
+
39
+ Generated by `simple_9x9_curriculum/build_dataset.py`:
40
+
41
+ - Start from a base Latin-square grid; randomly relabel digits, permute
42
+ rows and columns within bands, and transpose.
43
+ - Sample `empties=20` cell positions uniformly at random and erase them.
44
+ - Save 10 000 train + 1 000 eval puzzles (seed 0, seed 1).
45
+ - Output JSONL files
46
+ `data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`.
47
+
48
+ A single record contains:
49
+
50
+ ```json
51
+ {
52
+ "prompt": "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>",
53
+ "completion": "[7,3,8,2,6,9,4,5,...]",
54
+ "metadata": {
55
+ "grid_size": 9, "box_size": 3, "empties": 20,
56
+ "empty_locs_1based": [[1,4],[1,9],...],
57
+ "target_triples_1based": [[1,4,7],[1,9,3],...]
58
+ }
59
+ }
60
+ ```
61
+
62
+ The 20 `target_triples` give the **solved** value at each of the 20 empty
63
+ positions, so per-cell training targets are always available. At training
64
+ time we expand each puzzle into 20 (puzzle, target_cell) examples.
65
+
66
+ ### 2.2 Cell-policy framing
67
+
68
+ The model is never asked to solve a whole puzzle in one shot. Each example
69
+ is one (current_grid, target_cell) pair, and the supervised target is the
70
+ set of digits that are "i-consistent" with the current grid (see §4). This
71
+ turns Sudoku into a **classification-into-a-set** problem and lets us share
72
+ parameters across cells, stages, and puzzle sizes.
73
+
74
+ ### 2.3 Multi-value oversampling (data-side trick)
75
+
76
+ Implemented in `multi_output_cell_policy/sft_multi_output_train.py` via
77
+ `tokenizer._multi_value_oversample_factor` and the CLI flags
78
+
79
+ ```
80
+ --multi_value_oversample_factor INT (default 1)
81
+ --train_target_size_min INT (default 0)
82
+ --train_target_size_max INT (default 0)
83
+ ```
84
+
85
+ Inside the dataset builder, examples whose target set has more than one
86
+ digit are repeated `multi_value_oversample_factor` times in the training
87
+ mix. This biases gradient steps toward exactly the cells the model gets
88
+ wrong (multi-value cells). Empirically, this is the single biggest data-side
89
+ lever — see §10.
90
+
91
+ ### 2.4 Where the bottleneck lives
92
+
93
+ For 20-empty puzzles in stage 3, only ~25 % of empty cells have a
94
+ multi-value target set (the rest collapse to one i-consistent value). Yet
95
+ those multi-value cells are responsible for the entire solve-rate gap:
96
+ they are the cells where the model under-predicts (returns a singleton
97
+ when the target is a 2- or 3-element set), and a single failed cell kills
98
+ the whole-puzzle solve. The reward shaping in §6 and the oversample in 2.3
99
+ both attack this single failure mode.
100
+
101
+ ---
102
+
103
+ ## 3. Instruction format
104
+
105
+ ### 3.1 System prompt
106
+
107
+ (verbatim from `multi_output_cell_policy/prompt_builder.py`)
108
+
109
+ ```text
110
+ You are a Sudoku value policy.
111
+ This setup uses puzzles with about 20 empty cells.
112
+ You will be given one target empty cell.
113
+ Return ONLY one JSON object of the form {"values":[...]}.
114
+ The JSON object must contain exactly one key named "values".
115
+ The "values" field must be a JSON array of unique integers in [1,9].
116
+ You may return as many candidate values as you want, including one, several,
117
+ or many values.
118
+ Choose the number of returned values yourself based on which values seem
119
+ i-consistent.
120
+ The order of the values does not matter.
121
+ Do not output any explanation, markdown, punctuation outside JSON, or extra text.
122
+ Current stage objective: i={i} consistency.
123
+ ```
124
+
125
+ ### 3.2 User message
126
+
127
+ ```text
128
+ Sudoku grid (0 means empty):
129
+ <grid_to_text(grid)>
130
+ Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), …
131
+ Target cell to fill now: (R,C).
132
+ Turn: t/T.
133
+ Return only JSON with candidate values for this target cell: {"values":[...]}
134
+ ```
135
+
136
+ We use the Qwen2.5-Instruct chat template (`tokenizer.apply_chat_template`,
137
+ `add_generation_prompt=True`) to wrap system + user into the actual prompt
138
+ ids. `max_prompt_length = 768`.
139
+
140
+ ### 3.3 Output format
141
+
142
+ ```json
143
+ {"values":[3,7]}
144
+ ```
145
+
146
+ Strictly canonical JSON (single key `values`, sorted unique digit list,
147
+ no whitespace). Outputs are scored by `parse_values_json`
148
+ (`shared_multi_output_policy.py`); any deviation collapses the whole
149
+ prediction to `parse_ok=0` and a hard-coded malformed penalty.
150
+
151
+ `max_completion_length = 24` tokens — enough to emit any 9-digit set.
152
+
153
+ ---
154
+
155
+ ## 4. Curriculum: stage-i consistency
156
+
157
+ The curriculum lives in `_stage_i_consistent_values_for_grid`:
158
+
159
+ - **Stage 1 — i=1 (legal moves).** A value v is i=1 consistent at cell c
160
+ iff placing v at c violates no Sudoku constraint (row, column, 3×3 box).
161
+ This is just "legal candidates".
162
+
163
+ - **Stage 2 — i=2.** v is i=2 consistent at c iff (a) it is i=1 consistent
164
+ AND (b) after placing v, every other empty cell in the grid still has at
165
+ least one i=1-consistent value (i.e. placing v does not immediately make
166
+ the puzzle unsolvable by 1-step propagation).
167
+
168
+ - **Stage 3 — i=3.** Same recursion one more level deep: v is i=3 consistent
169
+ iff after placing v, every other empty cell still has at least one i=2
170
+ consistent value.
171
+
172
+ This is bounded look-ahead constraint propagation. Stage-3 sets are tighter
173
+ than stage-2 sets which are tighter than stage-1 sets. The curriculum
174
+ goal at deployment time is stage-3.
175
+
176
+ In data, we use the same source records and just change `--stage_i`; the
177
+ target set is regenerated on the fly by `stage_i_consistent_values`.
178
+
179
+ ---
180
+
181
+ ## 5. Latent thought-token architecture
182
+
183
+ Base model: **Qwen/Qwen2.5-1.5B-Instruct** + LoRA (r=32, α=64, dropout=0.05)
184
+ on `q,k,v,o,gate,up,down`. The latent variant adds **k thought-token slots**
185
+ between the prompt and the next-token logits.
186
+
187
+ Four modes are implemented (`latent_multi_output_cell_policy/`); the winning
188
+ mode for the final number is **`recurrent_hidden`**:
189
+
190
+ > `build_recurrent_hidden_latent_hidden(model, ids, mask, k)`
191
+ >
192
+ > 1. Run the backbone once on the prompt. Keep `base_hidden = h[:,-1,:]`.
193
+ > 2. Set `latent_token = base_hidden`.
194
+ > 3. Repeat k times: append `latent_token` (as an embedding) to the running
195
+ > sequence, run the backbone again on the extended sequence, and replace
196
+ > `latent_token` with the new last hidden state.
197
+ > 4. After k recursions, `latent_hidden` is fed through the LM head to
198
+ > produce the next-token distribution.
199
+
200
+ In equations, with E the input embedding lookup, f_θ the LoRA-decorated
201
+ backbone, U the LM head:
202
+
203
+ $$ z_0 = f_\theta(E([x_1,\dots,x_T]))_{T} $$
204
+ $$ z_{j+1} = f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1,\dots,z_j]\bigr)_{T+j+1},\ j=0,\dots,k-1 $$
205
+ $$ p(\cdot \mid x_{1:T}) = \mathrm{softmax}(U z_k) $$
206
+
207
+ The model can therefore "iterate" k extra forward passes on the same prompt
208
+ before committing to a token, with the k extra hidden states carrying
209
+ intermediate computation. Setting k=0 recovers the vanilla baseline.
210
+
211
+ The other three latent modes are alternatives that we ablated:
212
+ `fixed_slots` (concatenate k trainable seed embeddings — Option-2),
213
+ `latent_seeds` (similar to fixed_slots), and `residual` (project k extra
214
+ hidden states back onto the base hidden state via a learned residual). All
215
+ modes share the SFT and GRPO trainers; only the next-token logit function
216
+ changes.
217
+
218
+ For the curriculum, we grow k stage by stage:
219
+
220
+ | stage | num_cot_tokens | comment |
221
+ | --- | ---: | --- |
222
+ | 1 | 1 | one extra recursion as soon as the model has the surface form |
223
+ | 2 | 2 | two — needed for 1-step propagation reasoning |
224
+ | 3 | 3 | three — needed for 2-step propagation reasoning |
225
+
226
+ ---
227
+
228
+ ## 6. The reward function
229
+
230
+ Defined in `multi_output_cell_policy/rewards.py`.
231
+
232
+ Given target set T, predicted set P (after JSON parse), let
233
+
234
+ - `num_good = |P ∩ T|`
235
+ - `num_bad = |P \ T|`
236
+ - `num_missing = max(0, |T| − num_good)`
237
+ - `is_exact = (P ≠ ∅) ∧ (P = T)`
238
+ - `tri(n) = n(n+1)/2` (rewards larger correct sets superlinearly)
239
+
240
+ Then
241
+
242
+ $$
243
+ r = \mathrm{tri}(\mathrm{num\_good}) \cdot R_g \;-\; \mathrm{num\_bad} \cdot P_b \\
244
+ - \mathbb{1}[P=\varnothing]\, P_e \;-\; \mathbb{1}[|P|=1, |T|>1, i<2]\, P_s \\
245
+ - \mathrm{num\_missing}\cdot P_m \;+\; \mathbb{1}[\text{is\_exact}]\, B_x \\
246
+ - \mathbb{1}[|P|<|T|, |T|>1]\, P_c
247
+ $$
248
+
249
+ with parameters (this is the recipe that produced the 0.58/0.68 latent solve):
250
+
251
+ | symbol | flag | value | role |
252
+ |---|---|---:|---|
253
+ | $R_g$ | `--reward_good_value` | 1.25 | per-correct-value reward (with triangular shape) |
254
+ | $P_b$ | `--penalty_bad_value` | 1.0 | per-extra-wrong-value penalty |
255
+ | $P_{\!\text{mal}}$ | `--penalty_malformed` | 4.0 | flat penalty if JSON parse fails |
256
+ | $P_e$ | `--penalty_empty` | 0.5 | flat penalty if predicted set is empty |
257
+ | $P_s$ | `--penalty_singleton` | 1.5 | only at stage<2: punishes singleton on multi-value targets |
258
+ | $P_m$ | `--penalty_missing` | **0.75** | per-missing-value (recall pressure) — **NEW** |
259
+ | $B_x$ | `--exact_match_bonus` | **2.0** | only when P = T — **NEW** |
260
+ | $P_c$ | `--cardinality_mismatch_penalty` | **1.0** | when |P| < |T| and |T|>1 — **NEW** |
261
+
262
+ Parse failures short-circuit to `r = -P_mal` and zero per-cell metrics.
263
+
264
+ ### 6.1 Why those three new terms exist (the breakthrough)
265
+
266
+ Diagnosis: at the v3/v4 plateau, eval reported
267
+
268
+ ```
269
+ exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000
270
+ ```
271
+
272
+ across all checkpoints. Per-cell exact and precision/recall were all near
273
+ 0.95 but the model **always predicted a single digit** (`avg_set_size=1.000`).
274
+ On a multi-value target $T=\{8,9\}$, predicting $\{8\}$ keeps precision=1.0,
275
+ recall=0.5 and yet `exact_set_match=0`. Solve = exact_set_match^N is
276
+ catastrophic in $N$ (=20), so even a small fraction of multi-value cells
277
+ killed it.
278
+
279
+ Without any of the new terms the optimum of $r$ on a multi-value cell is
280
+ trivially "predict the singleton you are most confident about" — there is
281
+ no upside to enumerate the second value. The three new terms close exactly
282
+ that hole:
283
+
284
+ - $P_m$ (`penalty_missing`) directly penalises recall;
285
+ - $B_x$ (`exact_match_bonus`) makes $P=T$ strictly dominate any singleton;
286
+ - $P_c$ (`cardinality_mismatch_penalty`) is a flat hammer whenever $|P|<|T|$.
287
+
288
+ After these terms were added, GRPO on the latent variant moved solve from
289
+ ~0.30 to ~0.58 (100-puzzle eval) over ~200 steps. The same fix is what we
290
+ ported back into the baseline pipeline this evening (see §10).
291
+
292
+ ---
293
+
294
+ ## 7. Multi-stage warm-baseline pipeline (the recipe that worked)
295
+
296
+ Master script:
297
+ `hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`.
298
+
299
+ For each curriculum stage we run **three sub-phases in order**:
300
+
301
+ ```
302
+ [stage i]
303
+ (1) baseline warm SFT (no latent tokens, k=0, vanilla LM)
304
+ (2) latent SFT (k = i, latent mode = recurrent_hidden)
305
+ (3) latent GRPO (k = i)
306
+ ```
307
+
308
+ **The warm baseline phase (1) is the trick that makes the curriculum work.**
309
+ At every stage transition the data distribution changes (i increases →
310
+ target sets shrink) and a new latent slot appears. Doing a vanilla SFT on
311
+ the new distribution first lets the LM relearn the surface form on familiar
312
+ parameters; THEN the latent SFT adds the extra thought slot on top of an
313
+ already-good policy. When we tried to add a new latent slot directly on
314
+ top of the previous stage's GRPO checkpoint, training loss did NOT
315
+ decrease.
316
+
317
+ Concrete LR schedule used for the champion run:
318
+
319
+ | phase | init from | LR | k |
320
+ |---|---|---:|---:|
321
+ | S1 baseline SFT | base Qwen | 2e-4 | 0 |
322
+ | S1 latent SFT | S1 baseline | 2e-4 | 1 |
323
+ | S1 latent GRPO | S1 latent SFT | 1e-6 | 1 |
324
+ | S2 baseline warm SFT | S1 GRPO | 5e-5 | 0 |
325
+ | S2 latent SFT | S2 baseline | 5e-5 | 2 |
326
+ | S2 latent GRPO | S2 latent SFT | 1e-6 | 2 |
327
+ | S3 baseline warm SFT | S2 GRPO | 5e-5 | 0 |
328
+ | S3 latent SFT | S3 baseline | 5e-5 → 1e-5 (champion) | 3 |
329
+ | S3 latent GRPO | S3 latent SFT | 5e-6 (β=0) | 3 |
330
+
331
+ Other shared knobs:
332
+
333
+ ```
334
+ LoRA: r=32 α=64 dropout=0.05 on q,k,v,o,gate,up,down
335
+ SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128
336
+ GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64
337
+ num_generations=4 beta=0.0 max_prompt_length=1024
338
+ max_completion_length=24
339
+ multi_value_oversample_factor=5, exact_match_bonus=2.0,
340
+ penalty_missing=0.75, cardinality_mismatch_penalty=1.0
341
+ ```
342
+
343
+ ---
344
+
345
+ ## 8. GRPO settings that mattered
346
+
347
+ - **β = 0.** The KL anchor was harmful in every sweep where we tried β>0.
348
+ `s3_grpo_kl04` (β=0.04) peaked at solve=0.625 (40p) at step 100 and
349
+ regressed to 0.525 by step 500.
350
+
351
+ - **num_generations = 4.** With num_generations=2 we routinely saw
352
+ `reward_std = 0` (all sampled completions identical → no gradient).
353
+ Bumping to 4 fixed it.
354
+
355
+ - **Low LR.** `lr=5e-6` was the steadiest. `lr=1e-5` peaked at step 200
356
+ (solve 0.65) then collapsed back to 0.54 — classic mode collapse.
357
+
358
+ - **Effective bs ≥ 64.** TRL's GRPOConfig requires
359
+ `eff_bs * grad_accum % num_generations == 0`; with 8 GPUs we hit this
360
+ trivially, but we caution single-GPU rerunners to set
361
+ `per_device_bs=4 grad_accum=2 num_generations=4`.
362
+
363
+ - **`enable_input_require_grads()` on the wrapped backbone.** Required for
364
+ TRL 0.15.x + PEFT LoRA + gradient checkpointing — otherwise the loss
365
+ tensor produced by GRPOTrainer has `requires_grad=False` and `.backward()`
366
+ raises. Also `unwrapped.config.use_cache = False`.
367
+
368
+ ---
369
+
370
+ ## 9. Final hyperparameters table — champion latent run
371
+
372
+ | group | hyperparameter | value |
373
+ |---|---|---|
374
+ | Backbone | model | Qwen/Qwen2.5-1.5B-Instruct |
375
+ | Backbone | dtype | bf16 |
376
+ | Backbone | LoRA target modules | q,k,v,o,gate,up,down |
377
+ | Backbone | LoRA r / α / dropout | 32 / 64 / 0.05 |
378
+ | Latent | mode | recurrent_hidden |
379
+ | Latent | num_cot_tokens (S1/S2/S3) | 1 / 2 / 3 |
380
+ | Latent | max_latent_slots / seeds | 8 / 8 |
381
+ | Data | total empties | 20 |
382
+ | Data | train rows / eval rows | 10 000 / 100 |
383
+ | Data | multi_value_oversample_factor | 5 |
384
+ | Data | mixed_stage1_ratio (S1) | 1 |
385
+ | Data | mixed_stage2_ratio (S≥2) | 1 |
386
+ | SFT | per_device_bs / grad_accum | 8 / 2 |
387
+ | SFT | num_epochs (cap) | 64 |
388
+ | SFT | LR (S1 latent) | 2e-4 |
389
+ | SFT | LR (S2/S3 baseline warm + latent) | 5e-5 |
390
+ | SFT | LR (S3 latent champion `s3b_lr1e5_o5`) | 1e-5 |
391
+ | SFT | weight_decay | 0.0 |
392
+ | SFT | gradient checkpointing | on |
393
+ | GRPO | per_device_bs / grad_accum | 4 / 2 |
394
+ | GRPO | num_generations | 4 |
395
+ | GRPO | LR | 5e-6 (S3); 1e-6 (S1, S2) |
396
+ | GRPO | β (KL) | 0.0 |
397
+ | GRPO | max_prompt_length | 1024 |
398
+ | GRPO | max_completion_length | 24 |
399
+ | Reward | reward_good_value | 1.25 |
400
+ | Reward | penalty_bad_value | 1.0 |
401
+ | Reward | penalty_malformed | 4.0 |
402
+ | Reward | penalty_empty | 0.5 |
403
+ | Reward | penalty_singleton | 1.5 |
404
+ | Reward | penalty_missing | 0.75 |
405
+ | Reward | exact_match_bonus | 2.0 |
406
+ | Reward | cardinality_mismatch_penalty | 1.0 |
407
+ | Eval | early-stop on prec/recall | 0.98 |
408
+
409
+ ---
410
+
411
+ ## 10. Headline results
412
+
413
+ ### 10.1 Latent (with thought tokens, recurrent_hidden)
414
+
415
+ | eval | model / phase | step | exact | prec | recall | **solve** |
416
+ |---|---|---:|---:|---:|---:|---:|
417
+ | **100p (auth.)** | `s3_grpo_baseline` (S3 GRPO, β=0, lr=5e-6) | 200 | 0.9665 | 0.9673 | 0.9680 | **0.580 (58/100)** |
418
+ | 40p | `s3_grpo_sharp_rwd` (exact_b=4, card_pen=3) | 300 | — | — | — | **0.675 (27/40)** |
419
+ | 40p | `s3_grpo_lr1e5` | 200 | 0.978 | 0.978 | 0.979 | 0.650 |
420
+ | 40p | `s3b_lr1e5_o5` (S3 SFT champion) | 2400 | 0.974 | 0.974 | 0.975 | 0.600 |
421
+
422
+ ### 10.2 Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)
423
+
424
+ | sweep | best variant | best step | exact | **solve (100p)** |
425
+ |---|---|---:|---:|---:|
426
+ | v3 (single-GPU LR=2e-5, no oversample, no new reward terms) | `baseline_3stage_20260522` | — | 0.730 | **0.000** |
427
+ | v4 (LR sweep, multi-GPU, original reward) | `pipe_v_sft_extend` (S3 SFT extended) | 4000 | 0.948 | **0.400** |
428
+ | **v6 (this evening, ports latent reward + oversample)** | `v6_i_sft_v_oversample10` (oversample=10) | running | 0.952+ | **0.440 (best so far)** |
429
+
430
+ The v6 sweep is still running — `v6_e/f/i` are in S3 SFT continuation,
431
+ GRPO follow-on phases queued. The v6_i variant has hit **solve=0.44** at
432
+ SFT eval (new baseline best, +0.04 over v4) and is still climbing.
433
+
434
+ ### 10.3 Stage-by-stage trajectory (latent, 40-puzzle eval)
435
+
436
+ ```
437
+ S1 SFT : exact ≈ 0.85, solve ≈ 0.20
438
+ S1 GRPO : exact ≈ 0.90, solve ≈ 0.20
439
+ S2 SFT (no oversample) : exact ≈ 0.94, solve ≈ 0.20-0.25 <- the wall
440
+ S2 SFT + multi_value_oversample=5 : exact ≈ 0.96, solve ≈ 0.30-0.35
441
+ S2 GRPO + new reward terms : exact ≈ 0.96, solve ≈ 0.35-0.40
442
+ S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion
443
+ S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion
444
+ S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak
445
+ ```
446
+
447
+ ### 10.4 Latent vs baseline gap (head-to-head, same 100p eval, same prompts)
448
+
449
+ | model | exact | prec | recall | **solve** | solved/100 |
450
+ |---|---:|---:|---:|---:|---:|
451
+ | Latent recurrent_hidden, S3 GRPO | 0.9665 | 0.9673 | 0.9680 | **0.580** | 58 |
452
+ | Vanilla baseline, v6_i (best at time of writing) | 0.952 | 0.952 | 0.952 | **0.440** | 44 |
453
+
454
+ Gap on 100-puzzle solve: ≈ **+0.14 absolute / +32 % relative** for latent
455
+ over the strongest baseline we have.
456
+
457
+ ---
458
+
459
+ ## 11. Why the latent works (interpretation hypotheses)
460
+
461
+ These are the working hypotheses the experiments are consistent with;
462
+ none is fully proven and ablations are still WIP.
463
+
464
+ 1. **Constraint-propagation depth.** Stage-3 i-consistency is essentially
465
+ 2-ply lookahead. With $k=3$ recurrent hidden tokens the model gets
466
+ exactly three extra forward passes between prompt and output — one for
467
+ the legality check, one for 1-step propagation, one for the second
468
+ step of propagation. Empirically the gap to the no-thought-token
469
+ baseline appears at stages where multi-step propagation matters
470
+ (stage 2 onward; stage 1 numbers are essentially identical).
471
+
472
+ 2. **Multi-value cells require enumeration, which a singleton softmax can't
473
+ do in one forward pass.** A vanilla LM at 1.5B parameters predicts
474
+ essentially deterministically once temperature is low; for a target set
475
+ {8, 9} the LM picks one of the two and stops. The latent model can use
476
+ one of the recurrent hidden steps to "consider" each option without
477
+ committing yet, which is exactly the failure mode in the data
478
+ (`avg_set_size = 1.000` for the baseline, `≈ 1.05` for the latent S3
479
+ model on the same eval).
480
+
481
+ 3. **Stable curriculum capacity growth.** Adding a new latent slot at every
482
+ stage gives the model a "fresh slate" of representational capacity at the
483
+ exact transition where the task gets harder. The warm-baseline SFT
484
+ between stages prevents the new slot from corrupting the previously
485
+ learned policy. Without warm baseline, training loss did not decrease
486
+ at all (we observed this directly when we tried to skip the warm
487
+ baseline yesterday).
488
+
489
+ 4. **GRPO without latent slots is starved of variance.** With max_completion
490
+ length 24 and the model essentially deterministic, GRPO's 4 sampled
491
+ completions per prompt collapse to a single answer — `reward_std = 0`,
492
+ no gradient. With latent recurrence + the new exact_match_bonus reward,
493
+ the model occasionally samples a 2-element set, gets a much higher
494
+ reward, and that prompt gets a real gradient signal.
495
+
496
+ ---
497
+
498
+ ## 12. Reproducibility
499
+
500
+ Code repository: `https://github.com/Avra98/curriculum_cot`
501
+ Latent checkpoints: `https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages`
502
+ Baseline checkpoints: `https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep`
503
+
504
+ Key scripts:
505
+
506
+ - Master orchestrator (latent, 9-phase warm-baseline pipeline):
507
+ `hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`
508
+ - Vanilla baseline pipeline:
509
+ `_runs/baseline_1p5b_pipeline_v4.sh` (with v6 launchers
510
+ `_runs/launch_baseline_push_v6.sh`)
511
+ - SFT trainer (vanilla): `multi_output_cell_policy/sft_multi_output_train.py`
512
+ - GRPO trainer (vanilla): `multi_output_cell_policy/grpo_multi_output_train.py`
513
+ - SFT trainer (latent): `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`
514
+ - GRPO trainer (latent): `latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py`
515
+ - Reward function: `multi_output_cell_policy/rewards.py`
516
+ - Prompt builder: `multi_output_cell_policy/prompt_builder.py`
517
+ - Stage-i consistency: `multi_output_cell_policy/shared_multi_output_policy.py`
518
+ - 100-puzzle evaluator: `analysis/eval_stage2_checkpoint.py`
519
+
520
+ To reproduce the latent champion (1.5B, 9-phase, ~16 GPU·h on 8×H100 80GB):
521
+
522
+ ```bash
523
+ export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter
524
+ bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
525
+ ```
526
+
527
+ To reproduce the v6 baseline push (single-GPU per variant, ~6 GPU·h):
528
+
529
+ ```bash
530
+ bash _runs/launch_baseline_push_v6.sh
531
+ ```
532
+
533
+ ---
534
+
535
+ ## Appendix A. The reward fix as a one-line patch
536
+
537
+ The single most consequential code change in this whole project, as a
538
+ self-contained patch on `multi_output_cell_policy/rewards.py`:
539
+
540
+ ```python
541
+ # new args (default 0 preserves legacy behaviour)
542
+ penalty_missing: float = 0.0
543
+ exact_match_bonus: float = 0.0
544
+ cardinality_mismatch_penalty: float = 0.0
545
+
546
+ num_missing = max(0, len(target_set) - num_good)
547
+ is_exact = bool(predicted_values) and (set(predicted_values) == target_set)
548
+
549
+ # ... base reward (triangular_number(num_good)*reward_good_value - num_bad*penalty_bad_value)
550
+
551
+ if num_missing > 0:
552
+ reward -= num_missing * penalty_missing
553
+ if is_exact:
554
+ reward += exact_match_bonus
555
+ if len(predicted_values) < len(target_values) and len(target_values) > 1:
556
+ reward -= cardinality_mismatch_penalty
557
+ ```
558
+
559
+ Defaults are zero so old runs are unaffected; the recipe sets
560
+ `(P_m, B_x, P_c) = (0.75, 2.0, 1.0)` for the vanilla recipe and
561
+ `(1.0, 4.0, 3.0)` for the "sharp_rwd" variant.
562
+
563
+ ## Appendix B. The warm-baseline trick as a sequence diagram
564
+
565
+ ```
566
+ Stage 1 Stage 2 Stage 3
567
+ ───────── ───────── ─────────
568
+ [base Qwen] ↓ ↓
569
+ ↓ ↓ ↓
570
+ S1 baseline SFT → S2 baseline SFT → S3 baseline SFT
571
+ (no latent, k=0) (no latent, k=0) (no latent, k=0)
572
+ ↓ ↓ ↓
573
+ S1 latent SFT → S2 latent SFT → S3 latent SFT
574
+ (k=1) (k=2) (k=3)
575
+ ↓ ↓ ↓
576
+ S1 latent GRPO → S2 latent GRPO → S3 latent GRPO
577
+ (k=1, β=0, lr 1e-6) (k=2, β=0) (k=3, β=0, lr 5e-6)
578
+ ↓ ↓ ↓
579
+ [final policy]
580
+ ```
581
+
582
+ Every arrow is `init_adapter_dir = <previous output>`. Each row is a
583
+ "slot in the curriculum"; the column adds reasoning capacity (k+=1) and
584
+ moves to a harder target distribution (i+=1). The diagonal across the
585
+ diagram is the actual training trajectory.
586
+
587
+ ---
588
+
589
+ _End of report._
REBUTTAL_REPORT.tex ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ \documentclass[11pt]{article}
2
+ \usepackage[a4paper,margin=1in]{geometry}
3
+ \usepackage[utf8]{inputenc}
4
+ \usepackage[T1]{fontenc}
5
+ \usepackage{lmodern}
6
+ \usepackage{microtype}
7
+ \usepackage{amsmath,amssymb}
8
+ \usepackage{booktabs}
9
+ \usepackage{array}
10
+ \usepackage{longtable}
11
+ \usepackage{tabularx}
12
+ \usepackage{xcolor}
13
+ \usepackage{listings}
14
+ \usepackage{hyperref}
15
+ \usepackage{enumitem}
16
+ \usepackage{parskip}
17
+
18
+ \hypersetup{colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue}
19
+
20
+ \lstset{
21
+ basicstyle=\ttfamily\small,
22
+ breaklines=true,
23
+ columns=fullflexible,
24
+ keepspaces=true,
25
+ frame=single,
26
+ framerule=0.4pt,
27
+ xleftmargin=0.5em,
28
+ xrightmargin=0.5em,
29
+ showstringspaces=false,
30
+ }
31
+
32
+ \newcommand{\code}[1]{\texttt{#1}}
33
+ \newcommand{\indic}{\mathbf{1}}
34
+
35
+ \title{Curriculum CoT for $9{\times}9$ Sudoku\\[2pt]
36
+ \large Rebuttal / Paper-Section Material}
37
+ \author{}
38
+ \date{Last updated: 2026--05--24}
39
+
40
+ \begin{document}
41
+ \maketitle
42
+
43
+ \noindent
44
+ This document is a comprehensive, paper-ready reference of (a) the data
45
+ pipeline, (b) the instruction-tuning prompt format, (c) the curriculum
46
+ and reward design, (d) the latent thought-token architecture, (e) the
47
+ multi-stage SFT-then-GRPO training recipe, and (f) the headline numerical
48
+ results --- so a rebuttal section can be assembled directly from this
49
+ document.
50
+
51
+ \bigskip
52
+ \hrule
53
+ \bigskip
54
+
55
+ \section{Task}
56
+
57
+ We use the model as a \textbf{per-cell value policy} for $9\times 9$
58
+ Sudoku. For a fixed target empty cell, the model emits a JSON set of
59
+ candidate digits that are ``i-consistent'' with the current grid
60
+ (definition in \S 4). We evaluate two metrics:
61
+
62
+ \begin{itemize}[leftmargin=*]
63
+ \item \textbf{per-cell exact set match} (\code{exact\_set\_match}) ---
64
+ predicted set equals the ground-truth i-consistent set;
65
+ \item \textbf{whole-puzzle solve rate} (\code{solve}) --- every empty
66
+ cell on a 20-empty puzzle produces an exact set match.
67
+ \end{itemize}
68
+
69
+ Because $\text{solve} = \prod \text{exact\_set\_match}$ across the $\sim 20$
70
+ empty cells of a puzzle, the two metrics are non-linearly coupled:
71
+ \[
72
+ \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}}
73
+ \]
74
+ so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ --- every
75
+ percentage point of per-cell exact maps to a much larger swing in solve.
76
+
77
+ \section{Data pipeline}
78
+
79
+ \subsection{Puzzle generation}
80
+
81
+ Generated by \code{simple\_9x9\_curriculum/build\_dataset.py}:
82
+
83
+ \begin{itemize}[leftmargin=*]
84
+ \item Start from a base Latin-square grid; randomly relabel digits,
85
+ permute rows and columns within bands, and transpose.
86
+ \item Sample \code{empties=20} cell positions uniformly at random and
87
+ erase them.
88
+ \item Save 10\,000 train + 1\,000 eval puzzles (seed 0, seed 1).
89
+ \item Output JSONL files \code{data/sudoku\_t3\_20empty\_value\_qwen\_text\_stage1\_\{train,eval\}.jsonl}.
90
+ \end{itemize}
91
+
92
+ A single record contains:
93
+
94
+ \begin{lstlisting}
95
+ {
96
+ "prompt": "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>",
97
+ "completion": "[7,3,8,2,6,9,4,5,...]",
98
+ "metadata": {
99
+ "grid_size": 9, "box_size": 3, "empties": 20,
100
+ "empty_locs_1based": [[1,4],[1,9],...],
101
+ "target_triples_1based": [[1,4,7],[1,9,3],...]
102
+ }
103
+ }
104
+ \end{lstlisting}
105
+
106
+ The 20 \code{target\_triples} give the \textbf{solved} value at each of
107
+ the 20 empty positions, so per-cell training targets are always
108
+ available. At training time we expand each puzzle into 20 (puzzle,
109
+ target\_cell) examples.
110
+
111
+ \subsection{Cell-policy framing}
112
+
113
+ The model is never asked to solve a whole puzzle in one shot. Each
114
+ example is one (current\_grid, target\_cell) pair, and the supervised
115
+ target is the set of digits that are ``i-consistent'' with the current
116
+ grid (see \S 4). This turns Sudoku into a
117
+ \textbf{classification-into-a-set} problem and lets us share parameters
118
+ across cells, stages, and puzzle sizes.
119
+
120
+ \subsection{Multi-value oversampling (data-side trick)}
121
+
122
+ Implemented in \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py}
123
+ via \code{tokenizer.\_multi\_value\_oversample\_factor} and the CLI flags
124
+
125
+ \begin{lstlisting}
126
+ --multi_value_oversample_factor INT (default 1)
127
+ --train_target_size_min INT (default 0)
128
+ --train_target_size_max INT (default 0)
129
+ \end{lstlisting}
130
+
131
+ Inside the dataset builder, examples whose target set has more than one
132
+ digit are repeated \code{multi\_value\_oversample\_factor} times in the
133
+ training mix. This biases gradient steps toward exactly the cells the
134
+ model gets wrong (multi-value cells). Empirically, this is the single
135
+ biggest data-side lever --- see \S 10.
136
+
137
+ \subsection{Where the bottleneck lives}
138
+
139
+ For 20-empty puzzles in stage 3, only $\sim 25\%$ of empty cells have a
140
+ multi-value target set (the rest collapse to one i-consistent value).
141
+ Yet those multi-value cells are responsible for the entire solve-rate
142
+ gap: they are the cells where the model under-predicts (returns a
143
+ singleton when the target is a 2- or 3-element set), and a single
144
+ failed cell kills the whole-puzzle solve. The reward shaping in \S 6
145
+ and the oversample in 2.3 both attack this single failure mode.
146
+
147
+ \section{Instruction format}
148
+
149
+ \subsection{System prompt}
150
+
151
+ (verbatim from \code{multi\_output\_cell\_policy/prompt\_builder.py})
152
+
153
+ \begin{lstlisting}
154
+ You are a Sudoku value policy.
155
+ This setup uses puzzles with about 20 empty cells.
156
+ You will be given one target empty cell.
157
+ Return ONLY one JSON object of the form {"values":[...]}.
158
+ The JSON object must contain exactly one key named "values".
159
+ The "values" field must be a JSON array of unique integers in [1,9].
160
+ You may return as many candidate values as you want, including one,
161
+ several, or many values.
162
+ Choose the number of returned values yourself based on which values seem
163
+ i-consistent.
164
+ The order of the values does not matter.
165
+ Do not output any explanation, markdown, punctuation outside JSON, or
166
+ extra text.
167
+ Current stage objective: i={i} consistency.
168
+ \end{lstlisting}
169
+
170
+ \subsection{User message}
171
+
172
+ \begin{lstlisting}
173
+ Sudoku grid (0 means empty):
174
+ <grid_to_text(grid)>
175
+ Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), ...
176
+ Target cell to fill now: (R,C).
177
+ Turn: t/T.
178
+ Return only JSON with candidate values for this target cell: {"values":[...]}
179
+ \end{lstlisting}
180
+
181
+ We use the Qwen2.5-Instruct chat template
182
+ (\code{tokenizer.apply\_chat\_template}, \code{add\_generation\_prompt=True})
183
+ to wrap system + user into the actual prompt ids.
184
+ \code{max\_prompt\_length = 768}.
185
+
186
+ \subsection{Output format}
187
+
188
+ \begin{lstlisting}
189
+ {"values":[3,7]}
190
+ \end{lstlisting}
191
+
192
+ Strictly canonical JSON (single key \code{values}, sorted unique digit
193
+ list, no whitespace). Outputs are scored by \code{parse\_values\_json}
194
+ (\code{shared\_multi\_output\_policy.py}); any deviation collapses the
195
+ whole prediction to \code{parse\_ok=0} and a hard-coded malformed
196
+ penalty.
197
+
198
+ \code{max\_completion\_length = 24} tokens --- enough to emit any
199
+ 9-digit set.
200
+
201
+ \section{Curriculum: stage-i consistency}
202
+
203
+ The curriculum lives in \code{\_stage\_i\_consistent\_values\_for\_grid}:
204
+
205
+ \begin{itemize}[leftmargin=*]
206
+ \item \textbf{Stage 1 --- $i=1$ (legal moves).} A value $v$ is $i=1$
207
+ consistent at cell $c$ iff placing $v$ at $c$ violates no Sudoku
208
+ constraint (row, column, $3\times 3$ box). This is just ``legal
209
+ candidates''.
210
+
211
+ \item \textbf{Stage 2 --- $i=2$.} $v$ is $i=2$ consistent at $c$ iff
212
+ (a) it is $i=1$ consistent AND (b) after placing $v$, every other
213
+ empty cell in the grid still has at least one $i=1$-consistent value
214
+ (i.e.\ placing $v$ does not immediately make the puzzle unsolvable
215
+ by 1-step propagation).
216
+
217
+ \item \textbf{Stage 3 --- $i=3$.} Same recursion one more level deep:
218
+ $v$ is $i=3$ consistent iff after placing $v$, every other empty cell
219
+ still has at least one $i=2$ consistent value.
220
+ \end{itemize}
221
+
222
+ This is bounded look-ahead constraint propagation. Stage-3 sets are
223
+ tighter than stage-2 sets which are tighter than stage-1 sets. The
224
+ curriculum goal at deployment time is stage-3.
225
+
226
+ In data, we use the same source records and just change \code{--stage\_i};
227
+ the target set is regenerated on the fly by
228
+ \code{stage\_i\_consistent\_values}.
229
+
230
+ \section{Latent thought-token architecture}
231
+
232
+ Base model: \textbf{Qwen/Qwen2.5-1.5B-Instruct} + LoRA
233
+ ($r=32$, $\alpha=64$, dropout $=0.05$) on
234
+ \code{q,k,v,o,gate,up,down}. The latent variant adds \textbf{$k$
235
+ thought-token slots} between the prompt and the next-token logits.
236
+
237
+ Four modes are implemented (\code{latent\_multi\_output\_cell\_policy/});
238
+ the winning mode for the final number is \textbf{\code{recurrent\_hidden}}:
239
+
240
+ \begin{quote}
241
+ \code{build\_recurrent\_hidden\_latent\_hidden(model, ids, mask, k)}
242
+ \begin{enumerate}[leftmargin=*,nosep]
243
+ \item Run the backbone once on the prompt. Keep
244
+ \code{base\_hidden = h[:,-1,:]}.
245
+ \item Set \code{latent\_token = base\_hidden}.
246
+ \item Repeat $k$ times: append \code{latent\_token} (as an embedding)
247
+ to the running sequence, run the backbone again on the extended
248
+ sequence, and replace \code{latent\_token} with the new last hidden
249
+ state.
250
+ \item After $k$ recursions, \code{latent\_hidden} is fed through the LM
251
+ head to produce the next-token distribution.
252
+ \end{enumerate}
253
+ \end{quote}
254
+
255
+ In equations, with $E$ the input embedding lookup, $f_\theta$ the
256
+ LoRA-decorated backbone, $U$ the LM head:
257
+ \begin{align*}
258
+ z_0 &= f_\theta\bigl(E([x_1,\dots,x_T])\bigr)_T \\
259
+ z_{j+1} &= f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1, \dots, z_j]\bigr)_{T+j+1},\quad j=0,\dots,k-1 \\
260
+ p(\cdot \mid x_{1:T}) &= \mathrm{softmax}(U z_k)
261
+ \end{align*}
262
+
263
+ The model can therefore ``iterate'' $k$ extra forward passes on the
264
+ same prompt before committing to a token, with the $k$ extra hidden
265
+ states carrying intermediate computation. Setting $k=0$ recovers the
266
+ vanilla baseline.
267
+
268
+ The other three latent modes are alternatives that we ablated:
269
+ \code{fixed\_slots} (concatenate $k$ trainable seed embeddings ---
270
+ Option-2), \code{latent\_seeds} (similar to \code{fixed\_slots}), and
271
+ \code{residual} (project $k$ extra hidden states back onto the base
272
+ hidden state via a learned residual). All modes share the SFT and GRPO
273
+ trainers; only the next-token logit function changes.
274
+
275
+ For the curriculum, we grow $k$ stage by stage:
276
+
277
+ \begin{center}
278
+ \begin{tabular}{ccl}
279
+ \toprule
280
+ \textbf{stage} & \textbf{num\_cot\_tokens} & \textbf{comment} \\
281
+ \midrule
282
+ 1 & 1 & one extra recursion as soon as the model has the surface form \\
283
+ 2 & 2 & two --- needed for 1-step propagation reasoning \\
284
+ 3 & 3 & three --- needed for 2-step propagation reasoning \\
285
+ \bottomrule
286
+ \end{tabular}
287
+ \end{center}
288
+
289
+ \section{The reward function}
290
+
291
+ Defined in \code{multi\_output\_cell\_policy/rewards.py}.
292
+
293
+ Given target set $T$, predicted set $P$ (after JSON parse), let
294
+ \begin{itemize}[leftmargin=*,nosep]
295
+ \item \code{num\_good} $= |P \cap T|$
296
+ \item \code{num\_bad} $= |P \setminus T|$
297
+ \item \code{num\_missing} $= \max(0, |T| - \text{num\_good})$
298
+ \item \code{is\_exact} $= (P \neq \varnothing) \land (P = T)$
299
+ \item $\mathrm{tri}(n) = n(n+1)/2$ (rewards larger correct sets superlinearly)
300
+ \end{itemize}
301
+
302
+ Then
303
+ \begin{align*}
304
+ r &= \mathrm{tri}(\text{num\_good}) \cdot R_g \;-\; \text{num\_bad} \cdot P_b \\
305
+ &\quad - \indic[P=\varnothing]\, P_e \;-\; \indic[|P|=1, |T|>1, i<2]\, P_s \\
306
+ &\quad - \text{num\_missing}\cdot P_m \;+\; \indic[\text{is\_exact}]\, B_x \\
307
+ &\quad - \indic[|P|<|T|, |T|>1]\, P_c
308
+ \end{align*}
309
+
310
+ with parameters (this is the recipe that produced the 0.58/0.68 latent
311
+ solve):
312
+
313
+ \begin{center}
314
+ \begin{tabular}{cllr}
315
+ \toprule
316
+ \textbf{symbol} & \textbf{flag} & \textbf{role} & \textbf{value} \\
317
+ \midrule
318
+ $R_g$ & \code{--reward\_good\_value} & per-correct-value reward (triangular shape) & 1.25 \\
319
+ $P_b$ & \code{--penalty\_bad\_value} & per-extra-wrong-value penalty & 1.0 \\
320
+ $P_{\text{mal}}$ & \code{--penalty\_malformed} & flat penalty if JSON parse fails & 4.0 \\
321
+ $P_e$ & \code{--penalty\_empty} & flat penalty if predicted set is empty & 0.5 \\
322
+ $P_s$ & \code{--penalty\_singleton} & only at stage$<$2: punishes singleton on multi-value targets & 1.5 \\
323
+ $P_m$ & \code{--penalty\_missing} & per-missing-value (recall pressure) --- \textbf{NEW} & \textbf{0.75} \\
324
+ $B_x$ & \code{--exact\_match\_bonus} & only when $P = T$ --- \textbf{NEW} & \textbf{2.0} \\
325
+ $P_c$ & \code{--cardinality\_mismatch\_penalty} & when $|P| < |T|$ and $|T|>1$ --- \textbf{NEW} & \textbf{1.0} \\
326
+ \bottomrule
327
+ \end{tabular}
328
+ \end{center}
329
+
330
+ Parse failures short-circuit to $r = -P_{\text{mal}}$ and zero per-cell
331
+ metrics.
332
+
333
+ \subsection{Why those three new terms exist (the breakthrough)}
334
+
335
+ Diagnosis: at the v3/v4 plateau, eval reported
336
+
337
+ \begin{lstlisting}
338
+ exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000
339
+ \end{lstlisting}
340
+
341
+ across all checkpoints. Per-cell exact and precision/recall were all
342
+ near 0.95 but the model \textbf{always predicted a single digit}
343
+ (\code{avg\_set\_size=1.000}). On a multi-value target $T=\{8,9\}$,
344
+ predicting $\{8\}$ keeps precision $=1.0$, recall $=0.5$ and yet
345
+ \code{exact\_set\_match}$=0$. Solve $= \text{exact\_set\_match}^N$ is
346
+ catastrophic in $N$ ($=20$), so even a small fraction of multi-value
347
+ cells killed it.
348
+
349
+ Without any of the new terms the optimum of $r$ on a multi-value cell
350
+ is trivially ``predict the singleton you are most confident about'' ---
351
+ there is no upside to enumerate the second value. The three new terms
352
+ close exactly that hole:
353
+
354
+ \begin{itemize}[leftmargin=*,nosep]
355
+ \item $P_m$ (\code{penalty\_missing}) directly penalises recall;
356
+ \item $B_x$ (\code{exact\_match\_bonus}) makes $P=T$ strictly dominate any singleton;
357
+ \item $P_c$ (\code{cardinality\_mismatch\_penalty}) is a flat hammer whenever $|P|<|T|$.
358
+ \end{itemize}
359
+
360
+ After these terms were added, GRPO on the latent variant moved solve
361
+ from $\sim 0.30$ to $\sim 0.58$ (100-puzzle eval) over $\sim 200$
362
+ steps. The same fix is what we ported back into the baseline pipeline
363
+ this evening (see \S 10).
364
+
365
+ \section{Multi-stage warm-baseline pipeline (the recipe that worked)}
366
+
367
+ Master script:
368
+ \code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}.
369
+
370
+ For each curriculum stage we run \textbf{three sub-phases in order}:
371
+
372
+ \begin{lstlisting}
373
+ [stage i]
374
+ (1) baseline warm SFT (no latent tokens, k=0, vanilla LM)
375
+ (2) latent SFT (k = i, latent mode = recurrent_hidden)
376
+ (3) latent GRPO (k = i)
377
+ \end{lstlisting}
378
+
379
+ \textbf{The warm baseline phase (1) is the trick that makes the
380
+ curriculum work.} At every stage transition the data distribution
381
+ changes ($i$ increases $\Rightarrow$ target sets shrink) and a new
382
+ latent slot appears. Doing a vanilla SFT on the new distribution first
383
+ lets the LM relearn the surface form on familiar parameters; THEN the
384
+ latent SFT adds the extra thought slot on top of an already-good policy.
385
+ When we tried to add a new latent slot directly on top of the previous
386
+ stage's GRPO checkpoint, training loss did NOT decrease.
387
+
388
+ Concrete LR schedule used for the champion run:
389
+
390
+ \begin{center}
391
+ \begin{tabular}{lllc}
392
+ \toprule
393
+ \textbf{phase} & \textbf{init from} & \textbf{LR} & \textbf{k} \\
394
+ \midrule
395
+ S1 baseline SFT & base Qwen & 2e-4 & 0 \\
396
+ S1 latent SFT & S1 baseline & 2e-4 & 1 \\
397
+ S1 latent GRPO & S1 latent SFT & 1e-6 & 1 \\
398
+ S2 baseline warm SFT & S1 GRPO & 5e-5 & 0 \\
399
+ S2 latent SFT & S2 baseline & 5e-5 & 2 \\
400
+ S2 latent GRPO & S2 latent SFT & 1e-6 & 2 \\
401
+ S3 baseline warm SFT & S2 GRPO & 5e-5 & 0 \\
402
+ S3 latent SFT & S3 baseline & 5e-5 $\rightarrow$ 1e-5 (champion) & 3 \\
403
+ S3 latent GRPO & S3 latent SFT & 5e-6 ($\beta=0$) & 3 \\
404
+ \bottomrule
405
+ \end{tabular}
406
+ \end{center}
407
+
408
+ Other shared knobs:
409
+
410
+ \begin{lstlisting}
411
+ LoRA: r=32 a=64 dropout=0.05 on q,k,v,o,gate,up,down
412
+ SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128
413
+ GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64
414
+ num_generations=4 beta=0.0 max_prompt_length=1024
415
+ max_completion_length=24
416
+ multi_value_oversample_factor=5, exact_match_bonus=2.0,
417
+ penalty_missing=0.75, cardinality_mismatch_penalty=1.0
418
+ \end{lstlisting}
419
+
420
+ \section{GRPO settings that mattered}
421
+
422
+ \begin{itemize}[leftmargin=*]
423
+ \item \textbf{$\beta = 0$.} The KL anchor was harmful in every sweep
424
+ where we tried $\beta>0$. \code{s3\_grpo\_kl04} ($\beta=0.04$) peaked
425
+ at solve $=0.625$ (40p) at step 100 and regressed to $0.525$ by step
426
+ 500.
427
+
428
+ \item \textbf{\code{num\_generations} $= 4$.} With \code{num\_generations}$=2$
429
+ we routinely saw \code{reward\_std}$=0$ (all sampled completions
430
+ identical $\Rightarrow$ no gradient). Bumping to 4 fixed it.
431
+
432
+ \item \textbf{Low LR.} \code{lr=5e-6} was the steadiest. \code{lr=1e-5}
433
+ peaked at step 200 (solve $0.65$) then collapsed back to $0.54$ ---
434
+ classic mode collapse.
435
+
436
+ \item \textbf{Effective bs $\geq 64$.} TRL's GRPOConfig requires
437
+ \code{eff\_bs * grad\_accum \% num\_generations == 0}; with 8 GPUs we
438
+ hit this trivially, but we caution single-GPU rerunners to set
439
+ \code{per\_device\_bs=4 grad\_accum=2 num\_generations=4}.
440
+
441
+ \item \textbf{\code{enable\_input\_require\_grads()} on the wrapped backbone.}
442
+ Required for TRL 0.15.x + PEFT LoRA + gradient checkpointing ---
443
+ otherwise the loss tensor produced by GRPOTrainer has
444
+ \code{requires\_grad=False} and \code{.backward()} raises. Also
445
+ \code{unwrapped.config.use\_cache = False}.
446
+ \end{itemize}
447
+
448
+ \section{Final hyperparameters table --- champion latent run}
449
+
450
+ \begin{center}
451
+ \begin{longtable}{lll}
452
+ \toprule
453
+ \textbf{group} & \textbf{hyperparameter} & \textbf{value} \\
454
+ \midrule
455
+ \endfirsthead
456
+ \toprule
457
+ \textbf{group} & \textbf{hyperparameter} & \textbf{value} \\
458
+ \midrule
459
+ \endhead
460
+ Backbone & model & Qwen/Qwen2.5-1.5B-Instruct \\
461
+ Backbone & dtype & bf16 \\
462
+ Backbone & LoRA target modules & q,k,v,o,gate,up,down \\
463
+ Backbone & LoRA $r$ / $\alpha$ / dropout & 32 / 64 / 0.05 \\
464
+ Latent & mode & \code{recurrent\_hidden} \\
465
+ Latent & \code{num\_cot\_tokens} (S1/S2/S3) & 1 / 2 / 3 \\
466
+ Latent & \code{max\_latent\_slots} / seeds & 8 / 8 \\
467
+ Data & total empties & 20 \\
468
+ Data & train rows / eval rows & 10\,000 / 100 \\
469
+ Data & \code{multi\_value\_oversample\_factor} & 5 \\
470
+ Data & \code{mixed\_stage1\_ratio} (S1) & 1 \\
471
+ Data & \code{mixed\_stage2\_ratio} (S$\geq 2$) & 1 \\
472
+ SFT & per\_device\_bs / grad\_accum & 8 / 2 \\
473
+ SFT & \code{num\_epochs} (cap) & 64 \\
474
+ SFT & LR (S1 latent) & 2e-4 \\
475
+ SFT & LR (S2/S3 baseline warm + latent) & 5e-5 \\
476
+ SFT & LR (S3 latent champion \code{s3b\_lr1e5\_o5}) & 1e-5 \\
477
+ SFT & weight\_decay & 0.0 \\
478
+ SFT & gradient checkpointing & on \\
479
+ GRPO & per\_device\_bs / grad\_accum & 4 / 2 \\
480
+ GRPO & \code{num\_generations} & 4 \\
481
+ GRPO & LR & 5e-6 (S3); 1e-6 (S1, S2) \\
482
+ GRPO & $\beta$ (KL) & 0.0 \\
483
+ GRPO & \code{max\_prompt\_length} & 1024 \\
484
+ GRPO & \code{max\_completion\_length} & 24 \\
485
+ Reward & \code{reward\_good\_value} & 1.25 \\
486
+ Reward & \code{penalty\_bad\_value} & 1.0 \\
487
+ Reward & \code{penalty\_malformed} & 4.0 \\
488
+ Reward & \code{penalty\_empty} & 0.5 \\
489
+ Reward & \code{penalty\_singleton} & 1.5 \\
490
+ Reward & \code{penalty\_missing} & 0.75 \\
491
+ Reward & \code{exact\_match\_bonus} & 2.0 \\
492
+ Reward & \code{cardinality\_mismatch\_penalty} & 1.0 \\
493
+ Eval & early-stop on prec/recall & 0.98 \\
494
+ \bottomrule
495
+ \end{longtable}
496
+ \end{center}
497
+
498
+ \section{Headline results}
499
+
500
+ \subsection{Latent (with thought tokens, \code{recurrent\_hidden})}
501
+
502
+ \begin{center}
503
+ \begin{tabular}{llrrrrr}
504
+ \toprule
505
+ \textbf{eval} & \textbf{model / phase} & \textbf{step} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} \\
506
+ \midrule
507
+ \textbf{100p (auth.)} & \code{s3\_grpo\_baseline} (S3 GRPO, $\beta=0$, lr=5e-6) & 200 & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580 (58/100)} \\
508
+ 40p & \code{s3\_grpo\_sharp\_rwd} ($B_x{=}4$, $P_c{=}3$) & 300 & --- & --- & --- & \textbf{0.675 (27/40)} \\
509
+ 40p & \code{s3\_grpo\_lr1e5} & 200 & 0.978 & 0.978 & 0.979 & 0.650 \\
510
+ 40p & \code{s3b\_lr1e5\_o5} (S3 SFT champion) & 2400 & 0.974 & 0.974 & 0.975 & 0.600 \\
511
+ \bottomrule
512
+ \end{tabular}
513
+ \end{center}
514
+
515
+ \subsection{Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)}
516
+
517
+ \begin{center}
518
+ \begin{tabular}{llrrr}
519
+ \toprule
520
+ \textbf{sweep} & \textbf{best variant} & \textbf{best step} & \textbf{exact} & \textbf{solve (100p)} \\
521
+ \midrule
522
+ v3 (single-GPU, no oversample, no new reward) & \code{baseline\_3stage\_20260522} & --- & 0.730 & \textbf{0.000} \\
523
+ v4 (LR sweep, multi-GPU, original reward) & \code{pipe\_v\_sft\_extend} (S3 SFT extended) & 4000 & 0.948 & \textbf{0.400} \\
524
+ \textbf{v6 (this evening; ports latent reward + oversample)} & \code{v6\_i\_sft\_v\_oversample10} & running & 0.952$+$ & \textbf{0.440 (best so far)} \\
525
+ \bottomrule
526
+ \end{tabular}
527
+ \end{center}
528
+
529
+ The v6 sweep is still running --- \code{v6\_e/f/i} are in S3 SFT
530
+ continuation, GRPO follow-on phases queued. The \code{v6\_i} variant
531
+ has hit \textbf{solve $=0.44$} at SFT eval (new baseline best,
532
+ $+0.04$ over v4) and is still climbing.
533
+
534
+ \subsection{Stage-by-stage trajectory (latent, 40-puzzle eval)}
535
+
536
+ \begin{lstlisting}
537
+ S1 SFT : exact ~ 0.85, solve ~ 0.20
538
+ S1 GRPO : exact ~ 0.90, solve ~ 0.20
539
+ S2 SFT (no oversample) : exact ~ 0.94, solve ~ 0.20-0.25 <- the wall
540
+ S2 SFT + multi_value_oversample=5 : exact ~ 0.96, solve ~ 0.30-0.35
541
+ S2 GRPO + new reward terms : exact ~ 0.96, solve ~ 0.35-0.40
542
+ S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion
543
+ S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion
544
+ S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak
545
+ \end{lstlisting}
546
+
547
+ \subsection{Latent vs baseline gap (head-to-head, same 100p eval, same prompts)}
548
+
549
+ \begin{center}
550
+ \begin{tabular}{lrrrrr}
551
+ \toprule
552
+ \textbf{model} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} & \textbf{solved/100} \\
553
+ \midrule
554
+ Latent \code{recurrent\_hidden}, S3 GRPO & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580} & 58 \\
555
+ Vanilla baseline, \code{v6\_i} (best at time of writing) & 0.952 & 0.952 & 0.952 & \textbf{0.440} & 44 \\
556
+ \bottomrule
557
+ \end{tabular}
558
+ \end{center}
559
+
560
+ Gap on 100-puzzle solve: $\approx$ \textbf{$+0.14$ absolute / $+32\%$
561
+ relative} for latent over the strongest baseline we have.
562
+
563
+ \section{Why the latent works (interpretation hypotheses)}
564
+
565
+ These are the working hypotheses the experiments are consistent with;
566
+ none is fully proven and ablations are still WIP.
567
+
568
+ \begin{enumerate}[leftmargin=*]
569
+ \item \textbf{Constraint-propagation depth.} Stage-3 i-consistency is
570
+ essentially 2-ply lookahead. With $k=3$ recurrent hidden tokens the
571
+ model gets exactly three extra forward passes between prompt and
572
+ output --- one for the legality check, one for 1-step propagation,
573
+ one for the second step of propagation. Empirically the gap to the
574
+ no-thought-token baseline appears at stages where multi-step
575
+ propagation matters (stage 2 onward; stage 1 numbers are essentially
576
+ identical).
577
+
578
+ \item \textbf{Multi-value cells require enumeration, which a singleton
579
+ softmax can't do in one forward pass.} A vanilla LM at 1.5B
580
+ parameters predicts essentially deterministically once temperature is
581
+ low; for a target set $\{8, 9\}$ the LM picks one of the two and
582
+ stops. The latent model can use one of the recurrent hidden steps to
583
+ ``consider'' each option without committing yet, which is exactly
584
+ the failure mode in the data (\code{avg\_set\_size} $= 1.000$ for the
585
+ baseline, $\approx 1.05$ for the latent S3 model on the same eval).
586
+
587
+ \item \textbf{Stable curriculum capacity growth.} Adding a new latent
588
+ slot at every stage gives the model a ``fresh slate'' of
589
+ representational capacity at the exact transition where the task
590
+ gets harder. The warm-baseline SFT between stages prevents the new
591
+ slot from corrupting the previously learned policy. Without warm
592
+ baseline, training loss did not decrease at all (we observed this
593
+ directly when we tried to skip the warm baseline).
594
+
595
+ \item \textbf{GRPO without latent slots is starved of variance.} With
596
+ \code{max\_completion\_length} 24 and the model essentially
597
+ deterministic, GRPO's 4 sampled completions per prompt collapse to a
598
+ single answer --- \code{reward\_std}$=0$, no gradient. With latent
599
+ recurrence + the new \code{exact\_match\_bonus} reward, the model
600
+ occasionally samples a 2-element set, gets a much higher reward, and
601
+ that prompt gets a real gradient signal.
602
+ \end{enumerate}
603
+
604
+ \section{Reproducibility}
605
+
606
+ \noindent
607
+ Code repository: \url{https://github.com/Avra98/curriculum_cot} \\
608
+ Latent checkpoints: \url{https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages} \\
609
+ Baseline checkpoints: \url{https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep}
610
+
611
+ Key scripts:
612
+
613
+ \begin{itemize}[leftmargin=*,nosep]
614
+ \item Master orchestrator (latent, 9-phase warm-baseline pipeline):
615
+ \code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}
616
+ \item Vanilla baseline pipeline:
617
+ \code{\_runs/baseline\_1p5b\_pipeline\_v4.sh} (with v6 launchers
618
+ \code{\_runs/launch\_baseline\_push\_v6.sh})
619
+ \item SFT trainer (vanilla):
620
+ \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py}
621
+ \item GRPO trainer (vanilla):
622
+ \code{multi\_output\_cell\_policy/grpo\_multi\_output\_train.py}
623
+ \item SFT trainer (latent):
624
+ \code{latent\_multi\_output\_cell\_policy/sft\_latent\_multi\_output\_train.py}
625
+ \item GRPO trainer (latent):
626
+ \code{latent\_multi\_output\_cell\_policy/grpo\_residual\_projector\_latent\_train.py}
627
+ \item Reward function: \code{multi\_output\_cell\_policy/rewards.py}
628
+ \item Prompt builder: \code{multi\_output\_cell\_policy/prompt\_builder.py}
629
+ \item Stage-i consistency:
630
+ \code{multi\_output\_cell\_policy/shared\_multi\_output\_policy.py}
631
+ \item 100-puzzle evaluator: \code{analysis/eval\_stage2\_checkpoint.py}
632
+ \end{itemize}
633
+
634
+ To reproduce the latent champion (1.5B, 9-phase, $\sim 16$ GPU$\cdot$h
635
+ on $8\times$H100 80GB):
636
+
637
+ \begin{lstlisting}
638
+ export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter
639
+ bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
640
+ \end{lstlisting}
641
+
642
+ To reproduce the v6 baseline push (single-GPU per variant, $\sim 6$
643
+ GPU$\cdot$h):
644
+
645
+ \begin{lstlisting}
646
+ bash _runs/launch_baseline_push_v6.sh
647
+ \end{lstlisting}
648
+
649
+ \appendix
650
+
651
+ \section{The reward fix as a one-line patch}
652
+
653
+ The single most consequential code change in this whole project, as a
654
+ self-contained patch on \code{multi\_output\_cell\_policy/rewards.py}:
655
+
656
+ \begin{lstlisting}[language=Python]
657
+ # new args (default 0 preserves legacy behaviour)
658
+ penalty_missing: float = 0.0
659
+ exact_match_bonus: float = 0.0
660
+ cardinality_mismatch_penalty: float = 0.0
661
+
662
+ num_missing = max(0, len(target_set) - num_good)
663
+ is_exact = bool(predicted_values) and (set(predicted_values) == target_set)
664
+
665
+ # ... base reward (triangular_number(num_good)*reward_good_value
666
+ # - num_bad*penalty_bad_value)
667
+
668
+ if num_missing > 0:
669
+ reward -= num_missing * penalty_missing
670
+ if is_exact:
671
+ reward += exact_match_bonus
672
+ if len(predicted_values) < len(target_values) and len(target_values) > 1:
673
+ reward -= cardinality_mismatch_penalty
674
+ \end{lstlisting}
675
+
676
+ Defaults are zero so old runs are unaffected; the recipe sets
677
+ $(P_m, B_x, P_c) = (0.75, 2.0, 1.0)$ for the vanilla recipe and
678
+ $(1.0, 4.0, 3.0)$ for the ``sharp\_rwd'' variant.
679
+
680
+ \section{The warm-baseline trick as a sequence diagram}
681
+
682
+ \begin{lstlisting}
683
+ Stage 1 Stage 2 Stage 3
684
+ --------- --------- ---------
685
+ [base Qwen] | |
686
+ | | |
687
+ v v v
688
+ S1 baseline SFT -> S2 baseline SFT -> S3 baseline SFT
689
+ (no latent, k=0) (no latent, k=0) (no latent, k=0)
690
+ | | |
691
+ v v v
692
+ S1 latent SFT -> S2 latent SFT -> S3 latent SFT
693
+ (k=1) (k=2) (k=3)
694
+ | | |
695
+ v v v
696
+ S1 latent GRPO -> S2 latent GRPO -> S3 latent GRPO
697
+ (k=1, b=0, lr 1e-6) (k=2, b=0) (k=3, b=0, lr 5e-6)
698
+ | | |
699
+ [final policy]
700
+ \end{lstlisting}
701
+
702
+ Every arrow is \code{init\_adapter\_dir = <previous output>}. Each row
703
+ is a ``slot in the curriculum''; the column adds reasoning capacity
704
+ ($k\mathrel{+}=1$) and moves to a harder target distribution
705
+ ($i\mathrel{+}=1$). The diagonal across the diagram is the actual
706
+ training trajectory.
707
+
708
+ \bigskip
709
+ \noindent\emph{End of report.}
710
+
711
+ \end{document}
addition/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Addition Carry Experiment
2
+
3
+ This folder contains a standalone PyTorch experiment for algorithmic addition with carry on a one-layer decoder-only Transformer.
4
+
5
+ The comparison includes exactly three methods:
6
+
7
+ - `nocurr_nocot`: no curriculum, no latent chain-of-thought
8
+ - `curr_nocot`: digit-length curriculum, no latent chain-of-thought
9
+ - `curr_cot`: same one-layer backbone plus recurrent latent scratchpad tokens
10
+
11
+ ## Task
12
+
13
+ Each example adds two reversed digit sequences in a configurable radix. Stage `k` means only the first `k` least-significant positions vary and the rest are zero. Every method now trains on the full example in one forward pass:
14
+
15
+ - predict all `k` active sum digits
16
+ - predict the final carry bit as an additional output slot
17
+ - compute masked loss over the active digits plus the final carry
18
+
19
+ This means the baseline and both curriculum variants learn whole-example addition rather than a single queried digit at a time. Internal carry targets are still kept for diagnostics and linear probing, but not as an auxiliary training loss.
20
+
21
+ The latent method reuses the same one-layer Transformer recurrently. After an initial pass over the inputs and output slots, the model appends continuous latent scratchpad tokens before the output slots and reruns the same layer, giving later curriculum stages more internal workspace for carry-like computation.
22
+
23
+ ## Files
24
+
25
+ - `config.py`: experiment config and CLI handling
26
+ - `data.py`: synthetic data generation, curriculum stages, carry-heavy subsets
27
+ - `model.py`: one-layer decoder-only Transformer and latent recurrence
28
+ - `train.py`: single-run training entrypoint
29
+ - `eval.py`: evaluation and diagnostics
30
+ - `plots.py`: local plotting
31
+ - `run_comparison.py`: multi-seed comparison across all three methods
32
+
33
+ ## Outputs
34
+
35
+ Each run writes:
36
+
37
+ - `config.json`
38
+ - `artifacts/history.jsonl`
39
+ - `artifacts/summary.json`
40
+ - `checkpoints/best.pt`
41
+ - `checkpoints/last.pt`
42
+ - local plots under `plots/`
43
+
44
+ If W&B is enabled, the same run also logs metrics there.
45
+
46
+ ## Run A Single Method
47
+
48
+ Default settings:
49
+
50
+ ```bash
51
+ python addition/train.py --model nocurr_nocot --use_wandb
52
+ python addition/train.py --model curr_nocot --use_wandb
53
+ python addition/train.py --model curr_cot --use_wandb
54
+ ```
55
+
56
+ The default backbone now uses a single attention head. To run a harder hexadecimal setting:
57
+
58
+ ```bash
59
+ python addition/train.py --model curr_cot --radix 16 --use_wandb --output_dir addition_runs/hex_curr_cot
60
+ ```
61
+
62
+ Run offline or local-only:
63
+
64
+ ```bash
65
+ python addition/train.py --model curr_cot --wandb_mode offline
66
+ python addition/train.py --model curr_cot --no_wandb
67
+ ```
68
+
69
+ ## Smoke Test
70
+
71
+ Use the smoke preset to verify the whole pipeline quickly:
72
+
73
+ ```bash
74
+ python addition/train.py --model curr_cot --preset smoke --no_wandb --output_dir addition_runs/smoke_curr_cot
75
+ ```
76
+
77
+ ## Run The Full Comparison
78
+
79
+ This runs all three methods across multiple seeds and saves aggregate plots and JSON:
80
+
81
+ ```bash
82
+ python addition/run_comparison.py --preset default --use_wandb --comparison_output_dir addition_runs/comparison_default
83
+ ```
84
+
85
+ Small fast comparison:
86
+
87
+ ```bash
88
+ python addition/run_comparison.py --preset smoke --no_wandb --comparison_output_dir addition_runs/comparison_smoke
89
+ ```
90
+
91
+ ## Main Metrics
92
+
93
+ The experiment reports:
94
+
95
+ - digit accuracy by output position
96
+ - final-carry accuracy
97
+ - exact whole-sum accuracy by active length
98
+ - average digit accuracy by length
99
+ - in-distribution results up to `train_max_digits`
100
+ - OOD results on longer lengths
101
+ - separate uniform and carry-heavy evaluations
102
+
103
+ ## Diagnostics
104
+
105
+ The evaluation also includes:
106
+
107
+ - a linear probe on output-slot hidden states for carry prediction
108
+ - attention summaries showing how strongly the final carry readout attends to operand digits, previous output slots, and latent tokens
109
+
110
+ ## Notes
111
+
112
+ - The first version is intentionally small enough to iterate locally.
113
+ - The backbone depth stays fixed at one layer in all methods.
114
+ - The latent method gets more recurrent compute, not more layers.
addition/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Standalone addition-with-carry experiment package."""
addition/config.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import dataclasses
5
+ import json
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Iterable
9
+
10
+ import torch
11
+
12
+
13
+ VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot")
14
+ VALID_PRESETS = ("default", "smoke")
15
+
16
+
17
+ @dataclass
18
+ class ExperimentConfig:
19
+ model: str = "nocurr_nocot"
20
+ output_dir: str = "addition_runs/default"
21
+ seed: int = 0
22
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
23
+ preset: str = "default"
24
+ run_name: str = ""
25
+ notes: str = ""
26
+ use_wandb: bool = True
27
+ wandb_project: str = "addition-carry"
28
+ wandb_entity: str = ""
29
+ wandb_mode: str = "online"
30
+ radix: int = 10
31
+ train_max_digits: int = 12
32
+ eval_max_digits: int = 20
33
+ ood_lengths: tuple[int, ...] = (14, 16, 20)
34
+ train_batch_size: int = 256
35
+ eval_batch_size: int = 512
36
+ learning_rate: float = 3e-4
37
+ weight_decay: float = 1e-2
38
+ grad_clip_norm: float = 1.0
39
+ carry_loss_weight: float = 0.0
40
+ train_steps: int = 3600
41
+ max_steps_per_stage: int = 300
42
+ validation_interval: int = 100
43
+ stage_accuracy_threshold: float = 0.99
44
+ initial_stage: int = 1
45
+ eval_examples_per_length: int = 256
46
+ carry_heavy_examples_per_length: int = 256
47
+ train_carry_heavy_prob: float = 0.15
48
+ d_model: int = 512
49
+ n_heads: int = 1
50
+ ff_dim: int = 2048
51
+ dropout: float = 0.0
52
+ max_latent_steps: int = 12
53
+ attention_probe_examples: int = 256
54
+ linear_probe_epochs: int = 150
55
+ linear_probe_lr: float = 1e-2
56
+ comparison_num_seeds: int = 5
57
+
58
+ def __post_init__(self) -> None:
59
+ if self.model not in VALID_MODELS:
60
+ raise ValueError(f"Unsupported model: {self.model}")
61
+ if self.preset not in VALID_PRESETS:
62
+ raise ValueError(f"Unsupported preset: {self.preset}")
63
+ if self.train_max_digits > self.eval_max_digits:
64
+ raise ValueError("train_max_digits must be <= eval_max_digits")
65
+ if self.max_latent_steps < 0:
66
+ raise ValueError("max_latent_steps must be non-negative")
67
+ if self.radix < 2 or self.radix > 16:
68
+ raise ValueError("radix must be between 2 and 16")
69
+ if self.initial_stage < 1 or self.initial_stage > self.train_max_digits:
70
+ raise ValueError("initial_stage must be between 1 and train_max_digits")
71
+ self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits)
72
+ if not self.ood_lengths:
73
+ self.ood_lengths = (self.eval_max_digits,)
74
+
75
+ @property
76
+ def uses_curriculum(self) -> bool:
77
+ return self.model in {"curr_nocot", "curr_cot"}
78
+
79
+ @property
80
+ def uses_latent_cot(self) -> bool:
81
+ return self.model == "curr_cot"
82
+
83
+ @property
84
+ def discrete_vocab_size(self) -> int:
85
+ return self.radix + 2
86
+
87
+ @property
88
+ def digit_vocab_size(self) -> int:
89
+ return self.radix
90
+
91
+ @property
92
+ def input_sequence_length(self) -> int:
93
+ return self.input_sequence_length_for_digits(self.eval_max_digits)
94
+
95
+ @property
96
+ def output_sequence_length(self) -> int:
97
+ return self.output_sequence_length_for_digits(self.eval_max_digits)
98
+
99
+ @property
100
+ def base_sequence_length(self) -> int:
101
+ return self.base_sequence_length_for_digits(self.eval_max_digits)
102
+
103
+ @property
104
+ def max_sequence_length(self) -> int:
105
+ return self.base_sequence_length + self.max_latent_steps
106
+
107
+ @property
108
+ def effective_run_name(self) -> str:
109
+ if self.run_name:
110
+ return self.run_name
111
+ return f"{self.model}_base{self.radix}_seed{self.seed}"
112
+
113
+ def input_sequence_length_for_digits(self, active_digits: int) -> int:
114
+ return (int(active_digits) * 2) + 2
115
+
116
+ def output_sequence_length_for_digits(self, active_digits: int) -> int:
117
+ return int(active_digits) + 1
118
+
119
+ def base_sequence_length_for_digits(self, active_digits: int) -> int:
120
+ return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits)
121
+
122
+ def latent_steps_for_stage(self, stage: int) -> int:
123
+ if not self.uses_latent_cot:
124
+ return 0
125
+ return max(0, min(int(stage), int(self.max_latent_steps)))
126
+
127
+
128
+ def default_output_root() -> Path:
129
+ return Path("addition_runs")
130
+
131
+
132
+ def apply_preset(config: ExperimentConfig) -> ExperimentConfig:
133
+ config = dataclasses.replace(config)
134
+ if config.preset == "smoke":
135
+ config.output_dir = config.output_dir or str(default_output_root() / "smoke")
136
+ config.train_batch_size = 64
137
+ config.eval_batch_size = 128
138
+ config.d_model = 128
139
+ config.n_heads = 1
140
+ config.ff_dim = 512
141
+ config.train_steps = 180
142
+ config.max_steps_per_stage = 40
143
+ config.validation_interval = 20
144
+ config.eval_examples_per_length = 64
145
+ config.carry_heavy_examples_per_length = 64
146
+ config.attention_probe_examples = 64
147
+ config.linear_probe_epochs = 60
148
+ config.comparison_num_seeds = 2
149
+ return config
150
+
151
+
152
+ def config_to_dict(config: ExperimentConfig) -> dict:
153
+ data = dataclasses.asdict(config)
154
+ data["ood_lengths"] = list(config.ood_lengths)
155
+ data["uses_curriculum"] = config.uses_curriculum
156
+ data["uses_latent_cot"] = config.uses_latent_cot
157
+ data["discrete_vocab_size"] = config.discrete_vocab_size
158
+ data["input_sequence_length"] = config.input_sequence_length
159
+ data["output_sequence_length"] = config.output_sequence_length
160
+ data["base_sequence_length"] = config.base_sequence_length
161
+ data["max_sequence_length"] = config.max_sequence_length
162
+ data["effective_run_name"] = config.effective_run_name
163
+ return data
164
+
165
+
166
+ def save_config(config: ExperimentConfig, output_dir: Path) -> None:
167
+ output_dir.mkdir(parents=True, exist_ok=True)
168
+ with (output_dir / "config.json").open("w", encoding="utf-8") as handle:
169
+ json.dump(config_to_dict(config), handle, indent=2, sort_keys=True)
170
+
171
+
172
+ def add_config_arguments(parser: argparse.ArgumentParser) -> None:
173
+ parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot")
174
+ parser.add_argument("--output_dir", type=str, default="")
175
+ parser.add_argument("--seed", type=int, default=0)
176
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
177
+ parser.add_argument("--preset", choices=VALID_PRESETS, default="default")
178
+ parser.add_argument("--run_name", type=str, default="")
179
+ parser.add_argument("--notes", type=str, default="")
180
+ parser.add_argument("--use_wandb", action="store_true")
181
+ parser.add_argument("--no_wandb", action="store_true")
182
+ parser.add_argument("--wandb_project", type=str, default="addition-carry")
183
+ parser.add_argument("--wandb_entity", type=str, default="")
184
+ parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled"))
185
+ parser.add_argument("--radix", type=int, default=10)
186
+ parser.add_argument("--train_max_digits", type=int, default=12)
187
+ parser.add_argument("--eval_max_digits", type=int, default=20)
188
+ parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20])
189
+ parser.add_argument("--train_batch_size", type=int, default=256)
190
+ parser.add_argument("--eval_batch_size", type=int, default=512)
191
+ parser.add_argument("--learning_rate", type=float, default=3e-4)
192
+ parser.add_argument("--weight_decay", type=float, default=1e-2)
193
+ parser.add_argument("--grad_clip_norm", type=float, default=1.0)
194
+ parser.add_argument("--carry_loss_weight", type=float, default=0.0)
195
+ parser.add_argument("--train_steps", type=int, default=3600)
196
+ parser.add_argument("--max_steps_per_stage", type=int, default=300)
197
+ parser.add_argument("--validation_interval", type=int, default=100)
198
+ parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99)
199
+ parser.add_argument("--initial_stage", type=int, default=1)
200
+ parser.add_argument("--eval_examples_per_length", type=int, default=256)
201
+ parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256)
202
+ parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15)
203
+ parser.add_argument("--d_model", type=int, default=512)
204
+ parser.add_argument("--n_heads", type=int, default=1)
205
+ parser.add_argument("--ff_dim", type=int, default=2048)
206
+ parser.add_argument("--dropout", type=float, default=0.0)
207
+ parser.add_argument("--max_latent_steps", type=int, default=12)
208
+ parser.add_argument("--attention_probe_examples", type=int, default=256)
209
+ parser.add_argument("--linear_probe_epochs", type=int, default=150)
210
+ parser.add_argument("--linear_probe_lr", type=float, default=1e-2)
211
+ parser.add_argument("--comparison_num_seeds", type=int, default=5)
212
+
213
+
214
+ def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig:
215
+ use_wandb = bool(args.use_wandb or not args.no_wandb)
216
+ if args.wandb_mode == "disabled":
217
+ use_wandb = False
218
+ output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}")
219
+ config = ExperimentConfig(
220
+ model=args.model,
221
+ output_dir=output_dir,
222
+ seed=args.seed,
223
+ device=args.device,
224
+ preset=args.preset,
225
+ run_name=args.run_name,
226
+ notes=args.notes,
227
+ use_wandb=use_wandb,
228
+ wandb_project=args.wandb_project,
229
+ wandb_entity=args.wandb_entity,
230
+ wandb_mode=args.wandb_mode,
231
+ radix=args.radix,
232
+ train_max_digits=args.train_max_digits,
233
+ eval_max_digits=args.eval_max_digits,
234
+ ood_lengths=tuple(args.ood_lengths),
235
+ train_batch_size=args.train_batch_size,
236
+ eval_batch_size=args.eval_batch_size,
237
+ learning_rate=args.learning_rate,
238
+ weight_decay=args.weight_decay,
239
+ grad_clip_norm=args.grad_clip_norm,
240
+ carry_loss_weight=args.carry_loss_weight,
241
+ train_steps=args.train_steps,
242
+ max_steps_per_stage=args.max_steps_per_stage,
243
+ validation_interval=args.validation_interval,
244
+ stage_accuracy_threshold=args.stage_accuracy_threshold,
245
+ initial_stage=args.initial_stage,
246
+ eval_examples_per_length=args.eval_examples_per_length,
247
+ carry_heavy_examples_per_length=args.carry_heavy_examples_per_length,
248
+ train_carry_heavy_prob=args.train_carry_heavy_prob,
249
+ d_model=args.d_model,
250
+ n_heads=args.n_heads,
251
+ ff_dim=args.ff_dim,
252
+ dropout=args.dropout,
253
+ max_latent_steps=args.max_latent_steps,
254
+ attention_probe_examples=args.attention_probe_examples,
255
+ linear_probe_epochs=args.linear_probe_epochs,
256
+ linear_probe_lr=args.linear_probe_lr,
257
+ comparison_num_seeds=args.comparison_num_seeds,
258
+ )
259
+ return apply_preset(config)
260
+
261
+
262
+ def build_arg_parser(description: str) -> argparse.ArgumentParser:
263
+ parser = argparse.ArgumentParser(description=description)
264
+ add_config_arguments(parser)
265
+ return parser
266
+
267
+
268
+ def parse_config(description: str) -> ExperimentConfig:
269
+ parser = build_arg_parser(description)
270
+ args = parser.parse_args()
271
+ return build_config_from_args(args)
272
+
273
+
274
+ def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]:
275
+ root = Path(config.output_dir)
276
+ directories = {
277
+ "root": root,
278
+ "checkpoints": root / "checkpoints",
279
+ "stage_checkpoints": root / "checkpoints" / "stages",
280
+ "plots": root / "plots",
281
+ "artifacts": root / "artifacts",
282
+ }
283
+ for directory in directories.values():
284
+ directory.mkdir(parents=True, exist_ok=True)
285
+ return directories
286
+
287
+
288
+ def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]:
289
+ return {f"{prefix}{key}": value for key, value in metrics.items()}
290
+
291
+
292
+ def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]:
293
+ for stage in range(1, config.train_max_digits + 1):
294
+ yield stage
addition/data.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import math
5
+ import random
6
+ from dataclasses import dataclass
7
+ from typing import Iterable
8
+
9
+ import torch
10
+
11
+ from addition.config import ExperimentConfig
12
+
13
+
14
+ DIGIT_OFFSET = 0
15
+ DEFAULT_SYMBOLS = "0123456789ABCDEF"
16
+
17
+
18
+ @dataclass
19
+ class AdditionProblem:
20
+ a_digits: list[int]
21
+ b_digits: list[int]
22
+ sum_digits: list[int]
23
+ carry_out: list[int]
24
+ active_digits: int
25
+ is_carry_heavy: bool
26
+
27
+
28
+ @dataclass
29
+ class Batch:
30
+ input_ids: torch.Tensor
31
+ target_digits: torch.Tensor
32
+ target_digit_mask: torch.Tensor
33
+ target_carry: torch.Tensor
34
+ target_final_carry: torch.Tensor
35
+ active_digits: torch.Tensor
36
+ is_carry_heavy: torch.Tensor
37
+
38
+
39
+ @dataclass
40
+ class EvaluationSuite:
41
+ validation_uniform: dict[int, list[AdditionProblem]]
42
+ test_uniform: dict[int, list[AdditionProblem]]
43
+ test_carry_heavy: dict[int, list[AdditionProblem]]
44
+
45
+
46
+ def a_token_id(radix: int) -> int:
47
+ return radix
48
+
49
+
50
+ def b_token_id(radix: int) -> int:
51
+ return radix + 1
52
+
53
+
54
+ def seed_everything(seed: int) -> None:
55
+ random.seed(seed)
56
+ torch.manual_seed(seed)
57
+ if torch.cuda.is_available():
58
+ torch.cuda.manual_seed_all(seed)
59
+
60
+
61
+ def compute_sum_and_carry(a_digits: list[int], b_digits: list[int], radix: int) -> tuple[list[int], list[int]]:
62
+ sum_digits: list[int] = []
63
+ carry_out: list[int] = []
64
+ carry = 0
65
+ for a_digit, b_digit in zip(a_digits, b_digits):
66
+ total = int(a_digit) + int(b_digit) + carry
67
+ sum_digits.append(total % radix)
68
+ carry = total // radix
69
+ carry_out.append(carry)
70
+ return sum_digits, carry_out
71
+
72
+
73
+ def sample_uniform_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
74
+ a_digits = [0] * max_digits
75
+ b_digits = [0] * max_digits
76
+ for index in range(active_digits):
77
+ a_digits[index] = rng.randint(0, radix - 1)
78
+ b_digits[index] = rng.randint(0, radix - 1)
79
+ sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
80
+ return AdditionProblem(
81
+ a_digits=a_digits,
82
+ b_digits=b_digits,
83
+ sum_digits=sum_digits,
84
+ carry_out=carry_out,
85
+ active_digits=active_digits,
86
+ is_carry_heavy=False,
87
+ )
88
+
89
+
90
+ def sample_carry_heavy_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
91
+ a_digits = [0] * max_digits
92
+ b_digits = [0] * max_digits
93
+ carry = 0
94
+ for index in range(active_digits):
95
+ high_floor = max(0, radix // 2)
96
+ a_digit = rng.randint(high_floor, radix - 1)
97
+ if carry == 0:
98
+ min_b = max(0, radix - a_digit)
99
+ else:
100
+ min_b = max(0, (radix - 1) - a_digit)
101
+ b_digit = rng.randint(min_b, radix - 1)
102
+ a_digits[index] = a_digit
103
+ b_digits[index] = b_digit
104
+ total = a_digit + b_digit + carry
105
+ carry = total // radix
106
+ sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
107
+ return AdditionProblem(
108
+ a_digits=a_digits,
109
+ b_digits=b_digits,
110
+ sum_digits=sum_digits,
111
+ carry_out=carry_out,
112
+ active_digits=active_digits,
113
+ is_carry_heavy=True,
114
+ )
115
+
116
+
117
+ def sample_problem(
118
+ max_digits: int,
119
+ active_digits: int,
120
+ radix: int,
121
+ rng: random.Random,
122
+ carry_heavy: bool = False,
123
+ ) -> AdditionProblem:
124
+ if carry_heavy:
125
+ return sample_carry_heavy_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
126
+ return sample_uniform_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
127
+
128
+
129
+ def encode_problem_tokens(problem: AdditionProblem, radix: int) -> list[int]:
130
+ return (
131
+ [a_token_id(radix)]
132
+ + [DIGIT_OFFSET + digit for digit in problem.a_digits[: problem.active_digits]]
133
+ + [b_token_id(radix)]
134
+ + [DIGIT_OFFSET + digit for digit in problem.b_digits[: problem.active_digits]]
135
+ )
136
+
137
+
138
+ def build_batch(
139
+ problems: list[AdditionProblem],
140
+ radix: int,
141
+ device: str,
142
+ ) -> Batch:
143
+ active_digits = problems[0].active_digits if problems else 0
144
+ input_ids = torch.tensor(
145
+ [
146
+ encode_problem_tokens(problem=problem, radix=radix)
147
+ for problem in problems
148
+ ],
149
+ dtype=torch.long,
150
+ device=device,
151
+ )
152
+ target_digits = torch.tensor(
153
+ [problem.sum_digits[:active_digits] for problem in problems],
154
+ dtype=torch.long,
155
+ device=device,
156
+ )
157
+ target_digit_mask = torch.tensor(
158
+ [[1] * active_digits for _ in problems],
159
+ dtype=torch.bool,
160
+ device=device,
161
+ )
162
+ target_carry = torch.tensor(
163
+ [problem.carry_out[:active_digits] for problem in problems],
164
+ dtype=torch.long,
165
+ device=device,
166
+ )
167
+ target_final_carry = torch.tensor(
168
+ [problem.carry_out[problem.active_digits - 1] for problem in problems],
169
+ dtype=torch.long,
170
+ device=device,
171
+ )
172
+ return Batch(
173
+ input_ids=input_ids,
174
+ target_digits=target_digits,
175
+ target_digit_mask=target_digit_mask,
176
+ target_carry=target_carry,
177
+ target_final_carry=target_final_carry,
178
+ active_digits=torch.tensor([problem.active_digits for problem in problems], dtype=torch.long, device=device),
179
+ is_carry_heavy=torch.tensor([int(problem.is_carry_heavy) for problem in problems], dtype=torch.bool, device=device),
180
+ )
181
+
182
+
183
+ def sample_training_batch(
184
+ config: ExperimentConfig,
185
+ stage: int,
186
+ rng: random.Random,
187
+ device: str,
188
+ ) -> Batch:
189
+ problems: list[AdditionProblem] = []
190
+ for _ in range(config.train_batch_size):
191
+ carry_heavy = rng.random() < config.train_carry_heavy_prob
192
+ problem = sample_problem(
193
+ max_digits=stage,
194
+ active_digits=stage,
195
+ radix=config.radix,
196
+ rng=rng,
197
+ carry_heavy=carry_heavy,
198
+ )
199
+ problems.append(problem)
200
+ return build_batch(
201
+ problems=problems,
202
+ radix=config.radix,
203
+ device=device,
204
+ )
205
+
206
+
207
+ def build_problem_set(
208
+ *,
209
+ max_digits: int,
210
+ active_digits: int,
211
+ radix: int,
212
+ count: int,
213
+ seed: int,
214
+ carry_heavy: bool,
215
+ ) -> list[AdditionProblem]:
216
+ rng = random.Random(seed)
217
+ return [
218
+ sample_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng, carry_heavy=carry_heavy)
219
+ for _ in range(count)
220
+ ]
221
+
222
+
223
+ def build_evaluation_suite(config: ExperimentConfig) -> EvaluationSuite:
224
+ validation_uniform: dict[int, list[AdditionProblem]] = {}
225
+ test_uniform: dict[int, list[AdditionProblem]] = {}
226
+ test_carry_heavy: dict[int, list[AdditionProblem]] = {}
227
+ all_lengths = sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
228
+ for length in all_lengths:
229
+ validation_uniform[length] = build_problem_set(
230
+ max_digits=length,
231
+ active_digits=length,
232
+ radix=config.radix,
233
+ count=config.eval_examples_per_length,
234
+ seed=10_000 + length,
235
+ carry_heavy=False,
236
+ )
237
+ test_uniform[length] = build_problem_set(
238
+ max_digits=length,
239
+ active_digits=length,
240
+ radix=config.radix,
241
+ count=config.eval_examples_per_length,
242
+ seed=20_000 + length,
243
+ carry_heavy=False,
244
+ )
245
+ test_carry_heavy[length] = build_problem_set(
246
+ max_digits=length,
247
+ active_digits=length,
248
+ radix=config.radix,
249
+ count=config.carry_heavy_examples_per_length,
250
+ seed=30_000 + length,
251
+ carry_heavy=True,
252
+ )
253
+ return EvaluationSuite(
254
+ validation_uniform=validation_uniform,
255
+ test_uniform=test_uniform,
256
+ test_carry_heavy=test_carry_heavy,
257
+ )
258
+
259
+
260
+ def digits_to_string(digits: Iterable[int], final_carry: int, radix: int) -> str:
261
+ digits = list(digits)
262
+ significant_digits = list(digits)
263
+ if final_carry:
264
+ significant_digits.append(final_carry)
265
+ while len(significant_digits) > 1 and significant_digits[-1] == 0:
266
+ significant_digits.pop()
267
+ symbols = DEFAULT_SYMBOLS[:radix]
268
+ return "".join(symbols[digit] for digit in reversed(significant_digits))
269
+
270
+
271
+ def value_from_digits(digits: Iterable[int], final_carry: int, radix: int) -> int:
272
+ value = 0
273
+ place = 1
274
+ for digit in digits:
275
+ value += int(digit) * place
276
+ place *= radix
277
+ if final_carry:
278
+ value += int(final_carry) * place
279
+ return value
280
+
281
+
282
+ def exact_sum_matches(
283
+ predicted_digits: list[int],
284
+ predicted_final_carry: int,
285
+ truth_digits: list[int],
286
+ truth_final_carry: int,
287
+ ) -> bool:
288
+ return predicted_digits == truth_digits and int(predicted_final_carry) == int(truth_final_carry)
289
+
290
+
291
+ def summarize_problem(problem: AdditionProblem, radix: int) -> dict[str, int | str]:
292
+ final_carry = problem.carry_out[problem.active_digits - 1]
293
+ return {
294
+ "a": digits_to_string(problem.a_digits[: problem.active_digits], final_carry=0, radix=radix),
295
+ "b": digits_to_string(problem.b_digits[: problem.active_digits], final_carry=0, radix=radix),
296
+ "sum": digits_to_string(problem.sum_digits[: problem.active_digits], final_carry=final_carry, radix=radix),
297
+ "radix": radix,
298
+ "active_digits": problem.active_digits,
299
+ "carry_heavy": int(problem.is_carry_heavy),
300
+ }
301
+
302
+
303
+ def count_carry_chain(problem: AdditionProblem) -> int:
304
+ longest = 0
305
+ current = 0
306
+ for index in range(problem.active_digits):
307
+ if problem.carry_out[index]:
308
+ current += 1
309
+ longest = max(longest, current)
310
+ else:
311
+ current = 0
312
+ return longest
313
+
314
+
315
+ def carry_density(problem: AdditionProblem) -> float:
316
+ if problem.active_digits <= 0:
317
+ return 0.0
318
+ return float(sum(problem.carry_out[: problem.active_digits])) / float(problem.active_digits)
319
+
320
+
321
+ def curriculum_stage_lengths(config: ExperimentConfig) -> list[int]:
322
+ if config.uses_curriculum:
323
+ return list(range(1, config.train_max_digits + 1))
324
+ return [config.train_max_digits]
325
+
326
+
327
+ def infer_eval_lengths(config: ExperimentConfig) -> list[int]:
328
+ return sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
329
+
330
+
331
+ def estimate_train_tokens_per_step(config: ExperimentConfig, stage: int) -> int:
332
+ latent_steps = config.latent_steps_for_stage(stage)
333
+ return config.train_batch_size * (config.base_sequence_length_for_digits(stage) + latent_steps)
334
+
335
+
336
+ def stage_fraction(stage: int, max_stage: int) -> float:
337
+ if max_stage <= 1:
338
+ return 1.0
339
+ return float(stage - 1) / float(max_stage - 1)
340
+
341
+
342
+ def maybe_trim_examples(problems: list[AdditionProblem], limit: int) -> list[AdditionProblem]:
343
+ if limit <= 0 or len(problems) <= limit:
344
+ return list(problems)
345
+ return list(problems[:limit])
346
+
347
+
348
+ def stage_display_name(stage: int) -> str:
349
+ suffix = "th"
350
+ if stage % 10 == 1 and stage % 100 != 11:
351
+ suffix = "st"
352
+ elif stage % 10 == 2 and stage % 100 != 12:
353
+ suffix = "nd"
354
+ elif stage % 10 == 3 and stage % 100 != 13:
355
+ suffix = "rd"
356
+ return f"{stage}{suffix}-digit"
357
+
358
+
359
+ def ideal_carry_chain_examples(config: ExperimentConfig, active_digits: int) -> list[AdditionProblem]:
360
+ examples: list[AdditionProblem] = []
361
+ for base_digit in (max(0, config.radix - 2), config.radix - 1):
362
+ a_digits = [base_digit] * active_digits
363
+ b_digits = [1] * active_digits
364
+ sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=config.radix)
365
+ examples.append(
366
+ AdditionProblem(
367
+ a_digits=a_digits,
368
+ b_digits=b_digits,
369
+ sum_digits=sum_digits,
370
+ carry_out=carry_out,
371
+ active_digits=active_digits,
372
+ is_carry_heavy=True,
373
+ )
374
+ )
375
+ return examples
376
+
377
+
378
+ def expected_sum_length(problem: AdditionProblem) -> int:
379
+ final_carry = problem.carry_out[problem.active_digits - 1]
380
+ return problem.active_digits + int(final_carry > 0)
381
+
382
+
383
+ def average_query_count(config: ExperimentConfig) -> float:
384
+ lengths = curriculum_stage_lengths(config)
385
+ return sum(lengths) / float(len(lengths))
386
+
387
+
388
+ def token_budget(config: ExperimentConfig) -> int:
389
+ avg_stage = int(math.ceil(average_query_count(config)))
390
+ return config.base_sequence_length_for_digits(avg_stage) + config.latent_steps_for_stage(avg_stage)
addition/eval.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Iterable
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from addition.config import ExperimentConfig
11
+ from addition.data import (
12
+ AdditionProblem,
13
+ EvaluationSuite,
14
+ build_batch,
15
+ carry_density,
16
+ count_carry_chain,
17
+ exact_sum_matches,
18
+ maybe_trim_examples,
19
+ )
20
+ from addition.model import AdditionTransformer
21
+
22
+
23
+ @dataclass
24
+ class LengthMetrics:
25
+ digit_accuracy: float
26
+ final_carry_accuracy: float
27
+ exact_match: float
28
+ avg_carry_chain: float
29
+ avg_carry_density: float
30
+ example_count: int
31
+ per_position_digit_accuracy: list[float]
32
+
33
+
34
+ def _chunked(sequence: list[AdditionProblem], chunk_size: int) -> Iterable[list[AdditionProblem]]:
35
+ for start in range(0, len(sequence), chunk_size):
36
+ yield sequence[start : start + chunk_size]
37
+
38
+
39
+ @torch.no_grad()
40
+ def evaluate_problem_set(
41
+ model: AdditionTransformer,
42
+ config: ExperimentConfig,
43
+ problems: list[AdditionProblem],
44
+ active_digits: int,
45
+ *,
46
+ device: str,
47
+ return_attention: bool = False,
48
+ ) -> tuple[LengthMetrics, dict[str, float] | None]:
49
+ model.eval()
50
+ latent_steps = config.latent_steps_for_stage(active_digits)
51
+ num_examples = len(problems)
52
+ if num_examples == 0:
53
+ empty = LengthMetrics(
54
+ digit_accuracy=0.0,
55
+ final_carry_accuracy=0.0,
56
+ exact_match=0.0,
57
+ avg_carry_chain=0.0,
58
+ avg_carry_density=0.0,
59
+ example_count=0,
60
+ per_position_digit_accuracy=[0.0] * active_digits,
61
+ )
62
+ return empty, None
63
+
64
+ predicted_digits = torch.zeros(num_examples, active_digits, dtype=torch.long)
65
+ predicted_final_carry = torch.zeros(num_examples, dtype=torch.long)
66
+ truth_digits = torch.tensor([[problem.sum_digits[position] for position in range(active_digits)] for problem in problems], dtype=torch.long)
67
+ truth_final_carry = torch.tensor([problem.carry_out[active_digits - 1] for problem in problems], dtype=torch.long)
68
+ attention_stats: dict[str, float] | None = None
69
+
70
+ offset = 0
71
+ for problem_chunk in _chunked(problems, config.eval_batch_size):
72
+ batch = build_batch(
73
+ problems=problem_chunk,
74
+ radix=config.radix,
75
+ device=device,
76
+ )
77
+ outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=return_attention)
78
+ chunk_size = len(problem_chunk)
79
+ predicted_digits[offset : offset + chunk_size] = outputs.digit_logits.argmax(dim=-1)[:, :active_digits].cpu()
80
+ predicted_final_carry[offset : offset + chunk_size] = outputs.final_carry_logits.argmax(dim=-1).cpu()
81
+ if return_attention and attention_stats is None:
82
+ attention_stats = summarize_attention(
83
+ attention_weights=outputs.attention_weights,
84
+ active_digits=active_digits,
85
+ input_sequence_length=batch.input_ids.shape[1],
86
+ output_sequence_length=outputs.output_hidden.shape[1],
87
+ )
88
+ offset += chunk_size
89
+
90
+ exact_matches = []
91
+ for example_index, problem in enumerate(problems):
92
+ exact_matches.append(
93
+ exact_sum_matches(
94
+ predicted_digits=predicted_digits[example_index].tolist(),
95
+ predicted_final_carry=int(predicted_final_carry[example_index].item()),
96
+ truth_digits=problem.sum_digits[:active_digits],
97
+ truth_final_carry=problem.carry_out[active_digits - 1],
98
+ )
99
+ )
100
+
101
+ per_position_digit = (predicted_digits == truth_digits).float().mean(dim=0).tolist()
102
+ metrics = LengthMetrics(
103
+ digit_accuracy=float((predicted_digits == truth_digits).float().mean().item()),
104
+ final_carry_accuracy=float((predicted_final_carry == truth_final_carry).float().mean().item()),
105
+ exact_match=float(torch.tensor(exact_matches, dtype=torch.float32).mean().item()),
106
+ avg_carry_chain=float(sum(count_carry_chain(problem) for problem in problems) / len(problems)),
107
+ avg_carry_density=float(sum(carry_density(problem) for problem in problems) / len(problems)),
108
+ example_count=len(problems),
109
+ per_position_digit_accuracy=[float(value) for value in per_position_digit],
110
+ )
111
+ return metrics, attention_stats
112
+
113
+
114
+ def summarize_attention(
115
+ attention_weights: torch.Tensor | None,
116
+ *,
117
+ active_digits: int,
118
+ input_sequence_length: int,
119
+ output_sequence_length: int,
120
+ ) -> dict[str, float]:
121
+ if attention_weights is None:
122
+ return {}
123
+ # Shape: [batch, heads, target_len, source_len]
124
+ final_attention = attention_weights[:, :, -1, :]
125
+ attention_mean = final_attention.mean(dim=(0, 1))
126
+ active_last_a_index = active_digits
127
+ active_last_b_index = input_sequence_length // 2 + active_digits
128
+ latent_slice = attention_mean[input_sequence_length : -output_sequence_length]
129
+ output_slice = attention_mean[-output_sequence_length:-1]
130
+ entropy = -torch.sum(attention_mean * torch.log(attention_mean.clamp_min(1e-9))).item()
131
+ summary = {
132
+ "lsd_a_attention": float(attention_mean[1].item()),
133
+ "msd_a_attention": float(attention_mean[active_last_a_index].item()),
134
+ "lsd_b_attention": float(attention_mean[(input_sequence_length // 2) + 1].item()),
135
+ "msd_b_attention": float(attention_mean[active_last_b_index].item()),
136
+ "attention_entropy": float(entropy),
137
+ "all_latent_attention": float(latent_slice.sum().item()) if latent_slice.numel() else 0.0,
138
+ "previous_output_attention": float(output_slice.sum().item()) if output_slice.numel() else 0.0,
139
+ }
140
+ return summary
141
+
142
+
143
+ @torch.no_grad()
144
+ def evaluate_length_dict(
145
+ model: AdditionTransformer,
146
+ config: ExperimentConfig,
147
+ problems_by_length: dict[int, list[AdditionProblem]],
148
+ *,
149
+ device: str,
150
+ attention_length: int | None = None,
151
+ ) -> dict[str, dict]:
152
+ structured: dict[str, dict] = {}
153
+ for length, problems in sorted(problems_by_length.items()):
154
+ length_metrics, attention = evaluate_problem_set(
155
+ model=model,
156
+ config=config,
157
+ problems=problems,
158
+ active_digits=length,
159
+ device=device,
160
+ return_attention=attention_length is not None and attention_length == length,
161
+ )
162
+ structured[str(length)] = {
163
+ "digit_accuracy": length_metrics.digit_accuracy,
164
+ "final_carry_accuracy": length_metrics.final_carry_accuracy,
165
+ "exact_match": length_metrics.exact_match,
166
+ "avg_carry_chain": length_metrics.avg_carry_chain,
167
+ "avg_carry_density": length_metrics.avg_carry_density,
168
+ "example_count": length_metrics.example_count,
169
+ "per_position_digit_accuracy": length_metrics.per_position_digit_accuracy,
170
+ }
171
+ if attention is not None:
172
+ structured[str(length)]["attention_summary"] = attention
173
+ return structured
174
+
175
+
176
+ def collect_hidden_dataset(
177
+ model: AdditionTransformer,
178
+ config: ExperimentConfig,
179
+ problems: list[AdditionProblem],
180
+ *,
181
+ active_digits: int,
182
+ device: str,
183
+ limit_examples: int,
184
+ ) -> tuple[torch.Tensor, torch.Tensor]:
185
+ model.eval()
186
+ latent_steps = config.latent_steps_for_stage(active_digits)
187
+ selected = maybe_trim_examples(problems, limit_examples)
188
+ hidden_states: list[torch.Tensor] = []
189
+ carry_targets: list[torch.Tensor] = []
190
+ with torch.no_grad():
191
+ for problem_chunk in _chunked(selected, config.eval_batch_size):
192
+ batch = build_batch(
193
+ problems=problem_chunk,
194
+ radix=config.radix,
195
+ device=device,
196
+ )
197
+ outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=False)
198
+ slot_hidden = outputs.output_hidden[:, :active_digits, :]
199
+ slot_mask = batch.target_digit_mask
200
+ hidden_states.append(slot_hidden[slot_mask].detach().cpu())
201
+ carry_targets.append(batch.target_carry[slot_mask].detach().cpu())
202
+ return torch.cat(hidden_states, dim=0), torch.cat(carry_targets, dim=0)
203
+
204
+
205
+ def fit_linear_probe(
206
+ hidden_states: torch.Tensor,
207
+ carry_targets: torch.Tensor,
208
+ *,
209
+ epochs: int,
210
+ learning_rate: float,
211
+ ) -> dict[str, float]:
212
+ if hidden_states.numel() == 0:
213
+ return {"probe_accuracy": 0.0}
214
+ indices = torch.randperm(hidden_states.shape[0])
215
+ hidden_states = hidden_states[indices]
216
+ carry_targets = carry_targets[indices]
217
+ split_index = max(1, int(0.8 * hidden_states.shape[0]))
218
+ train_hidden = hidden_states[:split_index]
219
+ train_targets = carry_targets[:split_index]
220
+ test_hidden = hidden_states[split_index:]
221
+ test_targets = carry_targets[split_index:]
222
+ if test_hidden.numel() == 0:
223
+ test_hidden = train_hidden
224
+ test_targets = train_targets
225
+
226
+ probe = nn.Linear(hidden_states.shape[-1], 2)
227
+ optimizer = torch.optim.AdamW(probe.parameters(), lr=learning_rate)
228
+ loss_fn = nn.CrossEntropyLoss()
229
+ for _ in range(epochs):
230
+ logits = probe(train_hidden)
231
+ loss = loss_fn(logits, train_targets)
232
+ optimizer.zero_grad()
233
+ loss.backward()
234
+ optimizer.step()
235
+
236
+ with torch.no_grad():
237
+ predictions = probe(test_hidden).argmax(dim=-1)
238
+ accuracy = float((predictions == test_targets).float().mean().item())
239
+ return {"probe_accuracy": accuracy}
240
+
241
+
242
+ def evaluate_suite(
243
+ model: AdditionTransformer,
244
+ config: ExperimentConfig,
245
+ suite: EvaluationSuite,
246
+ *,
247
+ device: str,
248
+ ) -> dict[str, dict]:
249
+ id_lengths = list(range(1, config.train_max_digits + 1))
250
+ ood_lengths = list(config.ood_lengths)
251
+ max_attention_length = max(ood_lengths) if ood_lengths else config.train_max_digits
252
+
253
+ validation = evaluate_length_dict(
254
+ model=model,
255
+ config=config,
256
+ problems_by_length={length: suite.validation_uniform[length] for length in id_lengths},
257
+ device=device,
258
+ )
259
+ uniform_all = evaluate_length_dict(
260
+ model=model,
261
+ config=config,
262
+ problems_by_length={length: suite.test_uniform[length] for length in sorted(set(id_lengths + ood_lengths))},
263
+ device=device,
264
+ attention_length=max_attention_length,
265
+ )
266
+ carry_heavy_all = evaluate_length_dict(
267
+ model=model,
268
+ config=config,
269
+ problems_by_length={length: suite.test_carry_heavy[length] for length in sorted(set(id_lengths + ood_lengths))},
270
+ device=device,
271
+ attention_length=max_attention_length,
272
+ )
273
+ probe_hidden, probe_targets = collect_hidden_dataset(
274
+ model=model,
275
+ config=config,
276
+ problems=suite.test_carry_heavy[max_attention_length],
277
+ active_digits=max_attention_length,
278
+ device=device,
279
+ limit_examples=config.attention_probe_examples,
280
+ )
281
+ diagnostics = fit_linear_probe(
282
+ hidden_states=probe_hidden,
283
+ carry_targets=probe_targets,
284
+ epochs=config.linear_probe_epochs,
285
+ learning_rate=config.linear_probe_lr,
286
+ )
287
+ diagnostics["attention_uniform"] = uniform_all[str(max_attention_length)].get("attention_summary", {})
288
+ diagnostics["attention_carry_heavy"] = carry_heavy_all[str(max_attention_length)].get("attention_summary", {})
289
+ return {
290
+ "validation_uniform": validation,
291
+ "test_uniform": uniform_all,
292
+ "test_carry_heavy": carry_heavy_all,
293
+ "diagnostics": diagnostics,
294
+ }
295
+
296
+
297
+ def stage_validation_metric(results: dict[str, dict], stage: int) -> float:
298
+ stage_metrics = results["validation_uniform"][str(stage)]
299
+ return float(stage_metrics["digit_accuracy"])
300
+
301
+
302
+ def flatten_nested_metrics(prefix: str, nested: dict[str, dict]) -> dict[str, float]:
303
+ flat: dict[str, float] = {}
304
+ for split_name, split_metrics in nested.items():
305
+ if split_name == "diagnostics":
306
+ for key, value in split_metrics.items():
307
+ if isinstance(value, dict):
308
+ for inner_key, inner_value in value.items():
309
+ flat[f"{prefix}{split_name}/{key}/{inner_key}"] = float(inner_value)
310
+ else:
311
+ flat[f"{prefix}{split_name}/{key}"] = float(value)
312
+ continue
313
+ for length, length_metrics in split_metrics.items():
314
+ if not isinstance(length_metrics, dict):
315
+ continue
316
+ for metric_name, metric_value in length_metrics.items():
317
+ if isinstance(metric_value, list):
318
+ if metric_value:
319
+ flat[f"{prefix}{split_name}/length_{length}/{metric_name}_mean"] = float(sum(metric_value) / len(metric_value))
320
+ continue
321
+ if isinstance(metric_value, dict):
322
+ for inner_key, inner_value in metric_value.items():
323
+ flat[f"{prefix}{split_name}/length_{length}/{metric_name}/{inner_key}"] = float(inner_value)
324
+ continue
325
+ flat[f"{prefix}{split_name}/length_{length}/{metric_name}"] = float(metric_value)
326
+ return flat
addition/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from addition.config import ExperimentConfig
9
+
10
+
11
+ @dataclass
12
+ class ModelOutput:
13
+ digit_logits: torch.Tensor
14
+ final_carry_logits: torch.Tensor
15
+ output_hidden: torch.Tensor
16
+ latent_history: list[torch.Tensor]
17
+ attention_weights: torch.Tensor | None
18
+
19
+
20
+ class TransformerBlock(nn.Module):
21
+ def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None:
22
+ super().__init__()
23
+ self.ln_1 = nn.LayerNorm(d_model)
24
+ self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
25
+ self.dropout = nn.Dropout(dropout)
26
+ self.ln_2 = nn.LayerNorm(d_model)
27
+ self.mlp = nn.Sequential(
28
+ nn.Linear(d_model, ff_dim),
29
+ nn.GELU(),
30
+ nn.Linear(ff_dim, d_model),
31
+ nn.Dropout(dropout),
32
+ )
33
+
34
+ def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]:
35
+ seq_len = hidden_states.shape[1]
36
+ causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1)
37
+ normed = self.ln_1(hidden_states)
38
+ attn_output, attn_weights = self.attn(
39
+ normed,
40
+ normed,
41
+ normed,
42
+ need_weights=need_weights,
43
+ average_attn_weights=False,
44
+ attn_mask=causal_mask,
45
+ )
46
+ hidden_states = hidden_states + self.dropout(attn_output)
47
+ hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
48
+ return hidden_states, attn_weights if need_weights else None
49
+
50
+
51
+ class AdditionTransformer(nn.Module):
52
+ def __init__(self, config: ExperimentConfig) -> None:
53
+ super().__init__()
54
+ self.config = config
55
+ self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model)
56
+ self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model)
57
+ self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model))
58
+ self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model))
59
+ self.block = TransformerBlock(
60
+ d_model=config.d_model,
61
+ n_heads=config.n_heads,
62
+ ff_dim=config.ff_dim,
63
+ dropout=config.dropout,
64
+ )
65
+ self.final_ln = nn.LayerNorm(config.d_model)
66
+ self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size)
67
+ self.final_carry_head = nn.Linear(config.d_model, 2)
68
+ self.reset_parameters()
69
+
70
+ def reset_parameters(self) -> None:
71
+ nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
72
+ nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
73
+ nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02)
74
+ nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02)
75
+ nn.init.xavier_uniform_(self.digit_head.weight)
76
+ nn.init.zeros_(self.digit_head.bias)
77
+ nn.init.xavier_uniform_(self.final_carry_head.weight)
78
+ nn.init.zeros_(self.final_carry_head.bias)
79
+
80
+ def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
81
+ seq_len = input_ids.shape[1]
82
+ positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
83
+ return self.token_embedding(input_ids) + self.position_embedding(positions)
84
+
85
+ def embed_output_slots(
86
+ self,
87
+ batch_size: int,
88
+ output_length: int,
89
+ latent_count: int,
90
+ input_length: int,
91
+ device: torch.device,
92
+ ) -> torch.Tensor:
93
+ positions = torch.arange(output_length, device=device) + input_length + latent_count
94
+ positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions)
95
+ return positioned.unsqueeze(0).expand(batch_size, -1, -1)
96
+
97
+ def _run_block(
98
+ self,
99
+ embeddings: torch.Tensor,
100
+ *,
101
+ need_attention: bool = False,
102
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
103
+ hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention)
104
+ hidden_states = self.final_ln(hidden_states)
105
+ return hidden_states, attention_weights
106
+
107
+ def forward(
108
+ self,
109
+ input_ids: torch.Tensor,
110
+ *,
111
+ latent_steps: int = 0,
112
+ return_attention: bool = False,
113
+ ) -> ModelOutput:
114
+ base_embeddings = self.embed_discrete_tokens(input_ids)
115
+ latent_history: list[torch.Tensor] = []
116
+ attention_weights: torch.Tensor | None = None
117
+ batch_size = input_ids.shape[0]
118
+ input_length = input_ids.shape[1]
119
+ active_digits = max(1, (input_length - 2) // 2)
120
+ output_length = active_digits + 1
121
+ output_embeddings = self.embed_output_slots(
122
+ batch_size=batch_size,
123
+ output_length=output_length,
124
+ latent_count=0,
125
+ input_length=input_length,
126
+ device=input_ids.device,
127
+ )
128
+ hidden_states, attention_weights = self._run_block(
129
+ torch.cat([base_embeddings, output_embeddings], dim=1),
130
+ need_attention=return_attention,
131
+ )
132
+ output_hidden = hidden_states[:, -output_length:, :]
133
+ summary_hidden = output_hidden[:, -1, :]
134
+ latent_history.append(summary_hidden)
135
+
136
+ latent_embeddings: list[torch.Tensor] = []
137
+ for step_index in range(int(latent_steps)):
138
+ latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1)
139
+ latent_position_index = input_length + step_index
140
+ latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1)
141
+ latent_embeddings.append(latent_token)
142
+ output_embeddings = self.embed_output_slots(
143
+ batch_size=batch_size,
144
+ output_length=output_length,
145
+ latent_count=len(latent_embeddings),
146
+ input_length=input_length,
147
+ device=input_ids.device,
148
+ )
149
+ hidden_states, attention_weights = self._run_block(
150
+ torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1),
151
+ need_attention=return_attention,
152
+ )
153
+ latent_index = input_length + step_index
154
+ summary_hidden = hidden_states[:, latent_index, :]
155
+ output_hidden = hidden_states[:, -output_length:, :]
156
+ latent_history.append(summary_hidden)
157
+
158
+ digit_logits = self.digit_head(output_hidden[:, :active_digits, :])
159
+ final_carry_logits = self.final_carry_head(output_hidden[:, -1, :])
160
+ return ModelOutput(
161
+ digit_logits=digit_logits,
162
+ final_carry_logits=final_carry_logits,
163
+ output_hidden=output_hidden,
164
+ latent_history=latent_history,
165
+ attention_weights=attention_weights,
166
+ )
167
+
168
+ def parameter_count(self) -> int:
169
+ return sum(parameter.numel() for parameter in self.parameters())
170
+
171
+
172
+ def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer:
173
+ model = AdditionTransformer(config)
174
+ if device is not None:
175
+ model = model.to(device)
176
+ return model
177
+
178
+
179
+ @torch.no_grad()
180
+ def describe_model(config: ExperimentConfig) -> dict[str, int]:
181
+ model = build_model(config)
182
+ total_params = model.parameter_count()
183
+ head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name)
184
+ embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name)
185
+ return {
186
+ "total_params": int(total_params),
187
+ "embedding_params": int(embedding_params),
188
+ "head_params": int(head_params),
189
+ "backbone_params": int(total_params - head_params),
190
+ }
addition/plots.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+
8
+ def _load_pyplot():
9
+ import matplotlib.pyplot as plt
10
+
11
+ return plt
12
+
13
+
14
+ def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]:
15
+ if not history:
16
+ return []
17
+ plt = _load_pyplot()
18
+ output_dir.mkdir(parents=True, exist_ok=True)
19
+ steps = [entry["global_step"] for entry in history]
20
+ digit_acc = [entry["validation_digit_accuracy"] for entry in history]
21
+ carry_acc = [entry["validation_final_carry_accuracy"] for entry in history]
22
+ exact_match = [entry["validation_exact_match"] for entry in history]
23
+ stages = [entry["stage"] for entry in history]
24
+
25
+ saved_paths: list[Path] = []
26
+
27
+ plt.figure(figsize=(8, 4.5))
28
+ plt.plot(steps, digit_acc, label="Val digit acc")
29
+ plt.plot(steps, carry_acc, label="Val final carry acc")
30
+ plt.plot(steps, exact_match, label="Val exact match")
31
+ plt.xlabel("Global step")
32
+ plt.ylabel("Accuracy")
33
+ plt.ylim(0.0, 1.01)
34
+ plt.legend()
35
+ plt.tight_layout()
36
+ metrics_path = output_dir / "training_curves.png"
37
+ plt.savefig(metrics_path, dpi=160)
38
+ plt.close()
39
+ saved_paths.append(metrics_path)
40
+
41
+ plt.figure(figsize=(8, 4.5))
42
+ plt.step(steps, stages, where="post")
43
+ plt.xlabel("Global step")
44
+ plt.ylabel("Curriculum stage")
45
+ plt.tight_layout()
46
+ stage_path = output_dir / "stage_progression.png"
47
+ plt.savefig(stage_path, dpi=160)
48
+ plt.close()
49
+ saved_paths.append(stage_path)
50
+
51
+ return saved_paths
52
+
53
+
54
+ def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]:
55
+ lengths = sorted(int(length) for length in aggregate[method][split].keys())
56
+ means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths]
57
+ stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths]
58
+ return lengths, means, stds
59
+
60
+
61
+ def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]:
62
+ plt = _load_pyplot()
63
+ output_dir.mkdir(parents=True, exist_ok=True)
64
+ saved_paths: list[Path] = []
65
+ methods = list(aggregate.keys())
66
+ splits = [
67
+ ("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"),
68
+ ("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"),
69
+ ]
70
+ for split, filename, title in splits:
71
+ plt.figure(figsize=(8, 4.5))
72
+ for method in methods:
73
+ lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match")
74
+ plt.plot(lengths, means, marker="o", label=method)
75
+ lower = [max(0.0, mean - std) for mean, std in zip(means, stds)]
76
+ upper = [min(1.0, mean + std) for mean, std in zip(means, stds)]
77
+ plt.fill_between(lengths, lower, upper, alpha=0.15)
78
+ plt.xlabel("Active digits")
79
+ plt.ylabel("Exact-match accuracy")
80
+ plt.title(title)
81
+ plt.ylim(0.0, 1.01)
82
+ plt.legend()
83
+ plt.tight_layout()
84
+ path = output_dir / filename
85
+ plt.savefig(path, dpi=160)
86
+ plt.close()
87
+ saved_paths.append(path)
88
+
89
+ plt.figure(figsize=(8, 4.5))
90
+ for method in methods:
91
+ stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys())
92
+ means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages]
93
+ stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages]
94
+ plt.plot(stages, means, marker="o", label=method)
95
+ plt.fill_between(
96
+ stages,
97
+ [max(0.0, mean - std) for mean, std in zip(means, stds)],
98
+ [min(1.0, mean + std) for mean, std in zip(means, stds)],
99
+ alpha=0.15,
100
+ )
101
+ plt.xlabel("Curriculum stage")
102
+ plt.ylabel("Best validation digit accuracy")
103
+ plt.ylim(0.0, 1.01)
104
+ plt.title("Validation digit accuracy vs stage")
105
+ plt.legend()
106
+ plt.tight_layout()
107
+ stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png"
108
+ plt.savefig(stage_curve_path, dpi=160)
109
+ plt.close()
110
+ saved_paths.append(stage_curve_path)
111
+ return saved_paths
112
+
113
+
114
+ def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]:
115
+ plt = _load_pyplot()
116
+ output_dir.mkdir(parents=True, exist_ok=True)
117
+ saved_paths = plot_training_history(summary.get("history", []), output_dir)
118
+ uniform = summary["final_results"]["test_uniform"]
119
+ carry_heavy = summary["final_results"]["test_carry_heavy"]
120
+ lengths = sorted(int(length) for length in uniform.keys())
121
+ uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths]
122
+ carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths]
123
+ plt.figure(figsize=(8, 4.5))
124
+ plt.plot(lengths, uniform_exact, marker="o", label="Uniform")
125
+ plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy")
126
+ plt.xlabel("Active digits")
127
+ plt.ylabel("Exact-match accuracy")
128
+ plt.ylim(0.0, 1.01)
129
+ plt.legend()
130
+ plt.tight_layout()
131
+ final_curve_path = output_dir / "final_exact_match_by_length.png"
132
+ plt.savefig(final_curve_path, dpi=160)
133
+ plt.close()
134
+ saved_paths.append(final_curve_path)
135
+ return saved_paths
addition/run_comparison.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ from pathlib import Path
7
+ from statistics import mean, pstdev
8
+ from typing import Any
9
+
10
+ from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args
11
+ from addition.plots import plot_method_comparison
12
+ from addition.train import run_experiment
13
+
14
+
15
+ def _mean_std(values: list[float]) -> dict[str, float]:
16
+ if not values:
17
+ return {"mean": 0.0, "std": 0.0}
18
+ if len(values) == 1:
19
+ return {"mean": float(values[0]), "std": 0.0}
20
+ return {"mean": float(mean(values)), "std": float(pstdev(values))}
21
+
22
+
23
+ def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]:
24
+ lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int)
25
+ metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"]
26
+ aggregated: dict[str, Any] = {}
27
+ for length in lengths:
28
+ aggregated[length] = {}
29
+ for metric_name in metric_names:
30
+ values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries]
31
+ aggregated[length][metric_name] = _mean_std(values)
32
+ return aggregated
33
+
34
+
35
+ def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
36
+ max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"])
37
+ aggregated: dict[str, Any] = {}
38
+ for stage in range(1, max_stage + 1):
39
+ stage_values = []
40
+ stage_exact = []
41
+ for summary in run_summaries:
42
+ stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage]
43
+ if not stage_entries:
44
+ continue
45
+ stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries))
46
+ stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries))
47
+ aggregated[str(stage)] = {
48
+ "validation_digit_accuracy": _mean_std(stage_values),
49
+ "validation_exact_match": _mean_std(stage_exact),
50
+ }
51
+ return aggregated
52
+
53
+
54
+ def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
55
+ diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries]
56
+ output: dict[str, Any] = {
57
+ "probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]),
58
+ }
59
+ for attention_key in ("attention_uniform", "attention_carry_heavy"):
60
+ attention_values = [diag.get(attention_key, {}) for diag in diagnostics]
61
+ metric_names = sorted({metric for diag in attention_values for metric in diag.keys()})
62
+ output[attention_key] = {
63
+ metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names
64
+ }
65
+ return output
66
+
67
+
68
+ def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]:
69
+ aggregate: dict[str, Any] = {}
70
+ for method, run_summaries in results_by_method.items():
71
+ aggregate[method] = {
72
+ "test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"),
73
+ "test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"),
74
+ "stage_progression": _aggregate_stage_progression(run_summaries),
75
+ "diagnostics": _aggregate_diagnostics(run_summaries),
76
+ }
77
+ return aggregate
78
+
79
+
80
+ def build_parser() -> argparse.ArgumentParser:
81
+ parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.")
82
+ add_config_arguments(parser)
83
+ parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS)
84
+ parser.add_argument("--seeds", nargs="*", type=int, default=None)
85
+ parser.add_argument("--comparison_output_dir", type=str, default="")
86
+ return parser
87
+
88
+
89
+ def main() -> None:
90
+ parser = build_parser()
91
+ args = parser.parse_args()
92
+ base_config = apply_preset(build_config_from_args(args))
93
+ seeds = args.seeds or list(range(base_config.comparison_num_seeds))
94
+ comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}")
95
+ comparison_root.mkdir(parents=True, exist_ok=True)
96
+
97
+ results_by_method: dict[str, list[dict[str, Any]]] = {}
98
+ for method in args.methods:
99
+ results_by_method[method] = []
100
+ for seed in seeds:
101
+ args.model = method
102
+ args.seed = seed
103
+ args.output_dir = str(comparison_root / f"{method}_seed{seed}")
104
+ config = apply_preset(build_config_from_args(args))
105
+ config.output_dir = str(comparison_root / f"{method}_seed{seed}")
106
+ print(f"[addition comparison] running method={method} seed={seed}", flush=True)
107
+ summary = run_experiment(config)
108
+ results_by_method[method].append(summary)
109
+
110
+ aggregate = aggregate_runs(results_by_method)
111
+ aggregate_payload = {
112
+ "methods": args.methods,
113
+ "seeds": seeds,
114
+ "aggregate": aggregate,
115
+ }
116
+ with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle:
117
+ json.dump(aggregate_payload, handle, indent=2, sort_keys=True)
118
+ plot_method_comparison(aggregate, comparison_root / "plots")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
addition/train.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from addition.config import ExperimentConfig, ensure_output_dirs, parse_config, save_config
12
+ from addition.data import build_batch, build_evaluation_suite, digits_to_string, exact_sum_matches, sample_training_batch, seed_everything
13
+ from addition.eval import evaluate_problem_set, evaluate_suite, flatten_nested_metrics
14
+ from addition.model import build_model, describe_model
15
+ from addition.plots import plot_single_run_results
16
+
17
+
18
+ def _maybe_init_wandb(config: ExperimentConfig, output_dir: Path):
19
+ if not config.use_wandb or config.wandb_mode == "disabled":
20
+ return None
21
+ try:
22
+ import wandb
23
+ except ImportError:
24
+ print("wandb is not installed; continuing with local logging only.")
25
+ return None
26
+ run = wandb.init(
27
+ project=config.wandb_project,
28
+ entity=config.wandb_entity or None,
29
+ name=config.effective_run_name,
30
+ mode=config.wandb_mode,
31
+ config={"experiment": config.__dict__},
32
+ dir=str(output_dir),
33
+ reinit=True,
34
+ )
35
+ return run
36
+
37
+
38
+ def _save_json(path: Path, payload: dict[str, Any]) -> None:
39
+ with path.open("w", encoding="utf-8") as handle:
40
+ json.dump(payload, handle, indent=2, sort_keys=True)
41
+
42
+
43
+ def _save_checkpoint(path: Path, model: nn.Module, optimizer: torch.optim.Optimizer, metadata: dict[str, Any]) -> None:
44
+ torch.save(
45
+ {
46
+ "model_state": model.state_dict(),
47
+ "optimizer_state": optimizer.state_dict(),
48
+ "metadata": metadata,
49
+ },
50
+ path,
51
+ )
52
+
53
+
54
+ def _stage_checkpoint_path(stage_directory: Path, stage: int) -> Path:
55
+ return stage_directory / f"stage_{stage:02d}_passed.pt"
56
+
57
+
58
+ def _evaluate_current_stage(
59
+ model: nn.Module,
60
+ config: ExperimentConfig,
61
+ suite,
62
+ stage: int,
63
+ device: str,
64
+ ) -> dict[str, float]:
65
+ stage_metrics, _ = evaluate_problem_set(
66
+ model=model,
67
+ config=config,
68
+ problems=suite.validation_uniform[stage],
69
+ active_digits=stage,
70
+ device=device,
71
+ return_attention=False,
72
+ )
73
+ return {
74
+ "digit_accuracy": stage_metrics.digit_accuracy,
75
+ "final_carry_accuracy": stage_metrics.final_carry_accuracy,
76
+ "exact_match": stage_metrics.exact_match,
77
+ }
78
+
79
+
80
+ def _masked_digit_loss(
81
+ logits: torch.Tensor,
82
+ targets: torch.Tensor,
83
+ mask: torch.Tensor,
84
+ loss_fn: nn.Module,
85
+ ) -> torch.Tensor:
86
+ masked_logits = logits[mask]
87
+ masked_targets = targets[mask]
88
+ if masked_logits.numel() == 0:
89
+ return logits.new_zeros(())
90
+ return loss_fn(masked_logits, masked_targets)
91
+
92
+
93
+ @torch.no_grad()
94
+ def _print_model_debug_format(
95
+ model: nn.Module,
96
+ config: ExperimentConfig,
97
+ *,
98
+ stage: int,
99
+ rng,
100
+ device: str,
101
+ ) -> None:
102
+ debug_batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device)
103
+ outputs = model(debug_batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
104
+ print("[addition debug] model_architecture", flush=True)
105
+ print(model, flush=True)
106
+ print(
107
+ "[addition debug] batch_format "
108
+ f"stage={stage} input_shape={tuple(debug_batch.input_ids.shape)} "
109
+ f"target_digits_shape={tuple(debug_batch.target_digits.shape)} "
110
+ f"target_mask_shape={tuple(debug_batch.target_digit_mask.shape)} "
111
+ f"target_final_carry_shape={tuple(debug_batch.target_final_carry.shape)} "
112
+ f"digit_logits_shape={tuple(outputs.digit_logits.shape)} "
113
+ f"final_carry_logits_shape={tuple(outputs.final_carry_logits.shape)} "
114
+ f"output_hidden_shape={tuple(outputs.output_hidden.shape)}",
115
+ flush=True,
116
+ )
117
+
118
+
119
+ @torch.no_grad()
120
+ def _print_validation_samples(
121
+ model: nn.Module,
122
+ config: ExperimentConfig,
123
+ problems,
124
+ *,
125
+ stage: int,
126
+ device: str,
127
+ limit: int = 3,
128
+ ) -> None:
129
+ sample_problems = list(problems[:limit])
130
+ if not sample_problems:
131
+ return
132
+ batch = build_batch(problems=sample_problems, radix=config.radix, device=device)
133
+ outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
134
+ predicted_digits = outputs.digit_logits.argmax(dim=-1).cpu().tolist()
135
+ predicted_final_carry = outputs.final_carry_logits.argmax(dim=-1).cpu().tolist()
136
+
137
+ for example_index, problem in enumerate(sample_problems):
138
+ truth_digits = problem.sum_digits[:stage]
139
+ truth_final_carry = problem.carry_out[stage - 1]
140
+ pred_digits = predicted_digits[example_index][:stage]
141
+ pred_final_carry = int(predicted_final_carry[example_index])
142
+ exact = exact_sum_matches(
143
+ predicted_digits=pred_digits,
144
+ predicted_final_carry=pred_final_carry,
145
+ truth_digits=truth_digits,
146
+ truth_final_carry=truth_final_carry,
147
+ )
148
+ a_text = digits_to_string(problem.a_digits[:stage], final_carry=0, radix=config.radix)
149
+ b_text = digits_to_string(problem.b_digits[:stage], final_carry=0, radix=config.radix)
150
+ pred_text = digits_to_string(pred_digits, final_carry=pred_final_carry, radix=config.radix)
151
+ truth_text = digits_to_string(truth_digits, final_carry=truth_final_carry, radix=config.radix)
152
+ print(
153
+ f"[addition sample] stage={stage} idx={example_index} "
154
+ f"a={a_text} b={b_text} pred={pred_text} true={truth_text} "
155
+ f"pred_digits={pred_digits} pred_carry={pred_final_carry} "
156
+ f"true_digits={truth_digits} true_carry={truth_final_carry} exact={int(exact)}",
157
+ flush=True,
158
+ )
159
+
160
+
161
+ def run_experiment(config: ExperimentConfig) -> dict[str, Any]:
162
+ directories = ensure_output_dirs(config)
163
+ save_config(config, directories["root"])
164
+ seed_everything(config.seed)
165
+ device = config.device
166
+ model = build_model(config, device=device)
167
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
168
+ digit_loss_fn = nn.CrossEntropyLoss()
169
+ final_carry_loss_fn = nn.CrossEntropyLoss()
170
+ suite = build_evaluation_suite(config)
171
+ rng = __import__("random").Random(config.seed + 12345)
172
+ history: list[dict[str, Any]] = []
173
+ best_validation = -1.0
174
+ best_checkpoint_path = directories["checkpoints"] / "best.pt"
175
+ last_checkpoint_path = directories["checkpoints"] / "last.pt"
176
+ stage = config.initial_stage if config.uses_curriculum else config.train_max_digits
177
+ stage_steps = 0
178
+ global_step = 0
179
+ stop_reason = "train_steps_exhausted"
180
+ wandb_run = _maybe_init_wandb(config, directories["root"])
181
+ started_at = time.time()
182
+ param_counts = describe_model(config)
183
+ print(
184
+ f"[addition train] model={config.model} seed={config.seed} device={device} "
185
+ f"params={param_counts['total_params']} stage={stage}",
186
+ flush=True,
187
+ )
188
+ _print_model_debug_format(model=model, config=config, stage=stage, rng=rng, device=device)
189
+
190
+ while global_step < config.train_steps:
191
+ model.train()
192
+ batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device)
193
+ optimizer.zero_grad(set_to_none=True)
194
+ outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
195
+ digit_loss = _masked_digit_loss(
196
+ logits=outputs.digit_logits,
197
+ targets=batch.target_digits,
198
+ mask=batch.target_digit_mask,
199
+ loss_fn=digit_loss_fn,
200
+ )
201
+ final_carry_loss = final_carry_loss_fn(outputs.final_carry_logits, batch.target_final_carry)
202
+ loss = digit_loss + final_carry_loss
203
+ loss.backward()
204
+ if config.grad_clip_norm > 0:
205
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
206
+ optimizer.step()
207
+
208
+ global_step += 1
209
+ stage_steps += 1
210
+
211
+ if global_step % max(1, config.validation_interval // 2) == 0:
212
+ train_message = (
213
+ f"[addition train] step={global_step} stage={stage} "
214
+ f"loss={loss.item():.4f} digit_loss={digit_loss.item():.4f} "
215
+ f"final_carry_loss={final_carry_loss.item():.4f}"
216
+ )
217
+ print(train_message, flush=True)
218
+
219
+ should_validate = (
220
+ global_step % config.validation_interval == 0
221
+ or global_step == config.train_steps
222
+ or (config.uses_curriculum and stage_steps == config.max_steps_per_stage)
223
+ )
224
+ if not should_validate:
225
+ continue
226
+
227
+ validation = _evaluate_current_stage(model=model, config=config, suite=suite, stage=stage, device=device)
228
+ history_entry = {
229
+ "global_step": global_step,
230
+ "stage": stage,
231
+ "stage_steps": stage_steps,
232
+ "loss": float(loss.item()),
233
+ "digit_loss": float(digit_loss.item()),
234
+ "final_carry_loss": float(final_carry_loss.item()),
235
+ "validation_digit_accuracy": validation["digit_accuracy"],
236
+ "validation_final_carry_accuracy": validation["final_carry_accuracy"],
237
+ "validation_exact_match": validation["exact_match"],
238
+ "latent_steps": config.latent_steps_for_stage(stage),
239
+ }
240
+ history.append(history_entry)
241
+ print(
242
+ f"[addition val] step={global_step} stage={stage} "
243
+ f"digit_acc={validation['digit_accuracy']:.4f} final_carry_acc={validation['final_carry_accuracy']:.4f} "
244
+ f"exact={validation['exact_match']:.4f}",
245
+ flush=True,
246
+ )
247
+ _print_validation_samples(
248
+ model=model,
249
+ config=config,
250
+ problems=suite.validation_uniform[stage],
251
+ stage=stage,
252
+ device=device,
253
+ )
254
+ if wandb_run is not None:
255
+ payload = {
256
+ "train/loss": float(loss.item()),
257
+ "train/digit_loss": float(digit_loss.item()),
258
+ "train/final_carry_loss": float(final_carry_loss.item()),
259
+ "train/stage": stage,
260
+ "train/latent_steps": config.latent_steps_for_stage(stage),
261
+ "validation/digit_accuracy": validation["digit_accuracy"],
262
+ "validation/final_carry_accuracy": validation["final_carry_accuracy"],
263
+ "validation/exact_match": validation["exact_match"],
264
+ "step": global_step,
265
+ }
266
+ wandb_run.log(payload)
267
+
268
+ if validation["exact_match"] >= best_validation:
269
+ best_validation = validation["exact_match"]
270
+ _save_checkpoint(
271
+ best_checkpoint_path,
272
+ model,
273
+ optimizer,
274
+ metadata={
275
+ "global_step": global_step,
276
+ "stage": stage,
277
+ "best_validation_exact_match": best_validation,
278
+ },
279
+ )
280
+
281
+ reached_threshold = validation["exact_match"] >= config.stage_accuracy_threshold
282
+ reached_cap = stage_steps >= config.max_steps_per_stage
283
+
284
+ if config.uses_curriculum:
285
+ if stage < config.train_max_digits and reached_threshold:
286
+ _save_checkpoint(
287
+ _stage_checkpoint_path(directories["stage_checkpoints"], stage),
288
+ model,
289
+ optimizer,
290
+ metadata={
291
+ "global_step": global_step,
292
+ "stage": stage,
293
+ "validation_exact_match": validation["exact_match"],
294
+ "validation_digit_accuracy": validation["digit_accuracy"],
295
+ "validation_final_carry_accuracy": validation["final_carry_accuracy"],
296
+ },
297
+ )
298
+ print(
299
+ f"[addition curriculum] advance {stage} -> {stage + 1} "
300
+ f"(exact_match={validation['exact_match']:.4f})",
301
+ flush=True,
302
+ )
303
+ stage += 1
304
+ stage_steps = 0
305
+ continue
306
+ if reached_cap and not reached_threshold:
307
+ print(
308
+ f"[addition curriculum] hold stage={stage} after {stage_steps} steps "
309
+ f"(exact_match={validation['exact_match']:.4f} < threshold={config.stage_accuracy_threshold:.2f})",
310
+ flush=True,
311
+ )
312
+ if stage == config.train_max_digits and reached_threshold:
313
+ stop_reason = "final_stage_threshold"
314
+ break
315
+
316
+ _save_checkpoint(
317
+ last_checkpoint_path,
318
+ model,
319
+ optimizer,
320
+ metadata={
321
+ "global_step": global_step,
322
+ "stage": stage,
323
+ "stop_reason": stop_reason,
324
+ },
325
+ )
326
+
327
+ best_payload = torch.load(best_checkpoint_path, map_location=device)
328
+ model.load_state_dict(best_payload["model_state"])
329
+ final_results = evaluate_suite(model=model, config=config, suite=suite, device=device)
330
+ flat_final_metrics = flatten_nested_metrics("", final_results)
331
+ summary = {
332
+ "config": config.__dict__,
333
+ "param_counts": param_counts,
334
+ "best_validation_exact_match": best_validation,
335
+ "global_step": global_step,
336
+ "final_stage": stage,
337
+ "stop_reason": stop_reason,
338
+ "elapsed_seconds": time.time() - started_at,
339
+ "history": history,
340
+ "final_results": final_results,
341
+ "flat_final_metrics": flat_final_metrics,
342
+ }
343
+ _save_json(directories["artifacts"] / "summary.json", summary)
344
+ with (directories["artifacts"] / "history.jsonl").open("w", encoding="utf-8") as handle:
345
+ for entry in history:
346
+ handle.write(json.dumps(entry, sort_keys=True) + "\n")
347
+ plot_single_run_results(summary, directories["plots"])
348
+
349
+ if wandb_run is not None:
350
+ wandb_run.log(flat_final_metrics | {"step": global_step})
351
+ wandb_run.summary.update(
352
+ {
353
+ "best_validation_exact_match": best_validation,
354
+ "final_stage": stage,
355
+ "stop_reason": stop_reason,
356
+ }
357
+ )
358
+ wandb_run.finish()
359
+
360
+ return summary
361
+
362
+
363
+ def main() -> None:
364
+ config = parse_config("Train the addition carry experiment.")
365
+ run_experiment(config)
366
+
367
+
368
+ if __name__ == "__main__":
369
+ main()
aligned_cell_policy/shared_cell_policy.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List
8
+
9
+ import numpy as np
10
+
11
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
12
+ PARENT_DIR = os.path.dirname(CURRENT_DIR)
13
+ if PARENT_DIR not in sys.path:
14
+ sys.path.insert(0, PARENT_DIR)
15
+
16
+ from format_utils_icon import parse_n_value_prediction
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class CellExample:
21
+ grid: np.ndarray
22
+ target_cell: tuple[int, int]
23
+ target_value: int
24
+ turn_idx: int
25
+ total_turns: int
26
+
27
+
28
+ _TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)")
29
+
30
+
31
+ def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray:
32
+ triples = _TUPLE_PROMPT_RE.findall(str(prompt_text))
33
+ if len(triples) < 81:
34
+ raise ValueError("Could not recover 81 (row,col,value) tuples from prompt.")
35
+ grid = np.zeros((9, 9), dtype=int)
36
+ for rr, cc, vv in triples[:81]:
37
+ grid[int(rr) - 1, int(cc) - 1] = int(vv)
38
+ return grid
39
+
40
+
41
+ def build_cell_examples_from_row(row: Dict[str, Any]) -> List[CellExample]:
42
+ prompt = str(row["prompt"])
43
+ grid = parse_grid_from_tuple_prompt(prompt)
44
+ metadata = dict(row.get("metadata", {}))
45
+ empty_locs = metadata.get("empty_locs_1based")
46
+ target_triples = metadata.get("target_triples_1based")
47
+
48
+ if not empty_locs or not target_triples:
49
+ completion = str(row.get("completion", ""))
50
+ parsed, _ = parse_n_value_prediction(completion, int(metadata.get("empties", 0) or 0))
51
+ if parsed is None:
52
+ raise ValueError("Row is missing metadata and completion could not be parsed.")
53
+ empty_locs = [(r + 1, c + 1) for r, c in np.argwhere(grid == 0).tolist()]
54
+ target_triples = [(int(r), int(c), int(v)) for (r, c), v in zip(empty_locs, parsed)]
55
+
56
+ total_turns = len(target_triples)
57
+ out: List[CellExample] = []
58
+ for idx, triple in enumerate(target_triples, start=1):
59
+ rr, cc, value = int(triple[0]) - 1, int(triple[1]) - 1, int(triple[2])
60
+ out.append(
61
+ CellExample(
62
+ grid=np.asarray(grid, dtype=int).copy(),
63
+ target_cell=(rr, cc),
64
+ target_value=value,
65
+ turn_idx=idx,
66
+ total_turns=total_turns,
67
+ )
68
+ )
69
+ return out
analysis/eval_saved_hard9x9_checkpoints.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from types import SimpleNamespace
6
+ import sys
7
+
8
+ import torch
9
+ from peft import PeftModel
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ if ROOT := "/home/ubuntu/curriculum_cot":
13
+ if ROOT not in sys.path:
14
+ sys.path.insert(0, ROOT)
15
+
16
+ from multi_output_cell_policy import grpo_multi_output_train as baseline_grpo
17
+ from multi_output_cell_policy import sft_multi_output_train as baseline_sft
18
+ from latent_multi_output_cell_policy import grpo_residual_projector_latent_train as latent_grpo
19
+ from latent_multi_output_cell_policy import residual_projector_warmstart_sft_latent_multi_output_train as latent_sft
20
+
21
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
22
+ CACHE_DIR = os.path.join(ROOT, ".hf_cache")
23
+ DATA_PATH = os.path.join(ROOT, "data", "sudoku_t3_30empty_value_qwen_text.jsonl")
24
+ EVAL_ROWS = 20
25
+ TOTAL_EMPTIES_HINT = 30
26
+
27
+
28
+ def make_tokenizer() -> AutoTokenizer:
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, use_fast=True)
30
+ if tokenizer.pad_token_id is None:
31
+ tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
32
+ return tokenizer
33
+
34
+
35
+ def make_device() -> torch.device:
36
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+
39
+ def make_baseline_sft_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
40
+ base = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_NAME,
42
+ cache_dir=CACHE_DIR,
43
+ torch_dtype=baseline_sft.pick_dtype() if torch.cuda.is_available() else torch.float32,
44
+ low_cpu_mem_usage=True,
45
+ )
46
+ model = PeftModel.from_pretrained(base, checkpoint_dir, is_trainable=False)
47
+ model.to(device)
48
+ model.eval()
49
+ return model
50
+
51
+
52
+ def make_baseline_grpo_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
53
+ base = AutoModelForCausalLM.from_pretrained(
54
+ MODEL_NAME,
55
+ cache_dir=CACHE_DIR,
56
+ torch_dtype=baseline_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
57
+ low_cpu_mem_usage=True,
58
+ )
59
+ model = baseline_grpo.load_trainable_adapter(base, checkpoint_dir)
60
+ model.to(device)
61
+ model.eval()
62
+ return model
63
+
64
+
65
+ def make_latent_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
66
+ base = AutoModelForCausalLM.from_pretrained(
67
+ MODEL_NAME,
68
+ cache_dir=CACHE_DIR,
69
+ torch_dtype=latent_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
70
+ low_cpu_mem_usage=True,
71
+ )
72
+ model = latent_grpo.load_trainable_adapter(base, checkpoint_dir)
73
+ projector_hidden = latent_grpo.infer_projector_hidden_from_state(checkpoint_dir) or latent_grpo.PROJECTOR_HIDDEN
74
+ latent_grpo.attach_residual_projector_modules(
75
+ model,
76
+ hidden_size=int(latent_grpo.unwrap_backbone(model).config.hidden_size),
77
+ projector_hidden=projector_hidden,
78
+ )
79
+ latent_grpo.maybe_load_projector_state(model, checkpoint_dir)
80
+ model.to(device)
81
+ model.eval()
82
+ return model
83
+
84
+
85
+ def common_reward_args() -> dict:
86
+ return {
87
+ "reward_good_value": 1.0,
88
+ "penalty_bad_value": 1.75,
89
+ "penalty_malformed": 4.0,
90
+ "penalty_empty": 0.5,
91
+ "penalty_singleton": 1.5,
92
+ }
93
+
94
+
95
+ def eval_baseline_sft(checkpoint_dir: str, stage_i: int) -> dict:
96
+ device = make_device()
97
+ tokenizer = make_tokenizer()
98
+ model = make_baseline_sft_model(checkpoint_dir, device)
99
+ rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
100
+ args = SimpleNamespace(
101
+ stage_i=int(stage_i),
102
+ total_empties_hint=TOTAL_EMPTIES_HINT,
103
+ max_completion_length=24,
104
+ debug_print_limit=0,
105
+ )
106
+ metrics = baseline_sft.run_eval(args, rows, model, tokenizer, device)
107
+ del model
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+ return metrics
111
+
112
+
113
+ def eval_baseline_grpo(checkpoint_dir: str, stage_i: int) -> dict:
114
+ device = make_device()
115
+ tokenizer = make_tokenizer()
116
+ model = make_baseline_grpo_model(checkpoint_dir, device)
117
+ rows = baseline_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
118
+ args = SimpleNamespace(
119
+ stage_i=int(stage_i),
120
+ total_empties_hint=TOTAL_EMPTIES_HINT,
121
+ max_completion_length=24,
122
+ debug_print_limit=0,
123
+ **common_reward_args(),
124
+ )
125
+ metrics = baseline_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device)
126
+ del model
127
+ if torch.cuda.is_available():
128
+ torch.cuda.empty_cache()
129
+ return metrics
130
+
131
+
132
+ def eval_latent_sft(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
133
+ device = make_device()
134
+ tokenizer = make_tokenizer()
135
+ model = make_latent_model(checkpoint_dir, device)
136
+ rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
137
+ args = SimpleNamespace(
138
+ stage_i=int(stage_i),
139
+ num_cot_tokens=int(num_cot_tokens),
140
+ total_empties_hint=TOTAL_EMPTIES_HINT,
141
+ max_completion_length=32,
142
+ debug_print_limit=0,
143
+ **common_reward_args(),
144
+ )
145
+ metrics = latent_sft.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
146
+ del model
147
+ if torch.cuda.is_available():
148
+ torch.cuda.empty_cache()
149
+ return metrics
150
+
151
+
152
+ def eval_latent_grpo(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
153
+ device = make_device()
154
+ tokenizer = make_tokenizer()
155
+ model = make_latent_model(checkpoint_dir, device)
156
+ rows = latent_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
157
+ args = SimpleNamespace(
158
+ stage_i=int(stage_i),
159
+ num_cot_tokens=int(num_cot_tokens),
160
+ total_empties_hint=TOTAL_EMPTIES_HINT,
161
+ max_completion_length=32,
162
+ debug_print_limit=0,
163
+ **common_reward_args(),
164
+ )
165
+ metrics = latent_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
166
+ del model
167
+ if torch.cuda.is_available():
168
+ torch.cuda.empty_cache()
169
+ return metrics
170
+
171
+
172
+ def main() -> None:
173
+ # Explicit step dirs (not run roots) so metrics match the agreed endpoints.
174
+ checkpoints = [
175
+ {
176
+ "label": "baseline_stage1_sft",
177
+ "stage_i": 1,
178
+ "kind": "baseline_sft",
179
+ "checkpoint_dir": os.path.join(
180
+ ROOT,
181
+ "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/20260404_023600_baseline30_clean/baseline_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty/checkpoint-step-01000",
182
+ ),
183
+ },
184
+ {
185
+ "label": "baseline_stage1_grpo",
186
+ "stage_i": 1,
187
+ "kind": "baseline_grpo",
188
+ "checkpoint_dir": os.path.join(
189
+ ROOT,
190
+ "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i1_20260404_fixed_baseline_grpo_i1/checkpoint-5350",
191
+ ),
192
+ },
193
+ {
194
+ "label": "baseline_stage2_sft",
195
+ "stage_i": 2,
196
+ "kind": "baseline_sft",
197
+ "checkpoint_dir": os.path.join(
198
+ ROOT,
199
+ "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft/i2_20260404_stage2_baseline_sft_from_grpo5350/checkpoint-step-13100",
200
+ ),
201
+ },
202
+ {
203
+ "label": "baseline_stage2_grpo",
204
+ "stage_i": 2,
205
+ "kind": "baseline_grpo",
206
+ "checkpoint_dir": os.path.join(
207
+ ROOT,
208
+ "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i2_20260405_stage2_baseline_grpo_from_sft13100/checkpoint-4325",
209
+ ),
210
+ },
211
+ {
212
+ "label": "latent_stage1_sft",
213
+ "stage_i": 1,
214
+ "kind": "latent_sft",
215
+ "num_cot_tokens": 1,
216
+ "checkpoint_dir": os.path.join(
217
+ ROOT,
218
+ "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/20260404_013500_latent30_frombaseline/latent_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty_residual_projector/checkpoint-step-00200",
219
+ ),
220
+ },
221
+ {
222
+ "label": "latent_stage1_grpo",
223
+ "stage_i": 1,
224
+ "kind": "latent_grpo",
225
+ "num_cot_tokens": 1,
226
+ "checkpoint_dir": os.path.join(
227
+ ROOT,
228
+ "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740",
229
+ ),
230
+ },
231
+ {
232
+ "label": "latent_stage2_sft",
233
+ "stage_i": 2,
234
+ "kind": "latent_sft",
235
+ "num_cot_tokens": 2,
236
+ "checkpoint_dir": os.path.join(
237
+ ROOT,
238
+ "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft/i2_cot2_20260404_stage2_latent_sft_from_grpo2740/checkpoint-step-00700",
239
+ ),
240
+ },
241
+ {
242
+ "label": "latent_stage2_grpo",
243
+ "stage_i": 2,
244
+ "kind": "latent_grpo",
245
+ "num_cot_tokens": 2,
246
+ "checkpoint_dir": os.path.join(
247
+ ROOT,
248
+ "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i2_cot2_20260405_stage2_latent_grpo_from_sft00700/checkpoint-1620",
249
+ ),
250
+ },
251
+ ]
252
+
253
+ results: dict[str, dict] = {}
254
+ for item in checkpoints:
255
+ label = item["label"]
256
+ print(f"[eval] starting {label}", flush=True)
257
+ if item["kind"] == "baseline_sft":
258
+ metrics = eval_baseline_sft(item["checkpoint_dir"], item["stage_i"])
259
+ elif item["kind"] == "baseline_grpo":
260
+ metrics = eval_baseline_grpo(item["checkpoint_dir"], item["stage_i"])
261
+ elif item["kind"] == "latent_sft":
262
+ metrics = eval_latent_sft(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
263
+ else:
264
+ metrics = eval_latent_grpo(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
265
+ results[label] = metrics
266
+ print(json.dumps({"label": label, "metrics": metrics}, sort_keys=True), flush=True)
267
+
268
+ print("[eval] complete", flush=True)
269
+ print(json.dumps(results, sort_keys=True, indent=2), flush=True)
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()
checkpoint_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from typing import Any, Callable
6
+
7
+ import torch
8
+ from peft import get_peft_model_state_dict
9
+ from safetensors.torch import save_file as save_safetensors_file
10
+
11
+ FINAL_CHECKPOINT_DIRNAME = "final_checkpoint"
12
+ _WEIGHT_FILENAMES = (
13
+ "adapter_model.safetensors",
14
+ "adapter_model.bin",
15
+ "model.safetensors",
16
+ "pytorch_model.bin",
17
+ )
18
+
19
+
20
+ def ensure_final_checkpoint_dir(output_dir: str) -> str:
21
+ repo_root = os.path.dirname(os.path.abspath(__file__))
22
+ output_dir_abs = os.path.abspath(output_dir)
23
+ try:
24
+ rel_output_dir = os.path.relpath(output_dir_abs, repo_root)
25
+ except Exception:
26
+ rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep))
27
+ rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")]
28
+ if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME:
29
+ rel_parts = rel_parts[1:]
30
+ if rel_parts and rel_parts[0] == "checkpoints":
31
+ rel_parts = rel_parts[1:]
32
+ if not rel_parts:
33
+ rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"]
34
+ final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts)
35
+ os.makedirs(final_dir, exist_ok=True)
36
+ return final_dir
37
+
38
+
39
+ def final_checkpoint_root(*parts: str) -> str:
40
+ repo_root = os.path.dirname(os.path.abspath(__file__))
41
+ root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts)
42
+ os.makedirs(root, exist_ok=True)
43
+ return root
44
+
45
+
46
+ def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str:
47
+ raw = str(path or "").strip()
48
+ if not raw:
49
+ return final_checkpoint_root(*default_parts)
50
+ abs_path = os.path.abspath(raw)
51
+ repo_root = os.path.dirname(os.path.abspath(__file__))
52
+ rel_path = os.path.relpath(abs_path, repo_root)
53
+ rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")]
54
+ if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]:
55
+ return abs_path
56
+ if rel_parts[:1] == ["checkpoints"]:
57
+ rel_parts = rel_parts[1:]
58
+ return final_checkpoint_root(*rel_parts)
59
+ return abs_path
60
+
61
+
62
+ def _has_saved_weights(target_dir: str) -> bool:
63
+ return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES)
64
+
65
+
66
+ def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None:
67
+ if _has_saved_weights(target_dir):
68
+ return
69
+ state = get_peft_model_state_dict(model)
70
+ cpu_state = {
71
+ key: value.detach().cpu().contiguous()
72
+ for key, value in state.items()
73
+ if torch.is_tensor(value)
74
+ }
75
+ if cpu_state:
76
+ save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors"))
77
+
78
+
79
+ def save_model_artifacts(
80
+ model: Any,
81
+ tokenizer: Any,
82
+ target_dir: str,
83
+ *,
84
+ extra_save_fn: Callable[[Any, str], None] | None = None,
85
+ ) -> str:
86
+ os.makedirs(target_dir, exist_ok=True)
87
+ model.save_pretrained(target_dir)
88
+ if tokenizer is not None:
89
+ tokenizer.save_pretrained(target_dir)
90
+ _fallback_save_adapter_weights(model, target_dir)
91
+ if extra_save_fn is not None:
92
+ extra_save_fn(model, target_dir)
93
+ return target_dir
94
+
95
+
96
+ def _replace_dir_contents(src_dir: str, dst_dir: str) -> None:
97
+ os.makedirs(dst_dir, exist_ok=True)
98
+ src_dir_abs = os.path.abspath(src_dir)
99
+ for name in os.listdir(dst_dir):
100
+ path = os.path.join(dst_dir, name)
101
+ if os.path.abspath(path) == src_dir_abs:
102
+ continue
103
+ if os.path.isdir(path) and not os.path.islink(path):
104
+ shutil.rmtree(path)
105
+ else:
106
+ os.unlink(path)
107
+ for name in os.listdir(src_dir):
108
+ src_path = os.path.join(src_dir, name)
109
+ dst_path = os.path.join(dst_dir, name)
110
+ if os.path.isdir(src_path) and not os.path.islink(src_path):
111
+ shutil.copytree(src_path, dst_path)
112
+ else:
113
+ shutil.copy2(src_path, dst_path)
114
+
115
+
116
+ def save_checkpoint_and_update_final(
117
+ model: Any,
118
+ tokenizer: Any,
119
+ output_dir: str,
120
+ checkpoint_name: str,
121
+ *,
122
+ extra_save_fn: Callable[[Any, str], None] | None = None,
123
+ ) -> str:
124
+ checkpoint_dir = os.path.join(output_dir, checkpoint_name)
125
+ save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn)
126
+ _replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir))
127
+ return checkpoint_dir
format_utils_icon.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from typing import List, Tuple
6
+
7
+ import numpy as np
8
+
9
+
10
+ _INT_RE = re.compile(r"-?\d+")
11
+
12
+
13
+ def grid_to_text(grid_9x9: np.ndarray) -> str:
14
+ grid = np.asarray(grid_9x9, dtype=int).reshape(9, 9)
15
+ return "\n".join(" ".join(str(int(value)) for value in row) for row in grid.tolist())
16
+
17
+
18
+ def parse_n_value_prediction(text: str, n: int) -> Tuple[List[int] | None, bool]:
19
+ raw = str(text or "").strip()
20
+ if not raw:
21
+ return None, False
22
+
23
+ try:
24
+ parsed = json.loads(raw)
25
+ if isinstance(parsed, dict) and isinstance(parsed.get("values"), list):
26
+ values = [int(v) for v in parsed["values"]]
27
+ if len(values) == int(n):
28
+ return values, True
29
+ if isinstance(parsed, list):
30
+ values = [int(v) for v in parsed]
31
+ if len(values) == int(n):
32
+ return values, True
33
+ except Exception:
34
+ pass
35
+
36
+ values = [int(match.group(0)) for match in _INT_RE.finditer(raw)]
37
+ if len(values) == int(n):
38
+ return values, True
39
+ return None, False
formatting_icon.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ def is_consistent_pair(grid, *, cell: int, value: int, t: int = 3, n: int = 9) -> bool:
7
+ g = np.asarray(grid, dtype=int).reshape(int(n), int(n))
8
+ cell = int(cell)
9
+ value = int(value)
10
+ if value < 1 or value > int(n):
11
+ return False
12
+ rr, cc = divmod(cell, int(n))
13
+ current = int(g[rr, cc])
14
+ if current != 0 and current != value:
15
+ return False
16
+
17
+ row = g[rr, :]
18
+ for idx, existing in enumerate(row):
19
+ if idx != cc and int(existing) == value:
20
+ return False
21
+
22
+ col = g[:, cc]
23
+ for idx, existing in enumerate(col):
24
+ if idx != rr and int(existing) == value:
25
+ return False
26
+
27
+ box_r = (rr // int(t)) * int(t)
28
+ box_c = (cc // int(t)) * int(t)
29
+ for r in range(box_r, box_r + int(t)):
30
+ for c in range(box_c, box_c + int(t)):
31
+ if (r != rr or c != cc) and int(g[r, c]) == value:
32
+ return False
33
+
34
+ return True
hard_9x9_10empty/launch_baseline_stage3_pipeline.sh ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
8
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
9
+ PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py"
10
+
11
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_10empty_value_qwen_text_longrun.jsonl}"
12
+ NUM_PUZZLES="${NUM_PUZZLES:-5000}"
13
+ DATASET_SEED="${DATASET_SEED:-0}"
14
+
15
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
16
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
17
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
18
+
19
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_qwen05b/baseline}"
20
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_10empty_3stage_hard9x9}"
21
+
22
+ WANDB_MODE="${WANDB_MODE:-online}"
23
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
24
+ WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}"
25
+ WAIT_SECONDS="${WAIT_SECONDS:-60}"
26
+
27
+ if [[ ! -f "${TRAIN_JSONL}" ]]; then
28
+ mkdir -p "$(dirname "${TRAIN_JSONL}")"
29
+ printf 'Building 10-empty dataset: %s\n' "${TRAIN_JSONL}"
30
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
31
+ --output "${TRAIN_JSONL}" \
32
+ --num_puzzles "${NUM_PUZZLES}" \
33
+ --empties 10 \
34
+ --seed "${DATASET_SEED}"
35
+ fi
36
+
37
+ if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then
38
+ while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do
39
+ printf 'Existing training detected; waiting %ss before launching 10-empty baseline pipeline...\n' "${WAIT_SECONDS}"
40
+ sleep "${WAIT_SECONDS}"
41
+ done
42
+ fi
43
+
44
+ mkdir -p "${CHECKPOINT_ROOT}"
45
+
46
+ cmd=(
47
+ "${PYTHON_BIN}" "${PIPELINE}"
48
+ --python_executable "${PYTHON_BIN}"
49
+ --train_jsonl "${TRAIN_JSONL}"
50
+ --cache_dir "${ROOT}/.hf_cache"
51
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct"
52
+ --checkpoint_root "${CHECKPOINT_ROOT}"
53
+ --output_root "${OUTPUT_ROOT}"
54
+ --run_tag "${RUN_TAG}"
55
+ --min_stage 1
56
+ --max_stage 3
57
+ --distributed_gpu_ids "${GPU_IDS}"
58
+ --sft_num_processes "${NUM_PROCESSES}"
59
+ --grpo_num_processes "${NUM_PROCESSES}"
60
+ --total_empties_hint 10
61
+ --limit_train_rows 5000
62
+ --sft_num_epochs 3.0
63
+ --grpo_num_train_epochs 1.5
64
+ --sft_gradient_accumulation_steps 8
65
+ --grpo_per_device_train_batch_size 8
66
+ --grpo_gradient_accumulation_steps 2
67
+ --grpo_num_generations 4
68
+ --sft_enable_gradient_checkpointing
69
+ --grpo_enable_gradient_checkpointing
70
+ --sft_eval_steps 100
71
+ --sft_save_steps 100
72
+ --grpo_eval_steps 50
73
+ --grpo_save_steps 50
74
+ --sft_eval_rows 100
75
+ --grpo_eval_rows 100
76
+ --sft_stage_max_steps "1:2000,2:2000,3:2000"
77
+ --grpo_stage_max_steps "1:1200,2:1200,3:1200"
78
+ --sft_eval_solve_rate_stop 0.8
79
+ --sft_min_steps_before_stop 100
80
+ --grpo_eval_solve_rate_stop 0.8
81
+ --grpo_min_steps_before_stop 50
82
+ --grpo_reward_good_value 1.25
83
+ --grpo_penalty_bad_value 1.0
84
+ --grpo_penalty_malformed 4.0
85
+ --grpo_penalty_empty 0.5
86
+ --grpo_penalty_singleton 1.0
87
+ --phase_max_wall_clock_seconds 36000
88
+ --wandb_mode "${WANDB_MODE}"
89
+ --use_wandb
90
+ )
91
+
92
+ if [[ -n "${WANDB_ENTITY}" ]]; then
93
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
94
+ fi
95
+
96
+ printf 'Launching 10-empty baseline stage-3 pipeline\n'
97
+ printf 'Dataset: %s\n' "${TRAIN_JSONL}"
98
+ printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}"
99
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
100
+ printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
101
+
102
+ exec "${cmd[@]}"
hard_9x9_15empty/launch_baseline_pipeline.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
8
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
9
+ PIPELINE_LAUNCHER="${ROOT}/large_baseline_extension/launch_nonlocation_pipeline.sh"
10
+
11
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text.jsonl}"
12
+ NUM_PUZZLES="${NUM_PUZZLES:-20000}"
13
+ DATASET_SEED="${DATASET_SEED:-0}"
14
+
15
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
16
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
17
+ MIN_STAGE="${MIN_STAGE:-1}"
18
+ MAX_STAGE="${MAX_STAGE:-4}"
19
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
20
+
21
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline}"
22
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_15empty_4stage_hard9x9}"
23
+
24
+ WANDB_MODE="${WANDB_MODE:-online}"
25
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
26
+ WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}"
27
+ WAIT_SECONDS="${WAIT_SECONDS:-60}"
28
+
29
+ if [[ ! -f "${TRAIN_JSONL}" ]]; then
30
+ mkdir -p "$(dirname "${TRAIN_JSONL}")"
31
+ printf 'Building 15-empty dataset: %s\n' "${TRAIN_JSONL}"
32
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
33
+ --output "${TRAIN_JSONL}" \
34
+ --num_puzzles "${NUM_PUZZLES}" \
35
+ --empties 15 \
36
+ --seed "${DATASET_SEED}"
37
+ fi
38
+
39
+ if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then
40
+ while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do
41
+ printf 'Existing training detected; waiting %ss before launching 15-empty baseline...\n' "${WAIT_SECONDS}"
42
+ sleep "${WAIT_SECONDS}"
43
+ done
44
+ fi
45
+
46
+ mkdir -p "${CHECKPOINT_ROOT}"
47
+
48
+ export TRAIN_JSONL
49
+ export TOTAL_EMPTIES_HINT=15
50
+ export GPU_IDS
51
+ export NUM_PROCESSES
52
+ export MIN_STAGE
53
+ export MAX_STAGE
54
+ export RUN_TAG
55
+ export CHECKPOINT_ROOT
56
+ export OUTPUT_ROOT
57
+ export WANDB_MODE
58
+ export WANDB_ENTITY
59
+
60
+ printf 'Launching 15-empty hard 9x9 baseline pipeline\n'
61
+ printf 'Dataset: %s\n' "${TRAIN_JSONL}"
62
+ printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}"
63
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
64
+
65
+ exec "${PIPELINE_LAUNCHER}"
hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
8
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
9
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
10
+
11
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_train.jsonl}"
12
+ EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_eval.jsonl}"
13
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
14
+ EVAL_PUZZLES="${EVAL_PUZZLES:-2000}"
15
+ TRAIN_SEED="${TRAIN_SEED:-0}"
16
+ EVAL_SEED="${EVAL_SEED:-1}"
17
+
18
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
19
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
20
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
21
+
22
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline_stage1_multivalue}"
23
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_15empty_size2only}"
24
+
25
+ WANDB_MODE="${WANDB_MODE:-online}"
26
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
27
+
28
+ if [[ ! -f "${TRAIN_JSONL}" ]]; then
29
+ mkdir -p "$(dirname "${TRAIN_JSONL}")"
30
+ printf 'Building 15-empty train dataset: %s\n' "${TRAIN_JSONL}"
31
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
32
+ --output "${TRAIN_JSONL}" \
33
+ --num_puzzles "${TRAIN_PUZZLES}" \
34
+ --empties 15 \
35
+ --seed "${TRAIN_SEED}"
36
+ fi
37
+
38
+ if [[ ! -f "${EVAL_JSONL}" ]]; then
39
+ mkdir -p "$(dirname "${EVAL_JSONL}")"
40
+ printf 'Building 15-empty eval dataset: %s\n' "${EVAL_JSONL}"
41
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
42
+ --output "${EVAL_JSONL}" \
43
+ --num_puzzles "${EVAL_PUZZLES}" \
44
+ --empties 15 \
45
+ --seed "${EVAL_SEED}"
46
+ fi
47
+
48
+ mkdir -p "${CHECKPOINT_ROOT}"
49
+
50
+ cmd=(
51
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}"
52
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct"
53
+ --train_jsonl "${TRAIN_JSONL}"
54
+ --eval_jsonl "${EVAL_JSONL}"
55
+ --output_dir "${OUTPUT_DIR}"
56
+ --cache_dir "${ROOT}/.hf_cache"
57
+ --seed 0
58
+ --gpu_id 0
59
+ --stage_i 1
60
+ --total_empties_hint 15
61
+ --per_device_train_batch_size 16
62
+ --gradient_accumulation_steps 2
63
+ --num_epochs 4.0
64
+ --learning_rate 2e-4
65
+ --enable_gradient_checkpointing
66
+ --logging_steps 10
67
+ --eval_steps 50
68
+ --save_steps 50
69
+ --eval_rows "${EVAL_PUZZLES}"
70
+ --max_completion_length 24
71
+ --limit_train_rows "${TRAIN_PUZZLES}"
72
+ --lora_r 32
73
+ --lora_alpha 64
74
+ --lora_dropout 0.05
75
+ --multi_value_oversample_factor 1
76
+ --train_target_size_min 2
77
+ --train_target_size_max 2
78
+ --eval_target_size_min 2
79
+ --eval_target_size_max 2
80
+ --eval_value_precision_stop 0.95
81
+ --eval_value_recall_stop 0.95
82
+ --min_steps_before_stop 100
83
+ --max_wall_clock_seconds 7200
84
+ --max_steps 600
85
+ --use_wandb
86
+ --wandb_project "sudoku-multi-output-sft"
87
+ --wandb_run_name "baseline_stage01_sft_i1_15empty_size2only"
88
+ --wandb_mode "${WANDB_MODE}"
89
+ )
90
+
91
+ if [[ -n "${WANDB_ENTITY}" ]]; then
92
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
93
+ fi
94
+
95
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
96
+
97
+ printf 'Launching 15-empty stage-1 size-2-only SFT baseline\n'
98
+ printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}"
99
+ printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}"
100
+ printf 'Output dir: %s\n' "${OUTPUT_DIR}"
101
+ printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
102
+
103
+ exec "${cmd[@]}"
hard_9x9_7empty/launch_stage1_sft.sh ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
8
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
9
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
10
+
11
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_train.jsonl}"
12
+ EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_eval.jsonl}"
13
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-400}"
14
+ EVAL_PUZZLES="${EVAL_PUZZLES:-200}"
15
+ TRAIN_SEED="${TRAIN_SEED:-0}"
16
+ EVAL_SEED="${EVAL_SEED:-1}"
17
+
18
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
19
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
20
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
21
+
22
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_qwen05b/baseline_stage1}"
23
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_7empty}"
24
+
25
+ WANDB_MODE="${WANDB_MODE:-online}"
26
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
27
+
28
+ if [[ ! -f "${TRAIN_JSONL}" ]]; then
29
+ mkdir -p "$(dirname "${TRAIN_JSONL}")"
30
+ printf 'Building 7-empty train dataset: %s\n' "${TRAIN_JSONL}"
31
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
32
+ --output "${TRAIN_JSONL}" \
33
+ --num_puzzles "${TRAIN_PUZZLES}" \
34
+ --empties 7 \
35
+ --seed "${TRAIN_SEED}"
36
+ fi
37
+
38
+ if [[ ! -f "${EVAL_JSONL}" ]]; then
39
+ mkdir -p "$(dirname "${EVAL_JSONL}")"
40
+ printf 'Building 7-empty eval dataset: %s\n' "${EVAL_JSONL}"
41
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" \
42
+ --output "${EVAL_JSONL}" \
43
+ --num_puzzles "${EVAL_PUZZLES}" \
44
+ --empties 7 \
45
+ --seed "${EVAL_SEED}"
46
+ fi
47
+
48
+ mkdir -p "${CHECKPOINT_ROOT}"
49
+
50
+ cmd=(
51
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}"
52
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct"
53
+ --train_jsonl "${TRAIN_JSONL}"
54
+ --eval_jsonl "${EVAL_JSONL}"
55
+ --output_dir "${OUTPUT_DIR}"
56
+ --cache_dir "${ROOT}/.hf_cache"
57
+ --seed 0
58
+ --gpu_id 0
59
+ --stage_i 1
60
+ --total_empties_hint 7
61
+ --num_epochs 3.0
62
+ --learning_rate 2e-4
63
+ --gradient_accumulation_steps 8
64
+ --enable_gradient_checkpointing
65
+ --logging_steps 10
66
+ --eval_steps 25
67
+ --save_steps 25
68
+ --eval_rows "${EVAL_PUZZLES}"
69
+ --max_completion_length 24
70
+ --limit_train_rows "${TRAIN_PUZZLES}"
71
+ --lora_r 32
72
+ --lora_alpha 64
73
+ --lora_dropout 0.05
74
+ --multi_value_oversample_factor 16
75
+ --eval_exact_set_match_stop 0.999
76
+ --eval_value_precision_stop 0.999
77
+ --eval_value_recall_stop 0.999
78
+ --min_steps_before_stop 50
79
+ --max_wall_clock_seconds 1800
80
+ --max_steps 250
81
+ --use_wandb
82
+ --wandb_project "sudoku-multi-output-sft"
83
+ --wandb_run_name "baseline_stage01_sft_i1_7empty"
84
+ --wandb_mode "${WANDB_MODE}"
85
+ )
86
+
87
+ if [[ -n "${WANDB_ENTITY}" ]]; then
88
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
89
+ fi
90
+
91
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
92
+
93
+ printf 'Launching 7-empty stage-1 SFT baseline\n'
94
+ printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}"
95
+ printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}"
96
+ printf 'Output dir: %s\n' "${OUTPUT_DIR}"
97
+ printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
98
+
99
+ exec "${cmd[@]}"
hard_9x9_curriculum/build_stage3_hard_dataset.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import random
6
+ import sys
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, Iterable, List, Sequence, Tuple
10
+
11
+ import numpy as np
12
+
13
+ CURRENT_DIR = Path(__file__).resolve().parent
14
+ PARENT_DIR = CURRENT_DIR.parent
15
+ if str(PARENT_DIR) not in sys.path:
16
+ sys.path.insert(0, str(PARENT_DIR))
17
+
18
+ from formatting_icon import is_consistent_pair
19
+ from multi_output_cell_policy.shared_multi_output_policy import stage_i_consistent_values
20
+
21
+
22
+ GRID_SIZE = 9
23
+ BOX_SIZE = 3
24
+ ALL_VALUES = tuple(range(1, 10))
25
+ DEFAULT_BASE_NAME = "sudoku_t3_30empty_stage3hard"
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class DifficultyProfile:
30
+ stage1_solved: bool
31
+ stage2_solved: bool
32
+ stage3_solved: bool
33
+ stage1_steps: int
34
+ stage2_steps: int
35
+ stage3_steps: int
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class SeedMask:
40
+ mask_cells: tuple[int, ...]
41
+ profile: DifficultyProfile
42
+
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ root = PARENT_DIR
46
+ default_train = root / "data" / f"{DEFAULT_BASE_NAME}_value_qwen_text.jsonl"
47
+ default_eval = root / "data" / f"{DEFAULT_BASE_NAME}_eval_value_qwen_text.jsonl"
48
+ default_manifest = root / "data" / f"{DEFAULT_BASE_NAME}_manifest.json"
49
+ p = argparse.ArgumentParser()
50
+ p.add_argument("--train_output", type=str, default=str(default_train))
51
+ p.add_argument("--eval_output", type=str, default=str(default_eval))
52
+ p.add_argument("--manifest_output", type=str, default=str(default_manifest))
53
+ p.add_argument("--num_train_puzzles", type=int, default=4000)
54
+ p.add_argument("--num_eval_puzzles", type=int, default=200)
55
+ p.add_argument("--empties", type=int, default=30)
56
+ p.add_argument("--seed", type=int, default=0)
57
+ p.add_argument("--max_attempts", type=int, default=200000)
58
+ p.add_argument("--progress_every", type=int, default=250)
59
+ p.add_argument("--num_seed_masks", type=int, default=8)
60
+ return p.parse_args()
61
+
62
+
63
+ def permute_groups(rng: random.Random, values: Sequence[int], group_size: int) -> List[int]:
64
+ groups = [list(values[idx : idx + group_size]) for idx in range(0, len(values), group_size)]
65
+ rng.shuffle(groups)
66
+ out: List[int] = []
67
+ for group in groups:
68
+ rng.shuffle(group)
69
+ out.extend(group)
70
+ return out
71
+
72
+
73
+ def base_solved_grid() -> np.ndarray:
74
+ return np.asarray(
75
+ [[((rr * BOX_SIZE + rr // BOX_SIZE + cc) % GRID_SIZE) + 1 for cc in range(GRID_SIZE)] for rr in range(GRID_SIZE)],
76
+ dtype=int,
77
+ )
78
+
79
+
80
+ def row_major_empty_locs(grid: np.ndarray) -> List[Tuple[int, int]]:
81
+ return [(int(r), int(c)) for r, c in np.argwhere(np.asarray(grid, dtype=int) == 0).tolist()]
82
+
83
+
84
+ def make_prompt(grid: np.ndarray) -> str:
85
+ tuples = [f"({r + 1},{c + 1},{int(grid[r, c])})" for r in range(GRID_SIZE) for c in range(GRID_SIZE)]
86
+ return (
87
+ "9x9 Sudoku board encoded as (row,col,value) tuples in row-major order.\n"
88
+ "Value 0 means the cell is empty.\n"
89
+ + " ".join(tuples)
90
+ )
91
+
92
+
93
+ def legal_values(grid: np.ndarray, row: int, col: int) -> List[int]:
94
+ cell = int(row) * GRID_SIZE + int(col)
95
+ return [int(value) for value in ALL_VALUES if is_consistent_pair(grid, cell=cell, value=int(value), t=3, n=9)]
96
+
97
+
98
+ def count_solutions(grid: np.ndarray, *, limit: int = 2) -> int:
99
+ board = np.asarray(grid, dtype=int).copy()
100
+ solutions = 0
101
+
102
+ def backtrack() -> None:
103
+ nonlocal solutions
104
+ if solutions >= int(limit):
105
+ return
106
+ best_cell: Tuple[int, int] | None = None
107
+ best_values: List[int] | None = None
108
+ for rr, cc in row_major_empty_locs(board):
109
+ values = legal_values(board, rr, cc)
110
+ if not values:
111
+ return
112
+ if best_values is None or len(values) < len(best_values):
113
+ best_cell = (rr, cc)
114
+ best_values = values
115
+ if len(best_values) == 1:
116
+ break
117
+ if best_cell is None:
118
+ solutions += 1
119
+ return
120
+ rr, cc = best_cell
121
+ for value in best_values or []:
122
+ board[rr, cc] = int(value)
123
+ backtrack()
124
+ board[rr, cc] = 0
125
+ if solutions >= int(limit):
126
+ return
127
+
128
+ backtrack()
129
+ return int(solutions)
130
+
131
+
132
+ def propagate_stage(grid: np.ndarray, *, stage_i: int) -> Tuple[np.ndarray | None, int]:
133
+ board = np.asarray(grid, dtype=int).copy()
134
+ num_assignments = 0
135
+ while True:
136
+ chosen: Tuple[int, int, int] | None = None
137
+ for rr, cc in row_major_empty_locs(board):
138
+ values = stage_i_consistent_values(board, target_cell=(rr, cc), stage_i=int(stage_i))
139
+ if not values:
140
+ return None, num_assignments
141
+ if len(values) == 1:
142
+ chosen = (rr, cc, int(values[0]))
143
+ break
144
+ if chosen is None:
145
+ return board, num_assignments
146
+ rr, cc, value = chosen
147
+ board[rr, cc] = int(value)
148
+ num_assignments += 1
149
+
150
+
151
+ def build_difficulty_profile(puzzle: np.ndarray, solved: np.ndarray) -> DifficultyProfile | None:
152
+ stage1_board, stage1_steps = propagate_stage(puzzle, stage_i=1)
153
+ if stage1_board is None:
154
+ return None
155
+ stage2_board, stage2_steps = propagate_stage(puzzle, stage_i=2)
156
+ if stage2_board is None:
157
+ return None
158
+ stage3_board, stage3_steps = propagate_stage(puzzle, stage_i=3)
159
+ if stage3_board is None:
160
+ return None
161
+ return DifficultyProfile(
162
+ stage1_solved=bool(np.array_equal(stage1_board, solved)),
163
+ stage2_solved=bool(np.array_equal(stage2_board, solved)),
164
+ stage3_solved=bool(np.array_equal(stage3_board, solved)),
165
+ stage1_steps=int(stage1_steps),
166
+ stage2_steps=int(stage2_steps),
167
+ stage3_steps=int(stage3_steps),
168
+ )
169
+
170
+
171
+ def qualifies(profile: DifficultyProfile) -> bool:
172
+ return (not profile.stage1_solved) and (not profile.stage2_solved) and profile.stage3_solved
173
+
174
+
175
+ def build_puzzle_from_mask(solved: np.ndarray, mask_cells: Sequence[int]) -> np.ndarray:
176
+ puzzle = np.asarray(solved, dtype=int).copy()
177
+ for cell in mask_cells:
178
+ rr, cc = divmod(int(cell), GRID_SIZE)
179
+ puzzle[rr, cc] = 0
180
+ return puzzle
181
+
182
+
183
+ def sample_mask_cells(*, empties: int, rng: random.Random) -> tuple[int, ...]:
184
+ cells = list(range(GRID_SIZE * GRID_SIZE))
185
+ rng.shuffle(cells)
186
+ return tuple(sorted(int(cell) for cell in cells[: int(empties)]))
187
+
188
+
189
+ def greedy_find_seed_mask(
190
+ *,
191
+ empties: int,
192
+ max_attempts: int,
193
+ rng: random.Random,
194
+ progress_every: int,
195
+ ) -> Tuple[SeedMask | None, Dict[str, int]]:
196
+ solved = base_solved_grid()
197
+ attempts = 0
198
+ restarts = 0
199
+ while attempts < int(max_attempts):
200
+ restarts += 1
201
+ mask: List[int] = []
202
+ remaining = list(range(GRID_SIZE * GRID_SIZE))
203
+ rng.shuffle(remaining)
204
+ current_profile: DifficultyProfile | None = None
205
+
206
+ while len(mask) < int(empties) and attempts < int(max_attempts):
207
+ best_cell: int | None = None
208
+ best_profile: DifficultyProfile | None = None
209
+ best_score: Tuple[int, int, int] | None = None
210
+ candidate_cells = list(remaining[: min(len(remaining), 12)])
211
+ if not candidate_cells:
212
+ break
213
+
214
+ for cell in candidate_cells:
215
+ attempts += 1
216
+ trial_mask = tuple(sorted(mask + [int(cell)]))
217
+ puzzle = build_puzzle_from_mask(solved, trial_mask)
218
+ profile = build_difficulty_profile(puzzle, solved)
219
+ if profile is None or not profile.stage3_solved:
220
+ continue
221
+ score = (
222
+ int(not profile.stage2_solved),
223
+ int(not profile.stage1_solved),
224
+ int(profile.stage3_steps - profile.stage2_steps),
225
+ )
226
+ if best_score is None or score > best_score:
227
+ best_cell = int(cell)
228
+ best_profile = profile
229
+ best_score = score
230
+
231
+ if attempts == 1 or attempts % max(1, int(progress_every)) == 0:
232
+ print(
233
+ f"[search hard 9x9 masks] attempts={attempts} restarts={restarts} current_empties={len(mask)}",
234
+ flush=True,
235
+ )
236
+
237
+ if best_cell is None or best_profile is None:
238
+ break
239
+
240
+ mask.append(int(best_cell))
241
+ mask.sort()
242
+ remaining.remove(int(best_cell))
243
+ current_profile = best_profile
244
+
245
+ if len(mask) != int(empties) or current_profile is None:
246
+ continue
247
+
248
+ final_mask = tuple(sorted(int(cell) for cell in mask))
249
+ final_puzzle = build_puzzle_from_mask(solved, final_mask)
250
+ final_profile = build_difficulty_profile(final_puzzle, solved)
251
+ if final_profile is None or not qualifies(final_profile):
252
+ continue
253
+ if count_solutions(final_puzzle, limit=2) != 1:
254
+ continue
255
+ return SeedMask(mask_cells=final_mask, profile=final_profile), {
256
+ "attempts": int(attempts),
257
+ "restarts": int(restarts),
258
+ }
259
+
260
+ return None, {"attempts": int(attempts), "restarts": int(restarts)}
261
+
262
+
263
+ def random_symmetry(
264
+ rng: random.Random, *, solved: np.ndarray, mask_cells: Sequence[int]
265
+ ) -> Tuple[np.ndarray, tuple[int, ...]]:
266
+ digits = list(ALL_VALUES)
267
+ rng.shuffle(digits)
268
+ digit_map = {src: dst for src, dst in zip(ALL_VALUES, digits, strict=True)}
269
+ transformed = np.vectorize(lambda value: digit_map[int(value)], otypes=[int])(np.asarray(solved, dtype=int).copy())
270
+
271
+ row_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE)
272
+ col_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE)
273
+ inverse_row = {old: new for new, old in enumerate(row_order)}
274
+ inverse_col = {old: new for new, old in enumerate(col_order)}
275
+
276
+ transformed = transformed[row_order, :]
277
+ transformed = transformed[:, col_order]
278
+
279
+ transformed_cells: List[int] = []
280
+ for cell in mask_cells:
281
+ rr, cc = divmod(int(cell), GRID_SIZE)
282
+ new_r = int(inverse_row[int(rr)])
283
+ new_c = int(inverse_col[int(cc)])
284
+ transformed_cells.append(new_r * GRID_SIZE + new_c)
285
+
286
+ if rng.random() < 0.5:
287
+ transformed = transformed.T
288
+ transformed_cells = [int(cc) * GRID_SIZE + int(rr) for rr, cc in (divmod(cell, GRID_SIZE) for cell in transformed_cells)]
289
+
290
+ return np.asarray(transformed, dtype=int), tuple(sorted(int(cell) for cell in transformed_cells))
291
+
292
+
293
+ def make_example(solved: np.ndarray, mask_cells: Sequence[int], *, empties: int, profile: DifficultyProfile) -> Dict[str, object]:
294
+ puzzle = build_puzzle_from_mask(solved, mask_cells)
295
+ empty_locs_1based = [(rr + 1, cc + 1) for rr, cc in row_major_empty_locs(puzzle)]
296
+ target_triples_1based = [(rr + 1, cc + 1, int(solved[rr, cc])) for rr, cc in row_major_empty_locs(puzzle)]
297
+ completion_values = [int(value) for _, _, value in target_triples_1based]
298
+ return {
299
+ "prompt": make_prompt(puzzle),
300
+ "completion": json.dumps(completion_values, separators=(",", ":")),
301
+ "metadata": {
302
+ "grid_size": GRID_SIZE,
303
+ "box_size": BOX_SIZE,
304
+ "empties": int(empties),
305
+ "empty_locs_1based": empty_locs_1based,
306
+ "target_triples_1based": target_triples_1based,
307
+ "required_consistency_stage": 3,
308
+ "difficulty_profile": asdict(profile),
309
+ },
310
+ }
311
+
312
+
313
+ def search_seed_masks(
314
+ *,
315
+ num_seed_masks: int,
316
+ empties: int,
317
+ max_attempts: int,
318
+ seed: int,
319
+ progress_every: int,
320
+ ) -> Tuple[List[SeedMask], Dict[str, int]]:
321
+ rng = random.Random(int(seed))
322
+ seeds: List[SeedMask] = []
323
+ seen = set()
324
+ total_attempts = 0
325
+ total_restarts = 0
326
+
327
+ while len(seeds) < int(num_seed_masks) and total_attempts < int(max_attempts):
328
+ mask_seed, stats = greedy_find_seed_mask(
329
+ empties=int(empties),
330
+ max_attempts=max(1, int(max_attempts) - int(total_attempts)),
331
+ rng=rng,
332
+ progress_every=int(progress_every),
333
+ )
334
+ total_attempts += int(stats.get("attempts", 0))
335
+ total_restarts += int(stats.get("restarts", 0))
336
+ if mask_seed is None:
337
+ break
338
+ if mask_seed.mask_cells in seen:
339
+ continue
340
+ seen.add(mask_seed.mask_cells)
341
+ seeds.append(mask_seed)
342
+ print(
343
+ f"[search hard 9x9 masks] attempts={total_attempts} accepted={len(seeds)}/{num_seed_masks}",
344
+ flush=True,
345
+ )
346
+
347
+ stats = {
348
+ "attempts": int(total_attempts),
349
+ "restarts": int(total_restarts),
350
+ "accepted_seed_masks": int(len(seeds)),
351
+ }
352
+ return seeds, stats
353
+
354
+
355
+ def generate_examples(
356
+ *,
357
+ num_examples: int,
358
+ empties: int,
359
+ seed_masks: Sequence[SeedMask],
360
+ seed: int,
361
+ ) -> List[Dict[str, object]]:
362
+ if not seed_masks:
363
+ raise ValueError("seed_masks must not be empty")
364
+ rng = random.Random(int(seed) + 1)
365
+ solved = base_solved_grid()
366
+ rows: List[Dict[str, object]] = []
367
+ for idx in range(int(num_examples)):
368
+ seed_mask = seed_masks[idx % len(seed_masks)]
369
+ transformed_solved, transformed_mask = random_symmetry(
370
+ rng, solved=solved, mask_cells=seed_mask.mask_cells
371
+ )
372
+ rows.append(
373
+ make_example(
374
+ transformed_solved,
375
+ transformed_mask,
376
+ empties=int(empties),
377
+ profile=seed_mask.profile,
378
+ )
379
+ )
380
+ return rows
381
+
382
+
383
+ def write_jsonl(path: Path, rows: Iterable[Dict[str, object]]) -> None:
384
+ path.parent.mkdir(parents=True, exist_ok=True)
385
+ with path.open("w", encoding="utf-8") as f:
386
+ for row in rows:
387
+ f.write(json.dumps(row, separators=(",", ":")) + "\n")
388
+
389
+
390
+ def main() -> None:
391
+ args = parse_args()
392
+ total_needed = int(args.num_train_puzzles) + int(args.num_eval_puzzles)
393
+ num_seed_masks = min(max(1, int(args.num_seed_masks)), total_needed)
394
+ seed_masks, search_stats = search_seed_masks(
395
+ num_seed_masks=num_seed_masks,
396
+ empties=int(args.empties),
397
+ max_attempts=int(args.max_attempts),
398
+ seed=int(args.seed),
399
+ progress_every=int(args.progress_every),
400
+ )
401
+ if len(seed_masks) < num_seed_masks:
402
+ raise RuntimeError(
403
+ f"Only found {len(seed_masks)} qualifying seed masks out of requested {num_seed_masks}. "
404
+ f"Try increasing --max_attempts or reducing --num_seed_masks."
405
+ )
406
+ rows = generate_examples(
407
+ num_examples=total_needed,
408
+ empties=int(args.empties),
409
+ seed_masks=seed_masks,
410
+ seed=int(args.seed),
411
+ )
412
+
413
+ eval_rows = rows[: int(args.num_eval_puzzles)]
414
+ train_rows = rows[int(args.num_eval_puzzles) :]
415
+
416
+ train_output = Path(args.train_output).resolve()
417
+ eval_output = Path(args.eval_output).resolve()
418
+ manifest_output = Path(args.manifest_output).resolve()
419
+
420
+ write_jsonl(train_output, train_rows)
421
+ write_jsonl(eval_output, eval_rows)
422
+ manifest_output.parent.mkdir(parents=True, exist_ok=True)
423
+ manifest_output.write_text(
424
+ json.dumps(
425
+ {
426
+ "train_output": str(train_output),
427
+ "eval_output": str(eval_output),
428
+ "num_train_puzzles": int(len(train_rows)),
429
+ "num_eval_puzzles": int(len(eval_rows)),
430
+ "empties": int(args.empties),
431
+ "seed": int(args.seed),
432
+ "required_consistency_stage": 3,
433
+ "num_seed_masks": int(num_seed_masks),
434
+ "search_stats": search_stats,
435
+ },
436
+ indent=2,
437
+ sort_keys=True,
438
+ )
439
+ + "\n",
440
+ encoding="utf-8",
441
+ )
442
+ print(f"Wrote {len(train_rows)} train puzzles to {train_output}")
443
+ print(f"Wrote {len(eval_rows)} eval puzzles to {eval_output}")
444
+ print(f"Wrote manifest to {manifest_output}")
445
+
446
+
447
+ if __name__ == "__main__":
448
+ main()
hard_9x9_stage1_consistency_queue/README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage-1 Latent SFT Mode Sweep
2
+
3
+ This folder contains launchers for the 9x9 Sudoku curriculum experiments. The
4
+ stage-1 latent sweep launcher is:
5
+
6
+ ```bash
7
+ bash hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh
8
+ ```
9
+
10
+ The goal of this sweep is to compare the four latent implementations under the
11
+ same stage-1 SFT setup and measure which one gives the fastest useful
12
+ convergence. The main comparison should include training loss, held-out value
13
+ precision/recall, completion quality, wall-clock time, and GPU efficiency. In
14
+ particular, compare both loss vs. optimizer step and loss vs. elapsed time,
15
+ because some methods do more transformer forward passes per step.
16
+
17
+ ## Four Latent Modes
18
+
19
+ ### `residual`
20
+
21
+ The residual mode performs a dynamic latent hidden rollout, then projects the
22
+ difference between the latent hidden state and the base hidden state back into
23
+ the model hidden space. This projected delta is added to the base next-token
24
+ hidden state before computing logits. It is expressive, but it is slower because
25
+ the latent rollout requires repeated transformer passes.
26
+
27
+ ### `fixed_slots`
28
+
29
+ The fixed-slots mode learns a bank of trainable latent slot embeddings plus a
30
+ separate final readout slot. For each prediction, the model runs once on:
31
+
32
+ ```text
33
+ [prompt tokens, slot_1, ..., slot_k, final_slot]
34
+ ```
35
+
36
+ The next token is predicted from the hidden state at `final_slot`. This is a
37
+ parallel latent method: all latent slots are inserted at once, so it avoids the
38
+ recursive pass used by recurrent methods.
39
+
40
+ ### `recurrent_hidden`
41
+
42
+ The recurrent-hidden mode generates latent tokens dynamically from the current
43
+ example. It appends a hidden latent token, reruns the transformer, takes the new
44
+ last hidden state as the next latent token, and repeats for `num_cot_tokens`.
45
+ This is the closest to iterative hidden reasoning, but it is usually the
46
+ slowest because the latent steps are serial.
47
+
48
+ ### `latent_seeds`
49
+
50
+ The latent-seeds mode learns a bank of trainable seed embeddings. For each
51
+ prediction, the model runs once on:
52
+
53
+ ```text
54
+ [prompt tokens, seed_1, ..., seed_k]
55
+ ```
56
+
57
+ The next token is predicted from the hidden state at the last seed position.
58
+ Like fixed slots, this is parallel and avoids recursive transformer passes. The
59
+ main difference from `fixed_slots` is that there is no separate final readout
60
+ slot; the last seed position acts as the readout.
61
+
62
+ ## Experimental Strategy
63
+
64
+ Run all four modes in parallel on stage 1 with the same dataset, LoRA settings,
65
+ number of latent tokens, stopping rule, and evaluation set. The default launcher
66
+ splits an 8-GPU node into four two-GPU jobs:
67
+
68
+ ```text
69
+ residual -> GPUs 0,1
70
+ fixed_slots -> GPUs 2,3
71
+ recurrent_hidden -> GPUs 4,5
72
+ latent_seeds -> GPUs 6,7
73
+ ```
74
+
75
+ Use the results to decide which one or two methods should be promoted to deeper
76
+ curriculum stages. The expected practical tradeoff is that `fixed_slots` and
77
+ `latent_seeds` should be much faster per wall-clock time, while `residual` and
78
+ `recurrent_hidden` test more iterative, example-dependent latent computation.
79
+
80
+ ## Warm Baseline Stages 1-3 Pipeline
81
+
82
+ The full warm-baseline launcher is:
83
+
84
+ ```bash
85
+ STAGE1_BASELINE_ADAPTER_DIR=/path/to/warmed/stage1/baseline/checkpoint \
86
+ bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
87
+ ```
88
+
89
+ It runs all four latent modes in parallel, two GPUs per mode:
90
+
91
+ ```text
92
+ residual -> GPUs 0,1
93
+ fixed_slots -> GPUs 2,3
94
+ recurrent_hidden -> GPUs 4,5
95
+ latent_seeds -> GPUs 6,7
96
+ ```
97
+
98
+ For each mode, the intended sequence is:
99
+
100
+ ```text
101
+ stage1 latent SFT
102
+ -> stage1 latent GRPO
103
+ -> stage2 baseline warm-up SFT
104
+ -> stage2 latent SFT
105
+ -> stage2 latent GRPO
106
+ -> stage3 baseline warm-up SFT
107
+ -> stage3 latent SFT
108
+ -> stage3 latent GRPO
109
+ ```
110
+
111
+ The run is capped by fixed step budgets by default (`1000` SFT steps and `500`
112
+ GRPO steps per phase) and can stop early when the configured solve-rate target
113
+ is reached. The current 1.5B run uses the warmed Stage-1 baseline adapter from
114
+ `hard_9x9_20empty_baseline_1p5b_warmup`.
115
+
116
+ See `warm_baseline_all_latent_modes_stages123_results.md` for the current
117
+ solve-rate snapshot from the ongoing full-pipeline run.
hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
8
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
9
+ DATA_PATH="${DATA_PATH:-${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl}"
10
+ GPU_ID="${GPU_ID:-0}"
11
+ NUM_COT="${NUM_COT:-3}"
12
+ MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}"
13
+ LIMIT_ROWS="${LIMIT_ROWS:-1}"
14
+ TRAIN_STEPS="${TRAIN_STEPS:-60}"
15
+ LR="${LR:-1e-1}"
16
+ LORA_R="${LORA_R:-32}"
17
+ LORA_ALPHA="${LORA_ALPHA:-64}"
18
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
19
+
20
+ export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
21
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${GPU_ID}}"
22
+
23
+ exec "${PYTHON_BIN}" - <<'PY'
24
+ import os
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+
30
+ from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row
31
+ from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import (
32
+ attach_fixed_latent_slot_modules,
33
+ fixed_slot_next_token_logits_from_ids,
34
+ load_jsonl_rows,
35
+ load_trainable_adapter,
36
+ pick_dtype,
37
+ sample_fixed_slot_completion,
38
+ unwrap_backbone,
39
+ )
40
+ from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt
41
+ from multi_output_cell_policy.shared_multi_output_policy import build_supervised_completion
42
+
43
+
44
+ def env_int(name: str, default: int) -> int:
45
+ return int(os.environ.get(name, str(default)))
46
+
47
+
48
+ def env_float(name: str, default: float) -> float:
49
+ return float(os.environ.get(name, str(default)))
50
+
51
+
52
+ model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
53
+ data_path = os.environ.get("DATA_PATH", "data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl")
54
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
55
+ num_cot = env_int("NUM_COT", 5)
56
+ max_latent_slots = env_int("MAX_LATENT_SLOTS", 8)
57
+ limit_rows = env_int("LIMIT_ROWS", 1)
58
+ train_steps = env_int("TRAIN_STEPS", 60)
59
+ lr = env_float("LR", 1e-1)
60
+ lora_r = env_int("LORA_R", 32)
61
+ lora_alpha = env_int("LORA_ALPHA", 64)
62
+ lora_dropout = env_float("LORA_DROPOUT", 0.05)
63
+
64
+ rows = load_jsonl_rows(data_path, limit_rows=limit_rows)
65
+ ex = build_cell_examples_from_row(rows[0])[0]
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
68
+ if tokenizer.pad_token_id is None:
69
+ tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
70
+
71
+ prompt = build_multi_output_cell_prompt(
72
+ ex.grid,
73
+ target_cell=ex.target_cell,
74
+ stage_i=1,
75
+ tokenizer=tokenizer,
76
+ turn_idx=ex.turn_idx,
77
+ total_turns=ex.total_turns,
78
+ prev_output_flag=None,
79
+ total_empties_hint=20,
80
+ )
81
+ target_text = build_supervised_completion(ex, stage_i=1) + (tokenizer.eos_token or "")
82
+ print("target_text", target_text)
83
+
84
+ base = AutoModelForCausalLM.from_pretrained(
85
+ model_name,
86
+ torch_dtype=pick_dtype(),
87
+ low_cpu_mem_usage=True,
88
+ )
89
+ model = load_trainable_adapter(base, "", lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
90
+ attach_fixed_latent_slot_modules(
91
+ model,
92
+ hidden_size=int(unwrap_backbone(model).config.hidden_size),
93
+ max_latent_slots=max_latent_slots,
94
+ )
95
+ if hasattr(model, "config"):
96
+ model.config.use_cache = False
97
+ backbone = unwrap_backbone(model)
98
+ if hasattr(backbone, "config"):
99
+ backbone.config.use_cache = False
100
+ model.to(device)
101
+
102
+ for p in model.parameters():
103
+ p.requires_grad = False
104
+ model.fixed_latent_slots.requires_grad_(True)
105
+ model.fixed_final_slot_embed.requires_grad_(True)
106
+ optimizer = torch.optim.AdamW([model.fixed_latent_slots, model.fixed_final_slot_embed], lr=lr)
107
+
108
+ prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
109
+ completion_ids = tokenizer(target_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
110
+
111
+
112
+ @torch.no_grad()
113
+ def sample_now(tag: str) -> None:
114
+ model.eval()
115
+ attn = torch.ones_like(prompt_ids, device=device)
116
+ logits = fixed_slot_next_token_logits_from_ids(model, prompt_ids, attn, num_cot)
117
+ probs = torch.softmax(logits[0].float(), dim=-1)
118
+ top_probs, top_ids = torch.topk(probs, k=5)
119
+ out_ids = sample_fixed_slot_completion(
120
+ model,
121
+ tokenizer,
122
+ prompt_ids,
123
+ attn,
124
+ num_cot_tokens=num_cot,
125
+ max_new_tokens=12,
126
+ do_sample=False,
127
+ )
128
+ top_next = [(tokenizer.decode([int(i)]), round(float(p), 4)) for i, p in zip(top_ids.tolist(), top_probs.tolist())]
129
+ print(tag, tokenizer.decode(out_ids[0], skip_special_tokens=True), "top_next", top_next)
130
+
131
+
132
+ sample_now("before:")
133
+
134
+ for step in range(1, train_steps + 1):
135
+ model.train()
136
+ cur_ids = prompt_ids
137
+ cur_mask = torch.ones_like(prompt_ids, device=device)
138
+ losses = []
139
+ for idx in range(int(completion_ids.shape[1])):
140
+ logits = fixed_slot_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot)
141
+ target = completion_ids[:, idx]
142
+ losses.append(F.cross_entropy(logits.float(), target, reduction="mean"))
143
+ cur_ids = torch.cat([cur_ids, completion_ids[:, idx : idx + 1]], dim=1)
144
+ cur_mask = torch.cat(
145
+ [
146
+ cur_mask,
147
+ torch.ones((cur_mask.shape[0], 1), dtype=cur_mask.dtype, device=cur_mask.device),
148
+ ],
149
+ dim=1,
150
+ )
151
+ loss = torch.stack(losses).mean()
152
+ optimizer.zero_grad(set_to_none=True)
153
+ loss.backward()
154
+ optimizer.step()
155
+ if step == 1 or step % 10 == 0 or step == train_steps:
156
+ print(f"step={step} loss={float(loss.item()):.6f}")
157
+ sample_now(f"after_step_{step}:")
158
+ PY
hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Full 10-empty baseline pipeline, matching the successful 7-empty procedure:
3
+ # 1) Stage-1 SFT to value precision/recall >= 0.98
4
+ # 2) Stage-1 GRPO
5
+ # 3) Stage-2 SFT
6
+ # 4) Stage-2 GRPO
7
+ # 5) Stage-3 SFT
8
+ # 6) Stage-3 GRPO
9
+ #
10
+ # This is a wrapper around:
11
+ # - launch_10empty_sft_stage1_98p.sh
12
+ # - launch_10empty_post_s1sft_stages123_value98.sh
13
+ #
14
+ set -euo pipefail
15
+
16
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
17
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
18
+
19
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
20
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_full_stages123_value98}"
21
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
22
+
23
+ SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_10empty_sft_stage1_98p.sh"
24
+ POST_S1_SCRIPT="${SCRIPT_DIR}/launch_10empty_post_s1sft_stages123_value98.sh"
25
+ S1_DIR="${OUTPUT_ROOT}/10empty/stage01_sft_i1_10empty_sft98"
26
+
27
+ latest_checkpoint_in_dir() {
28
+ local d="$1"
29
+ shopt -s nullglob
30
+ local checkpoints=("${d}"/checkpoint-step-*)
31
+ shopt -u nullglob
32
+ if (( ${#checkpoints[@]} == 0 )); then
33
+ printf ''
34
+ return 1
35
+ fi
36
+ set +o pipefail
37
+ printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
38
+ set -o pipefail
39
+ }
40
+
41
+ printf '=== 10-empty full baseline pipeline (stage1 SFT -> stages123) ===\n'
42
+ printf 'run_tag=%s\n' "${RUN_TAG}"
43
+ printf 'output_root=%s\n' "${OUTPUT_ROOT}"
44
+
45
+ OUTPUT_DIR="${S1_DIR}" \
46
+ RUN_TAG="${RUN_TAG}" \
47
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
48
+ "${SFT_STAGE1_SCRIPT}"
49
+
50
+ STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")"
51
+ if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
52
+ printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2
53
+ exit 1
54
+ fi
55
+
56
+ printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
57
+
58
+ STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \
59
+ RUN_TAG="${RUN_TAG}" \
60
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
61
+ OUTPUT_ROOT="${OUTPUT_ROOT}" \
62
+ "${POST_S1_SCRIPT}"
hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run AFTER stage-1 SFT finishes (10-empty). Order:
3
+ # 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
4
+ # 2) Stage-2 SFT (init = stage-1 GRPO adapter)
5
+ # 3) Stage-2 GRPO (init = stage-2 SFT adapter)
6
+ # 4) Stage-3 SFT (init = stage-2 GRPO adapter)
7
+ # 5) Stage-3 GRPO (init = stage-3 SFT adapter)
8
+ #
9
+ # Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
10
+ # are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
11
+ # very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
12
+ # (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
13
+ #
14
+ # Required (full pipeline from stage-1 SFT):
15
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
16
+ #
17
+ # Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
18
+ # RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_10empty
19
+ # (OUTPUT_ROOT defaults to dirname of that dir.)
20
+ #
21
+ # Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
22
+ # START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_10empty
23
+ #
24
+ # Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
25
+ # START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_10empty
26
+ #
27
+ # Optional:
28
+ # VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
29
+ # TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
30
+ #
31
+ set -euo pipefail
32
+
33
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
34
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
35
+
36
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
37
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
38
+ GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
39
+
40
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
41
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
42
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
43
+ WANDB_MODE="${WANDB_MODE:-online}"
44
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
45
+
46
+ EMPTIES=10
47
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
48
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
49
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
50
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
51
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
52
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
53
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
54
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
55
+
56
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_stages123_value98}"
57
+ START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
58
+ START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
59
+ RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
60
+
61
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
62
+ if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
63
+ printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
64
+ exit 1
65
+ fi
66
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
67
+ elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
68
+ if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
69
+ printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
70
+ exit 1
71
+ fi
72
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
73
+ elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
74
+ if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
75
+ printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
76
+ exit 1
77
+ fi
78
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
79
+ else
80
+ if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
81
+ printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
82
+ exit 1
83
+ fi
84
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
85
+ fi
86
+
87
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
88
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
89
+
90
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
91
+
92
+ mkdir -p "${OUTPUT_ROOT}"
93
+
94
+ latest_sft_step_ckpt() {
95
+ local d="$1"
96
+ shopt -s nullglob
97
+ local cks=("${d}"/checkpoint-step-*)
98
+ shopt -u nullglob
99
+ if (( ${#cks[@]} == 0 )); then
100
+ printf ''
101
+ return 1
102
+ fi
103
+ set +o pipefail
104
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
105
+ set -o pipefail
106
+ }
107
+
108
+ resolve_grpo_adapter() {
109
+ local d="$1"
110
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
111
+ printf '%s\n' "${d}"
112
+ return 0
113
+ fi
114
+ local best="" step=-1
115
+ shopt -s nullglob
116
+ local c
117
+ for c in "${d}"/checkpoint-*; do
118
+ [[ -d "${c}" ]] || continue
119
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
120
+ local n
121
+ n="${c##*checkpoint-}"
122
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
123
+ step=$((10#${n}))
124
+ best="${c}"
125
+ fi
126
+ done
127
+ shopt -u nullglob
128
+ if [[ -n "${best}" ]]; then
129
+ printf '%s\n' "${best}"
130
+ return 0
131
+ fi
132
+ printf ''
133
+ return 1
134
+ }
135
+
136
+ GC_FLAGS=()
137
+ if [[ "${USE_GC:-0}" == "1" ]]; then
138
+ GC_FLAGS+=(--enable_gradient_checkpointing)
139
+ fi
140
+
141
+ run_sft() {
142
+ local stage="$1"
143
+ local init_adapter="$2"
144
+ local out_dir="$3"
145
+ local lr="$4"
146
+ mkdir -p "${out_dir}"
147
+ printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
148
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
149
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
150
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
151
+ --train_jsonl "${train_jsonl}" \
152
+ --eval_jsonl "${eval_jsonl}" \
153
+ --output_dir "${out_dir}" \
154
+ --cache_dir "${ROOT}/.hf_cache" \
155
+ --init_adapter_dir "${init_adapter}" \
156
+ --seed 0 \
157
+ --gpu_id 0 \
158
+ --stage_i "${stage}" \
159
+ --total_empties_hint "${EMPTIES}" \
160
+ --per_device_train_batch_size 16 \
161
+ --gradient_accumulation_steps 2 \
162
+ --num_epochs "${SFT_NUM_EPOCHS}" \
163
+ --learning_rate "${lr}" \
164
+ --max_grad_norm 1.0 \
165
+ "${GC_FLAGS[@]}" \
166
+ --logging_steps 20 \
167
+ --eval_steps 250 \
168
+ --save_steps 200 \
169
+ --eval_rows "${EVAL_PUZZLES}" \
170
+ --max_completion_length 24 \
171
+ --limit_train_rows "${TRAIN_PUZZLES}" \
172
+ --lora_r 32 \
173
+ --lora_alpha 64 \
174
+ --lora_dropout 0.05 \
175
+ --eval_value_precision_stop "${VALUE_TARGET}" \
176
+ --eval_value_recall_stop "${VALUE_TARGET}" \
177
+ --eval_exact_set_match_stop 0 \
178
+ --eval_solve_rate_stop 0 \
179
+ --min_steps_before_stop 50 \
180
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
181
+ --max_steps "${SFT_MAX_STEPS}" \
182
+ --use_wandb \
183
+ --wandb_project "sudoku-multi-output-sft" \
184
+ --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
185
+ --wandb_mode "${WANDB_MODE}" \
186
+ --wandb_entity "${WANDB_ENTITY}"
187
+ }
188
+
189
+ run_grpo() {
190
+ local stage="$1"
191
+ local init_adapter="$2"
192
+ local out_dir="$3"
193
+ mkdir -p "${out_dir}"
194
+ printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
195
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
196
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
197
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
198
+ --train_jsonl "${train_jsonl}" \
199
+ --eval_jsonl "${eval_jsonl}" \
200
+ --output_dir "${out_dir}" \
201
+ --cache_dir "${ROOT}/.hf_cache" \
202
+ --init_adapter_dir "${init_adapter}" \
203
+ --seed 0 \
204
+ --gpu_id 0 \
205
+ --stage_i "${stage}" \
206
+ --total_empties_hint "${EMPTIES}" \
207
+ --per_device_train_batch_size 8 \
208
+ --gradient_accumulation_steps 2 \
209
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
210
+ --learning_rate 1e-6 \
211
+ --logging_steps 20 \
212
+ --save_steps 200 \
213
+ --eval_steps 500 \
214
+ --eval_rows "${EVAL_PUZZLES}" \
215
+ --num_generations 4 \
216
+ --max_prompt_length 1024 \
217
+ --max_completion_length 24 \
218
+ --beta 0.0 \
219
+ --enable_gradient_checkpointing \
220
+ --limit_train_rows "${TRAIN_PUZZLES}" \
221
+ --reward_good_value 1.25 \
222
+ --penalty_bad_value 1.0 \
223
+ --penalty_malformed 4.0 \
224
+ --penalty_empty 0.5 \
225
+ --penalty_singleton 1.5 \
226
+ --eval_value_precision_stop "${VALUE_TARGET}" \
227
+ --eval_value_recall_stop "${VALUE_TARGET}" \
228
+ --eval_solve_rate_stop 0 \
229
+ --min_steps_before_stop 50 \
230
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
231
+ --max_steps "${GRPO_MAX_STEPS}" \
232
+ --use_wandb \
233
+ --wandb_project "sudoku-multi-output-grpo" \
234
+ --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
235
+ --wandb_mode "${WANDB_MODE}" \
236
+ --wandb_entity "${WANDB_ENTITY}"
237
+ }
238
+
239
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
240
+ printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
241
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
242
+ exit 1
243
+ fi
244
+
245
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
246
+ printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
247
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
248
+ S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
249
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
250
+ if [[ -z "${CKPT_S2}" ]]; then
251
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
252
+ exit 1
253
+ fi
254
+ printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
255
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
256
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
257
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
258
+ if [[ -z "${A2}" ]]; then
259
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
260
+ exit 1
261
+ fi
262
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
263
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
264
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
265
+ if [[ -z "${CKPT_S3}" ]]; then
266
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
267
+ exit 1
268
+ fi
269
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
270
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
271
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
272
+ if [[ -z "${A3}" ]]; then
273
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
274
+ exit 1
275
+ fi
276
+ printf '\nAll phases finished (started at stage-2 GRPO).\n'
277
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
278
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
279
+ exit 0
280
+ fi
281
+
282
+ if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
283
+ printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
284
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
285
+ A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
286
+ if [[ -z "${A2}" ]]; then
287
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
288
+ exit 1
289
+ fi
290
+ printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
291
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
292
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
293
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
294
+ if [[ -z "${CKPT_S3}" ]]; then
295
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
296
+ exit 1
297
+ fi
298
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
299
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
300
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
301
+ if [[ -z "${A3}" ]]; then
302
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
303
+ exit 1
304
+ fi
305
+ printf '\nAll phases finished (started after stage-2 GRPO).\n'
306
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
307
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
308
+ exit 0
309
+ fi
310
+
311
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
312
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
313
+ printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
314
+ else
315
+ printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
316
+ fi
317
+ printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
318
+ "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
319
+
320
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
321
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
322
+ A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
323
+ else
324
+ run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
325
+ A1="$(resolve_grpo_adapter "${G1_DIR}")"
326
+ fi
327
+ if [[ -z "${A1}" ]]; then
328
+ printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
329
+ exit 1
330
+ fi
331
+ printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
332
+
333
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
334
+ run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
335
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
336
+ if [[ -z "${CKPT_S2}" ]]; then
337
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
338
+ exit 1
339
+ fi
340
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
341
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
342
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
343
+ if [[ -z "${A2}" ]]; then
344
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
345
+ exit 1
346
+ fi
347
+
348
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
349
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
350
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
351
+ if [[ -z "${CKPT_S3}" ]]; then
352
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
353
+ exit 1
354
+ fi
355
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
356
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
357
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
358
+ if [[ -z "${A3}" ]]; then
359
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
360
+ exit 1
361
+ fi
362
+
363
+ printf '\nAll phases finished.\n'
364
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
365
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Stage-1 SFT only for 10-empty: train until eval value_precision AND value_recall
3
+ # both reach 0.98 (or max_steps / optional wall clock). Use the resulting
4
+ # checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for
5
+ # launch_10empty_post_s1sft_stages123_value98.sh.
6
+ #
7
+ # Fresh LoRA on base model:
8
+ # ./launch_10empty_sft_stage1_98p.sh
9
+ #
10
+ # Continue from a prior SFT checkpoint:
11
+ # INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_10empty_sft_stage1_98p.sh
12
+ #
13
+ set -euo pipefail
14
+
15
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
16
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
17
+
18
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
19
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
20
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
21
+
22
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
23
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
24
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
25
+ WANDB_MODE="${WANDB_MODE:-online}"
26
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
27
+
28
+ EMPTIES=10
29
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
30
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
31
+ SFT_TARGET="${SFT_TARGET:-0.98}"
32
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
33
+ MAX_STEPS="${MAX_STEPS:-30000}"
34
+
35
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_sft98_stage1}"
36
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}"
37
+
38
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
39
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
40
+
41
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
42
+
43
+ if [[ ! -f "${train_jsonl}" ]]; then
44
+ mkdir -p "$(dirname "${train_jsonl}")"
45
+ printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
46
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
47
+ fi
48
+ if [[ ! -f "${eval_jsonl}" ]]; then
49
+ mkdir -p "$(dirname "${eval_jsonl}")"
50
+ printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
51
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
52
+ fi
53
+
54
+ mkdir -p "${OUTPUT_DIR}"
55
+
56
+ INIT_FLAGS=()
57
+ if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
58
+ INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
59
+ printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
60
+ fi
61
+
62
+ GC_FLAGS=()
63
+ if [[ "${USE_GC:-0}" == "1" ]]; then
64
+ GC_FLAGS+=(--enable_gradient_checkpointing)
65
+ printf 'NOTE: USE_GC=1 — slower, less VRAM.\n'
66
+ fi
67
+
68
+ if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
69
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
70
+ else
71
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
72
+ fi
73
+ printf 'Output: %s\n' "${OUTPUT_DIR}"
74
+
75
+ exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
76
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
77
+ --train_jsonl "${train_jsonl}" \
78
+ --eval_jsonl "${eval_jsonl}" \
79
+ --output_dir "${OUTPUT_DIR}" \
80
+ --cache_dir "${ROOT}/.hf_cache" \
81
+ "${INIT_FLAGS[@]}" \
82
+ --seed 0 \
83
+ --gpu_id 0 \
84
+ --stage_i 1 \
85
+ --total_empties_hint "${EMPTIES}" \
86
+ --per_device_train_batch_size 16 \
87
+ --gradient_accumulation_steps 2 \
88
+ --num_epochs 64.0 \
89
+ --learning_rate 2e-4 \
90
+ --max_grad_norm 1.0 \
91
+ "${GC_FLAGS[@]}" \
92
+ --logging_steps 20 \
93
+ --eval_steps 250 \
94
+ --save_steps 100 \
95
+ --eval_rows "${EVAL_PUZZLES}" \
96
+ --max_completion_length 24 \
97
+ --limit_train_rows "${TRAIN_PUZZLES}" \
98
+ --lora_r 32 \
99
+ --lora_alpha 64 \
100
+ --lora_dropout 0.05 \
101
+ --eval_value_precision_stop "${SFT_TARGET}" \
102
+ --eval_value_recall_stop "${SFT_TARGET}" \
103
+ --eval_exact_set_match_stop 0 \
104
+ --eval_solve_rate_stop 0 \
105
+ --min_steps_before_stop 50 \
106
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
107
+ --max_steps "${MAX_STEPS}" \
108
+ --use_wandb \
109
+ --wandb_project "sudoku-multi-output-sft" \
110
+ --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
111
+ --wandb_mode "${WANDB_MODE}" \
112
+ --wandb_entity "${WANDB_ENTITY}"
hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Stage-1 fixed-slot latent SFT for 20-empty: train until eval value_precision AND
3
+ # value_recall both reach 0.98. This uses prompt + z1 + final_slot during stage 1,
4
+ # while still updating LoRA weights so the transformer can learn how to use z1.
5
+ #
6
+ # Fresh run:
7
+ # ./launch_20empty_fixed_slot_sft_stage1_98p.sh
8
+ #
9
+ # Warm-start from a prior checkpoint:
10
+ # INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_fixed_slot_sft_stage1_98p.sh
11
+ #
12
+ set -euo pipefail
13
+
14
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
15
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
16
+
17
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
18
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
19
+ SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
20
+
21
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6}"
22
+ NUM_PROCESSES="${NUM_PROCESSES:-7}"
23
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
24
+ WANDB_MODE="${WANDB_MODE:-online}"
25
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
26
+
27
+ EMPTIES=20
28
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
29
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
30
+ SFT_TARGET="${SFT_TARGET:-0.98}"
31
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
32
+ MAX_STEPS="${MAX_STEPS:-30000}"
33
+
34
+ LORA_R="${LORA_R:-32}"
35
+ LORA_ALPHA="${LORA_ALPHA:-64}"
36
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
37
+ MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-3}"
38
+
39
+ PER_DEVICE_TRAIN_BATCH_SIZE="${PER_DEVICE_TRAIN_BATCH_SIZE:-4}"
40
+ GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-8}"
41
+
42
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_fixed_slot_sft98_stage1}"
43
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_fixed_slot_sft98_i1_${EMPTIES}empty}"
44
+
45
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
46
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
47
+
48
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
49
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
50
+
51
+ if [[ ! -f "${train_jsonl}" ]]; then
52
+ mkdir -p "$(dirname "${train_jsonl}")"
53
+ printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
54
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
55
+ fi
56
+ if [[ ! -f "${eval_jsonl}" ]]; then
57
+ mkdir -p "$(dirname "${eval_jsonl}")"
58
+ printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
59
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
60
+ fi
61
+
62
+ mkdir -p "${OUTPUT_DIR}"
63
+
64
+ INIT_FLAGS=()
65
+ if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
66
+ INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
67
+ printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
68
+ fi
69
+
70
+ GC_FLAGS=()
71
+ if [[ "${USE_GC:-1}" == "1" ]]; then
72
+ GC_FLAGS+=(--enable_gradient_checkpointing)
73
+ printf 'NOTE: USE_GC=1 - slower, less VRAM.\n'
74
+ fi
75
+
76
+ if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
77
+ printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
78
+ else
79
+ printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
80
+ fi
81
+ printf 'Output: %s\n' "${OUTPUT_DIR}"
82
+ printf 'LoRA: r=%s alpha=%s dropout=%s | latent_mode=fixed_slots | active_z=1 | max_latent_slots=%s\n' "${LORA_R}" "${LORA_ALPHA}" "${LORA_DROPOUT}" "${MAX_LATENT_SLOTS}"
83
+ printf 'DDP: visible_gpus=%s nproc=%s | batch/device=%s grad_accum=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" "${PER_DEVICE_TRAIN_BATCH_SIZE}" "${GRADIENT_ACCUMULATION_STEPS}"
84
+
85
+ exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
86
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
87
+ --train_jsonl "${train_jsonl}" \
88
+ --eval_jsonl "${eval_jsonl}" \
89
+ --output_dir "${OUTPUT_DIR}" \
90
+ --cache_dir "${ROOT}/.hf_cache" \
91
+ "${INIT_FLAGS[@]}" \
92
+ --seed 0 \
93
+ --gpu_id 0 \
94
+ --stage_i 1 \
95
+ --num_cot_tokens 1 \
96
+ --latent_mode fixed_slots \
97
+ --max_latent_slots "${MAX_LATENT_SLOTS}" \
98
+ --total_empties_hint "${EMPTIES}" \
99
+ --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE}" \
100
+ --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS}" \
101
+ --num_epochs 64.0 \
102
+ --learning_rate 2e-4 \
103
+ --max_grad_norm 1.0 \
104
+ "${GC_FLAGS[@]}" \
105
+ --logging_steps 20 \
106
+ --eval_steps 250 \
107
+ --save_steps 100 \
108
+ --eval_rows "${EVAL_PUZZLES}" \
109
+ --max_completion_length 24 \
110
+ --limit_train_rows "${TRAIN_PUZZLES}" \
111
+ --lora_r "${LORA_R}" \
112
+ --lora_alpha "${LORA_ALPHA}" \
113
+ --lora_dropout "${LORA_DROPOUT}" \
114
+ --eval_value_precision_stop "${SFT_TARGET}" \
115
+ --eval_value_recall_stop "${SFT_TARGET}" \
116
+ --eval_exact_set_match_stop 0 \
117
+ --eval_solve_rate_stop 0 \
118
+ --min_steps_before_stop 50 \
119
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
120
+ --max_steps "${MAX_STEPS}" \
121
+ --use_wandb \
122
+ --wandb_project "sudoku-fixed-slot-sft" \
123
+ --wandb_run_name "${WANDB_RUN_NAME:-stage01_fixed_slot_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
124
+ --wandb_mode "${WANDB_MODE}" \
125
+ --wandb_entity "${WANDB_ENTITY}"
hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Full 20-empty baseline pipeline, matching the successful 10-empty procedure:
3
+ # 1) Stage-1 SFT to value precision/recall >= 0.98
4
+ # 2) Stage-1 GRPO
5
+ # 3) Stage-2 SFT
6
+ # 4) Stage-2 GRPO
7
+ # 5) Stage-3 SFT
8
+ # 6) Stage-3 GRPO
9
+ #
10
+ # This is a wrapper around:
11
+ # - launch_20empty_sft_stage1_98p.sh
12
+ # - launch_20empty_post_s1sft_stages123_value98.sh
13
+ #
14
+ set -euo pipefail
15
+
16
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
17
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
18
+
19
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
20
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_full_stages123_value98}"
21
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
22
+
23
+ SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_20empty_sft_stage1_98p.sh"
24
+ POST_S1_SCRIPT="${SCRIPT_DIR}/launch_20empty_post_s1sft_stages123_value98.sh"
25
+ S1_DIR="${OUTPUT_ROOT}/20empty/stage01_sft_i1_20empty_sft98"
26
+
27
+ latest_checkpoint_in_dir() {
28
+ local d="$1"
29
+ shopt -s nullglob
30
+ local checkpoints=("${d}"/checkpoint-step-*)
31
+ shopt -u nullglob
32
+ if (( ${#checkpoints[@]} == 0 )); then
33
+ printf ''
34
+ return 1
35
+ fi
36
+ set +o pipefail
37
+ printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
38
+ set -o pipefail
39
+ }
40
+
41
+ printf '=== 20-empty full baseline pipeline (stage1 SFT -> stages123) ===\n'
42
+ printf 'run_tag=%s\n' "${RUN_TAG}"
43
+ printf 'output_root=%s\n' "${OUTPUT_ROOT}"
44
+
45
+ OUTPUT_DIR="${S1_DIR}" \
46
+ RUN_TAG="${RUN_TAG}" \
47
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
48
+ "${SFT_STAGE1_SCRIPT}"
49
+
50
+ STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")"
51
+ if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
52
+ printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2
53
+ exit 1
54
+ fi
55
+
56
+ printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
57
+
58
+ STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \
59
+ RUN_TAG="${RUN_TAG}" \
60
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
61
+ OUTPUT_ROOT="${OUTPUT_ROOT}" \
62
+ "${POST_S1_SCRIPT}"
hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Latent recurrent-hidden (Coconut-style) pipeline for 20-empty Sudoku.
3
+ #
4
+ # Per-stage latent token count grows with curriculum:
5
+ # stage 1 -> num_cot_tokens = 1
6
+ # stage 2 -> num_cot_tokens = 2
7
+ # stage 3 -> num_cot_tokens = 3
8
+ #
9
+ # Pipeline:
10
+ # Stage 1 SFT (cot=1, fresh LoRA + random latent state)
11
+ # -> Stage 1 GRPO (cot=1)
12
+ # -> Stage 2 SFT (cot=2)
13
+ # -> Stage 2 GRPO (cot=2)
14
+ # -> Stage 3 SFT (cot=3)
15
+ # -> Stage 3 GRPO (cot=3)
16
+ #
17
+ # Mirrors the hyperparameters of the successful 20-empty recurrent-hidden stage-1
18
+ # run (bs=8 per-device, gradient accumulation 2, gradient checkpointing ON).
19
+ #
20
+ # Optional overrides:
21
+ # STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
22
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX
23
+ # VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=...
24
+ set -euo pipefail
25
+
26
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
27
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
28
+
29
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
30
+ SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
31
+ GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py"
32
+
33
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
34
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
35
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
36
+ WANDB_MODE="${WANDB_MODE:-online}"
37
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
38
+
39
+ LATENT_MODE="recurrent_hidden"
40
+ EMPTIES=20
41
+ TAG_SUFFIX="latent_recurrent"
42
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
43
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
44
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
45
+ # Per-phase early-stop bars. Default behavior preserved: both phases use
46
+ # VALUE_TARGET unless explicitly overridden. Recommended: SFT_VALUE_TARGET=0.95
47
+ # (let SFT do bulk learning quickly) and GRPO_VALUE_TARGET=0.98 (let GRPO push
48
+ # the last few percent of value precision/recall).
49
+ SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}"
50
+ GRPO_VALUE_TARGET="${GRPO_VALUE_TARGET:-${VALUE_TARGET}}"
51
+ MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
52
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
53
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
54
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
55
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
56
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
57
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
58
+ # -1 resolves in code to hidden_size, and alpha=-1 resolves to 2 * resolved rank.
59
+ LORA_R="${LORA_R:--1}"
60
+ LORA_ALPHA="${LORA_ALPHA:--1}"
61
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
62
+ STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
63
+ SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
64
+ SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
65
+ GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-8}"
66
+ GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}"
67
+
68
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_recurrent_stages123_value98}"
69
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
70
+ STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
71
+ STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
72
+ # When set, skip both Stage-1 SFT and Stage-1 GRPO and use this adapter
73
+ # directly as the init for Stage-2 SFT. Useful for resuming after a Stage-1
74
+ # GRPO post-training eval hangs but the LoRA adapter is already on disk.
75
+ STAGE1_GRPO_ADAPTER_DIR="${STAGE1_GRPO_ADAPTER_DIR:-}"
76
+ STAGE2_SFT_ADAPTER_DIR="${STAGE2_SFT_ADAPTER_DIR:-}"
77
+ STAGE2_GRPO_ADAPTER_DIR="${STAGE2_GRPO_ADAPTER_DIR:-}"
78
+ # When set, skip Stage-3 SFT and use this adapter directly as the init for
79
+ # Stage-3 GRPO. Useful when SFT plateaus mid-training and we want GRPO to push
80
+ # the last few percentage points without burning more SFT compute.
81
+ STAGE3_SFT_ADAPTER_DIR="${STAGE3_SFT_ADAPTER_DIR:-}"
82
+ # KL anchor for GRPO. Setting > 0 keeps the policy close to the SFT reference
83
+ # and prevents singleton/mode collapse seen in Stage-2 GRPO. 0.0 = no KL.
84
+ GRPO_BETA="${GRPO_BETA:-0.0}"
85
+
86
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
87
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
88
+
89
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
90
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
91
+
92
+ mkdir -p "${OUTPUT_ROOT}"
93
+
94
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
95
+ printf 'ERROR: Missing train or eval jsonl.\n' >&2
96
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
97
+ exit 1
98
+ fi
99
+
100
+ latest_sft_step_ckpt() {
101
+ local d="$1"
102
+ shopt -s nullglob
103
+ local cks=("${d}"/checkpoint-step-*)
104
+ shopt -u nullglob
105
+ if (( ${#cks[@]} == 0 )); then
106
+ printf ''
107
+ return 1
108
+ fi
109
+ set +o pipefail
110
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
111
+ set -o pipefail
112
+ }
113
+
114
+ resolve_latent_grpo_adapter() {
115
+ local d="$1"
116
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
117
+ printf '%s\n' "${d}"
118
+ return 0
119
+ fi
120
+ local best="" step=-1
121
+ shopt -s nullglob
122
+ local c
123
+ for c in "${d}"/checkpoint-*; do
124
+ [[ -d "${c}" ]] || continue
125
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
126
+ local n
127
+ n="${c##*checkpoint-}"
128
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
129
+ step=$((10#${n}))
130
+ best="${c}"
131
+ fi
132
+ done
133
+ shopt -u nullglob
134
+ if [[ -n "${best}" ]]; then
135
+ printf '%s\n' "${best}"
136
+ return 0
137
+ fi
138
+ printf ''
139
+ return 1
140
+ }
141
+
142
+ run_latent_sft() {
143
+ local stage="$1"
144
+ local init_adapter="$2"
145
+ local out_dir="$3"
146
+ local lr="$4"
147
+ local cot="$5"
148
+ local ms1=0 ms2=1
149
+ if [[ "${stage}" == "1" ]]; then
150
+ ms1=1
151
+ ms2=0
152
+ fi
153
+ mkdir -p "${out_dir}"
154
+ printf '\n=== Latent(recurrent) stage %s SFT -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${SFT_VALUE_TARGET}" "${cot}" >&2
155
+ printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
156
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
157
+ --model_name "${MODEL_NAME}" \
158
+ --train_jsonl "${train_jsonl}" \
159
+ --eval_jsonl "${eval_jsonl}" \
160
+ --output_dir "${out_dir}" \
161
+ --cache_dir "${ROOT}/.hf_cache" \
162
+ --init_adapter_dir "${init_adapter}" \
163
+ --seed 0 \
164
+ --gpu_id 0 \
165
+ --stage_i "${stage}" \
166
+ --num_cot_tokens "${cot}" \
167
+ --latent_mode "${LATENT_MODE}" \
168
+ --total_empties_hint "${EMPTIES}" \
169
+ --mixed_stage1_ratio "${ms1}" \
170
+ --mixed_stage2_ratio "${ms2}" \
171
+ --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
172
+ --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
173
+ --num_epochs "${SFT_NUM_EPOCHS}" \
174
+ --learning_rate "${lr}" \
175
+ --weight_decay 0.0 \
176
+ --enable_gradient_checkpointing \
177
+ --logging_steps 20 \
178
+ --eval_steps 250 \
179
+ --save_steps 200 \
180
+ --eval_rows "${EVAL_PUZZLES}" \
181
+ --max_completion_length 24 \
182
+ --limit_train_rows "${TRAIN_PUZZLES}" \
183
+ --eval_value_precision_stop "${SFT_VALUE_TARGET}" \
184
+ --eval_value_recall_stop "${SFT_VALUE_TARGET}" \
185
+ --eval_exact_set_match_stop 0 \
186
+ --eval_solve_rate_stop 0 \
187
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
188
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
189
+ --max_steps "${SFT_MAX_STEPS}" \
190
+ --reward_good_value 1.25 \
191
+ --penalty_bad_value 1.0 \
192
+ --penalty_malformed 4.0 \
193
+ --penalty_empty 0.5 \
194
+ --penalty_singleton 1.5 \
195
+ --lora_r "${LORA_R}" \
196
+ --lora_alpha "${LORA_ALPHA}" \
197
+ --lora_dropout "${LORA_DROPOUT}" \
198
+ --use_wandb \
199
+ --wandb_project "sudoku-latent-multi-output-sft-recurrent" \
200
+ --wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_cot${cot}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \
201
+ --wandb_mode "${WANDB_MODE}" \
202
+ --wandb_entity "${WANDB_ENTITY}"
203
+ }
204
+
205
+ run_latent_grpo() {
206
+ local stage="$1"
207
+ local init_adapter="$2"
208
+ local out_dir="$3"
209
+ local cot="$4"
210
+ mkdir -p "${out_dir}"
211
+ printf '\n=== Latent(recurrent) stage %s GRPO -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${GRPO_VALUE_TARGET}" "${cot}" >&2
212
+ printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
213
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
214
+ --model_name "${MODEL_NAME}" \
215
+ --train_jsonl "${train_jsonl}" \
216
+ --eval_jsonl "${eval_jsonl}" \
217
+ --output_dir "${out_dir}" \
218
+ --cache_dir "${ROOT}/.hf_cache" \
219
+ --init_adapter_dir "${init_adapter}" \
220
+ --seed 0 \
221
+ --gpu_id 0 \
222
+ --stage_i "${stage}" \
223
+ --num_cot_tokens "${cot}" \
224
+ --latent_mode "${LATENT_MODE}" \
225
+ --total_empties_hint "${EMPTIES}" \
226
+ --mixed_stage1_ratio 0 \
227
+ --mixed_stage2_ratio 1 \
228
+ --per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \
229
+ --gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \
230
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
231
+ --learning_rate 1e-6 \
232
+ --logging_steps 20 \
233
+ --save_steps 200 \
234
+ --eval_steps 500 \
235
+ --eval_rows "${EVAL_PUZZLES}" \
236
+ --num_generations 4 \
237
+ --max_prompt_length 1024 \
238
+ --max_completion_length 24 \
239
+ --beta "${GRPO_BETA}" \
240
+ --enable_gradient_checkpointing \
241
+ --limit_train_rows "${TRAIN_PUZZLES}" \
242
+ --reward_good_value 1.25 \
243
+ --penalty_bad_value 1.0 \
244
+ --penalty_malformed 4.0 \
245
+ --penalty_empty 0.5 \
246
+ --penalty_singleton 1.5 \
247
+ --eval_value_precision_stop "${GRPO_VALUE_TARGET}" \
248
+ --eval_value_recall_stop "${GRPO_VALUE_TARGET}" \
249
+ --eval_solve_rate_stop 0 \
250
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
251
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
252
+ --max_steps "${GRPO_MAX_STEPS}" \
253
+ --lora_r "${LORA_R}" \
254
+ --lora_alpha "${LORA_ALPHA}" \
255
+ --lora_dropout "${LORA_DROPOUT}" \
256
+ --use_wandb \
257
+ --wandb_project "sudoku-latent-multi-output-grpo-recurrent" \
258
+ --wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_cot${cot}_val${GRPO_VALUE_TARGET}_${RUN_TAG}" \
259
+ --wandb_mode "${WANDB_MODE}" \
260
+ --wandb_entity "${WANDB_ENTITY}"
261
+ }
262
+
263
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
264
+ printf 'Latent mode: %s (cot grows 1->2->3 per stage)\n' "${LATENT_MODE}"
265
+ printf 'Value gate: SFT prec+recall >= %s ; GRPO prec+recall >= %s (min_steps=%s) ; GRPO_BETA=%s\n' "${SFT_VALUE_TARGET}" "${GRPO_VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}" "${GRPO_BETA}"
266
+ printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
267
+
268
+ S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
269
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
270
+ if [[ -n "${STAGE1_GRPO_ADAPTER_DIR}" ]]; then
271
+ A1="${STAGE1_GRPO_ADAPTER_DIR}"
272
+ printf 'Using existing stage-1 GRPO adapter (skipping stage-1 SFT + GRPO): %s\n' "${A1}" >&2
273
+ elif [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
274
+ G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
275
+ printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
276
+ run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
277
+ A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
278
+ else
279
+ run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
280
+ G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
281
+ if [[ -z "${G1_SFT_CKPT}" ]]; then
282
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
283
+ exit 1
284
+ fi
285
+ run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
286
+ A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
287
+ fi
288
+ if [[ -z "${A1}" ]]; then
289
+ printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
290
+ exit 1
291
+ fi
292
+ printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
293
+
294
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
295
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
296
+ if [[ -n "${STAGE2_GRPO_ADAPTER_DIR}" ]]; then
297
+ A2="${STAGE2_GRPO_ADAPTER_DIR}"
298
+ printf 'Using existing stage-2 GRPO adapter (skipping stage-2 SFT + GRPO): %s\n' "${A2}" >&2
299
+ elif [[ -n "${STAGE2_SFT_ADAPTER_DIR}" ]]; then
300
+ CKPT_S2="${STAGE2_SFT_ADAPTER_DIR}"
301
+ printf 'Using existing stage-2 SFT checkpoint as GRPO init (skipping stage-2 SFT train): %s\n' "${CKPT_S2}" >&2
302
+ run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
303
+ A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
304
+ else
305
+ run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
306
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
307
+ if [[ -z "${CKPT_S2}" ]]; then
308
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
309
+ exit 1
310
+ fi
311
+ run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
312
+ A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
313
+ fi
314
+ if [[ -z "${A2}" ]]; then
315
+ printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
316
+ exit 1
317
+ fi
318
+
319
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
320
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
321
+ if [[ -n "${STAGE3_SFT_ADAPTER_DIR}" ]]; then
322
+ CKPT_S3="${STAGE3_SFT_ADAPTER_DIR}"
323
+ printf 'Using existing stage-3 SFT checkpoint as GRPO init (skipping stage-3 SFT train): %s\n' "${CKPT_S3}" >&2
324
+ else
325
+ run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
326
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
327
+ if [[ -z "${CKPT_S3}" ]]; then
328
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
329
+ exit 1
330
+ fi
331
+ fi
332
+ run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
333
+ A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
334
+ if [[ -z "${A3}" ]]; then
335
+ printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
336
+ exit 1
337
+ fi
338
+
339
+ printf '\nAll latent(recurrent) phases finished.\n'
340
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
341
+ printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Latent residual-projector pipeline for 20-empty Sudoku.
3
+ # Default behavior:
4
+ # - Stage 1 SFT starts from fresh LoRA + random latent residual modules
5
+ # - Then stage 1 GRPO -> stage 2 SFT -> stage 2 GRPO -> stage 3 SFT -> stage 3 GRPO
6
+ # Optional:
7
+ # STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
8
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX
9
+ # VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=...
10
+ set -euo pipefail
11
+
12
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
13
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
14
+
15
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
16
+ SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
17
+ GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
18
+
19
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
20
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
21
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
22
+ WANDB_MODE="${WANDB_MODE:-online}"
23
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
24
+
25
+ EMPTIES=20
26
+ TAG_SUFFIX="latent_residual"
27
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
28
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
29
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
30
+ MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
31
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
32
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
33
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
34
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
35
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
36
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
37
+ LORA_R="${LORA_R:-32}"
38
+ LORA_ALPHA="${LORA_ALPHA:-64}"
39
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
40
+ STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
41
+
42
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_residual_stages123_value98}"
43
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
44
+ STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
45
+ STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
46
+
47
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
48
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
49
+
50
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
51
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
52
+
53
+ mkdir -p "${OUTPUT_ROOT}"
54
+
55
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
56
+ printf 'ERROR: Missing train or eval jsonl.\n' >&2
57
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
58
+ exit 1
59
+ fi
60
+
61
+ latest_sft_step_ckpt() {
62
+ local d="$1"
63
+ shopt -s nullglob
64
+ local cks=("${d}"/checkpoint-step-*)
65
+ shopt -u nullglob
66
+ if (( ${#cks[@]} == 0 )); then
67
+ printf ''
68
+ return 1
69
+ fi
70
+ set +o pipefail
71
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
72
+ set -o pipefail
73
+ }
74
+
75
+ resolve_latent_grpo_adapter() {
76
+ local d="$1"
77
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
78
+ printf '%s\n' "${d}"
79
+ return 0
80
+ fi
81
+ local best="" step=-1
82
+ shopt -s nullglob
83
+ local c
84
+ for c in "${d}"/checkpoint-*; do
85
+ [[ -d "${c}" ]] || continue
86
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
87
+ [[ -f "${c}/latent_cot_state.pt" ]] || continue
88
+ local n
89
+ n="${c##*checkpoint-}"
90
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
91
+ step=$((10#${n}))
92
+ best="${c}"
93
+ fi
94
+ done
95
+ shopt -u nullglob
96
+ if [[ -n "${best}" ]]; then
97
+ printf '%s\n' "${best}"
98
+ return 0
99
+ fi
100
+ printf ''
101
+ return 1
102
+ }
103
+
104
+ run_latent_sft() {
105
+ local stage="$1"
106
+ local init_adapter="$2"
107
+ local out_dir="$3"
108
+ local lr="$4"
109
+ local cot="$5"
110
+ local ms1=0 ms2=1
111
+ if [[ "${stage}" == "1" ]]; then
112
+ ms1=1
113
+ ms2=0
114
+ fi
115
+ mkdir -p "${out_dir}"
116
+ printf '\n=== Latent stage %s SFT -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
117
+ printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
118
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
119
+ --model_name "${MODEL_NAME}" \
120
+ --train_jsonl "${train_jsonl}" \
121
+ --eval_jsonl "${eval_jsonl}" \
122
+ --output_dir "${out_dir}" \
123
+ --cache_dir "${ROOT}/.hf_cache" \
124
+ --init_adapter_dir "${init_adapter}" \
125
+ --seed 0 \
126
+ --gpu_id 0 \
127
+ --stage_i "${stage}" \
128
+ --num_cot_tokens "${cot}" \
129
+ --total_empties_hint "${EMPTIES}" \
130
+ --mixed_stage1_ratio "${ms1}" \
131
+ --mixed_stage2_ratio "${ms2}" \
132
+ --gradient_accumulation_steps 2 \
133
+ --num_epochs "${SFT_NUM_EPOCHS}" \
134
+ --learning_rate "${lr}" \
135
+ --weight_decay 0.0 \
136
+ --enable_gradient_checkpointing \
137
+ --logging_steps 20 \
138
+ --eval_steps 250 \
139
+ --save_steps 200 \
140
+ --eval_rows "${EVAL_PUZZLES}" \
141
+ --max_completion_length 24 \
142
+ --limit_train_rows "${TRAIN_PUZZLES}" \
143
+ --eval_value_precision_stop "${VALUE_TARGET}" \
144
+ --eval_value_recall_stop "${VALUE_TARGET}" \
145
+ --eval_exact_set_match_stop 0 \
146
+ --eval_solve_rate_stop 0 \
147
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
148
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
149
+ --max_steps "${SFT_MAX_STEPS}" \
150
+ --reward_good_value 1.25 \
151
+ --penalty_bad_value 1.0 \
152
+ --penalty_malformed 4.0 \
153
+ --penalty_empty 0.5 \
154
+ --penalty_singleton 1.5 \
155
+ --lora_r "${LORA_R}" \
156
+ --lora_alpha "${LORA_ALPHA}" \
157
+ --lora_dropout "${LORA_DROPOUT}" \
158
+ --use_wandb \
159
+ --wandb_project "sudoku-latent-multi-output-sft-residual-projector" \
160
+ --wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
161
+ --wandb_mode "${WANDB_MODE}" \
162
+ --wandb_entity "${WANDB_ENTITY}"
163
+ }
164
+
165
+ run_latent_grpo() {
166
+ local stage="$1"
167
+ local init_adapter="$2"
168
+ local out_dir="$3"
169
+ local cot="$4"
170
+ mkdir -p "${out_dir}"
171
+ printf '\n=== Latent stage %s GRPO -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
172
+ printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
173
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
174
+ --model_name "${MODEL_NAME}" \
175
+ --train_jsonl "${train_jsonl}" \
176
+ --eval_jsonl "${eval_jsonl}" \
177
+ --output_dir "${out_dir}" \
178
+ --cache_dir "${ROOT}/.hf_cache" \
179
+ --init_adapter_dir "${init_adapter}" \
180
+ --seed 0 \
181
+ --gpu_id 0 \
182
+ --stage_i "${stage}" \
183
+ --num_cot_tokens "${cot}" \
184
+ --total_empties_hint "${EMPTIES}" \
185
+ --mixed_stage1_ratio 0 \
186
+ --mixed_stage2_ratio 1 \
187
+ --per_device_train_batch_size 8 \
188
+ --gradient_accumulation_steps 2 \
189
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
190
+ --learning_rate 1e-6 \
191
+ --logging_steps 20 \
192
+ --save_steps 200 \
193
+ --eval_steps 500 \
194
+ --eval_rows "${EVAL_PUZZLES}" \
195
+ --num_generations 4 \
196
+ --max_prompt_length 1024 \
197
+ --max_completion_length 24 \
198
+ --beta 0.0 \
199
+ --enable_gradient_checkpointing \
200
+ --limit_train_rows "${TRAIN_PUZZLES}" \
201
+ --reward_good_value 1.25 \
202
+ --penalty_bad_value 1.0 \
203
+ --penalty_malformed 4.0 \
204
+ --penalty_empty 0.5 \
205
+ --penalty_singleton 1.5 \
206
+ --eval_value_precision_stop "${VALUE_TARGET}" \
207
+ --eval_value_recall_stop "${VALUE_TARGET}" \
208
+ --eval_solve_rate_stop 0 \
209
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
210
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
211
+ --max_steps "${GRPO_MAX_STEPS}" \
212
+ --lora_r "${LORA_R}" \
213
+ --lora_alpha "${LORA_ALPHA}" \
214
+ --lora_dropout "${LORA_DROPOUT}" \
215
+ --use_wandb \
216
+ --wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \
217
+ --wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
218
+ --wandb_mode "${WANDB_MODE}" \
219
+ --wandb_entity "${WANDB_ENTITY}"
220
+ }
221
+
222
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
223
+ printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}"
224
+ printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
225
+
226
+ S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
227
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
228
+ if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
229
+ G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
230
+ printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
231
+ else
232
+ run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
233
+ G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
234
+ if [[ -z "${G1_SFT_CKPT}" ]]; then
235
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
236
+ exit 1
237
+ fi
238
+ fi
239
+ run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
240
+ A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
241
+ if [[ -z "${A1}" ]]; then
242
+ printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
243
+ exit 1
244
+ fi
245
+ printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
246
+
247
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
248
+ run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
249
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
250
+ if [[ -z "${CKPT_S2}" ]]; then
251
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
252
+ exit 1
253
+ fi
254
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
255
+ run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
256
+ A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
257
+ if [[ -z "${A2}" ]]; then
258
+ printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
259
+ exit 1
260
+ fi
261
+
262
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
263
+ run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
264
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
265
+ if [[ -z "${CKPT_S3}" ]]; then
266
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
267
+ exit 1
268
+ fi
269
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
270
+ run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
271
+ A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
272
+ if [[ -z "${A3}" ]]; then
273
+ printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
274
+ exit 1
275
+ fi
276
+
277
+ printf '\nAll latent residual phases finished.\n'
278
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
279
+ printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run AFTER stage-1 SFT finishes (20-empty). Order:
3
+ # 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
4
+ # 2) Stage-2 SFT (init = stage-1 GRPO adapter)
5
+ # 3) Stage-2 GRPO (init = stage-2 SFT adapter)
6
+ # 4) Stage-3 SFT (init = stage-2 GRPO adapter)
7
+ # 5) Stage-3 GRPO (init = stage-3 SFT adapter)
8
+ #
9
+ # Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
10
+ # are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
11
+ # very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
12
+ # (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
13
+ #
14
+ # Required (full pipeline from stage-1 SFT):
15
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
16
+ #
17
+ # Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
18
+ # RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_20empty
19
+ # (OUTPUT_ROOT defaults to dirname of that dir.)
20
+ #
21
+ # Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
22
+ # START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_20empty
23
+ #
24
+ # Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
25
+ # START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_20empty
26
+ #
27
+ # Optional:
28
+ # VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
29
+ # TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
30
+ #
31
+ set -euo pipefail
32
+
33
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
34
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
35
+
36
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
37
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
38
+ GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
39
+
40
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
41
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
42
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
43
+ WANDB_MODE="${WANDB_MODE:-online}"
44
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
45
+
46
+ EMPTIES=20
47
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
48
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
49
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
50
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
51
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
52
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
53
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
54
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
55
+
56
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_stages123_value98}"
57
+ START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
58
+ START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
59
+ RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
60
+
61
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
62
+ if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
63
+ printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
64
+ exit 1
65
+ fi
66
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
67
+ elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
68
+ if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
69
+ printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
70
+ exit 1
71
+ fi
72
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
73
+ elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
74
+ if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
75
+ printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
76
+ exit 1
77
+ fi
78
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
79
+ else
80
+ if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
81
+ printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
82
+ exit 1
83
+ fi
84
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
85
+ fi
86
+
87
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
88
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
89
+
90
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
91
+
92
+ mkdir -p "${OUTPUT_ROOT}"
93
+
94
+ latest_sft_step_ckpt() {
95
+ local d="$1"
96
+ shopt -s nullglob
97
+ local cks=("${d}"/checkpoint-step-*)
98
+ shopt -u nullglob
99
+ if (( ${#cks[@]} == 0 )); then
100
+ printf ''
101
+ return 1
102
+ fi
103
+ set +o pipefail
104
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
105
+ set -o pipefail
106
+ }
107
+
108
+ resolve_grpo_adapter() {
109
+ local d="$1"
110
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
111
+ printf '%s\n' "${d}"
112
+ return 0
113
+ fi
114
+ local best="" step=-1
115
+ shopt -s nullglob
116
+ local c
117
+ for c in "${d}"/checkpoint-*; do
118
+ [[ -d "${c}" ]] || continue
119
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
120
+ local n
121
+ n="${c##*checkpoint-}"
122
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
123
+ step=$((10#${n}))
124
+ best="${c}"
125
+ fi
126
+ done
127
+ shopt -u nullglob
128
+ if [[ -n "${best}" ]]; then
129
+ printf '%s\n' "${best}"
130
+ return 0
131
+ fi
132
+ printf ''
133
+ return 1
134
+ }
135
+
136
+ GC_FLAGS=()
137
+ if [[ "${USE_GC:-0}" == "1" ]]; then
138
+ GC_FLAGS+=(--enable_gradient_checkpointing)
139
+ fi
140
+
141
+ run_sft() {
142
+ local stage="$1"
143
+ local init_adapter="$2"
144
+ local out_dir="$3"
145
+ local lr="$4"
146
+ mkdir -p "${out_dir}"
147
+ printf '\n=== Stage %s SFT -> stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
148
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
149
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
150
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
151
+ --train_jsonl "${train_jsonl}" \
152
+ --eval_jsonl "${eval_jsonl}" \
153
+ --output_dir "${out_dir}" \
154
+ --cache_dir "${ROOT}/.hf_cache" \
155
+ --init_adapter_dir "${init_adapter}" \
156
+ --seed 0 \
157
+ --gpu_id 0 \
158
+ --stage_i "${stage}" \
159
+ --total_empties_hint "${EMPTIES}" \
160
+ --per_device_train_batch_size 16 \
161
+ --gradient_accumulation_steps 2 \
162
+ --num_epochs "${SFT_NUM_EPOCHS}" \
163
+ --learning_rate "${lr}" \
164
+ --max_grad_norm 1.0 \
165
+ "${GC_FLAGS[@]}" \
166
+ --logging_steps 20 \
167
+ --eval_steps 250 \
168
+ --save_steps 200 \
169
+ --eval_rows "${EVAL_PUZZLES}" \
170
+ --max_completion_length 24 \
171
+ --limit_train_rows "${TRAIN_PUZZLES}" \
172
+ --lora_r 32 \
173
+ --lora_alpha 64 \
174
+ --lora_dropout 0.05 \
175
+ --eval_value_precision_stop "${VALUE_TARGET}" \
176
+ --eval_value_recall_stop "${VALUE_TARGET}" \
177
+ --eval_exact_set_match_stop 0 \
178
+ --eval_solve_rate_stop 0 \
179
+ --min_steps_before_stop 50 \
180
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
181
+ --max_steps "${SFT_MAX_STEPS}" \
182
+ --use_wandb \
183
+ --wandb_project "sudoku-multi-output-sft" \
184
+ --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
185
+ --wandb_mode "${WANDB_MODE}" \
186
+ --wandb_entity "${WANDB_ENTITY}"
187
+ }
188
+
189
+ run_grpo() {
190
+ local stage="$1"
191
+ local init_adapter="$2"
192
+ local out_dir="$3"
193
+ mkdir -p "${out_dir}"
194
+ printf '\n=== Stage %s GRPO -> stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
195
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
196
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
197
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
198
+ --train_jsonl "${train_jsonl}" \
199
+ --eval_jsonl "${eval_jsonl}" \
200
+ --output_dir "${out_dir}" \
201
+ --cache_dir "${ROOT}/.hf_cache" \
202
+ --init_adapter_dir "${init_adapter}" \
203
+ --seed 0 \
204
+ --gpu_id 0 \
205
+ --stage_i "${stage}" \
206
+ --total_empties_hint "${EMPTIES}" \
207
+ --per_device_train_batch_size 8 \
208
+ --gradient_accumulation_steps 2 \
209
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
210
+ --learning_rate 1e-6 \
211
+ --logging_steps 20 \
212
+ --save_steps 200 \
213
+ --eval_steps 500 \
214
+ --eval_rows "${EVAL_PUZZLES}" \
215
+ --num_generations 4 \
216
+ --max_prompt_length 1024 \
217
+ --max_completion_length 24 \
218
+ --beta 0.0 \
219
+ --enable_gradient_checkpointing \
220
+ --limit_train_rows "${TRAIN_PUZZLES}" \
221
+ --lora_r 32 \
222
+ --lora_alpha 64 \
223
+ --lora_dropout 0.05 \
224
+ --reward_good_value 1.25 \
225
+ --penalty_bad_value 1.0 \
226
+ --penalty_malformed 4.0 \
227
+ --penalty_empty 0.5 \
228
+ --penalty_singleton 1.5 \
229
+ --eval_value_precision_stop "${VALUE_TARGET}" \
230
+ --eval_value_recall_stop "${VALUE_TARGET}" \
231
+ --eval_solve_rate_stop 0 \
232
+ --min_steps_before_stop 50 \
233
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
234
+ --max_steps "${GRPO_MAX_STEPS}" \
235
+ --use_wandb \
236
+ --wandb_project "sudoku-multi-output-grpo" \
237
+ --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
238
+ --wandb_mode "${WANDB_MODE}" \
239
+ --wandb_entity "${WANDB_ENTITY}"
240
+ }
241
+
242
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
243
+ printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
244
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
245
+ exit 1
246
+ fi
247
+
248
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
249
+ printf 'Fast-forward: stage-2 SFT dir %s -> stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
250
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
251
+ S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
252
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
253
+ if [[ -z "${CKPT_S2}" ]]; then
254
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
255
+ exit 1
256
+ fi
257
+ printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
258
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
259
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
260
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
261
+ if [[ -z "${A2}" ]]; then
262
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
263
+ exit 1
264
+ fi
265
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
266
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
267
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
268
+ if [[ -z "${CKPT_S3}" ]]; then
269
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
270
+ exit 1
271
+ fi
272
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
273
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
274
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
275
+ if [[ -z "${A3}" ]]; then
276
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
277
+ exit 1
278
+ fi
279
+ printf '\nAll phases finished (started at stage-2 GRPO).\n'
280
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
281
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
282
+ exit 0
283
+ fi
284
+
285
+ if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
286
+ printf 'Fast-forward: stage-2 GRPO dir %s -> stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
287
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
288
+ A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
289
+ if [[ -z "${A2}" ]]; then
290
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
291
+ exit 1
292
+ fi
293
+ printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
294
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
295
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
296
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
297
+ if [[ -z "${CKPT_S3}" ]]; then
298
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
299
+ exit 1
300
+ fi
301
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
302
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
303
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
304
+ if [[ -z "${A3}" ]]; then
305
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
306
+ exit 1
307
+ fi
308
+ printf '\nAll phases finished (started after stage-2 GRPO).\n'
309
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
310
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
311
+ exit 0
312
+ fi
313
+
314
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
315
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
316
+ printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
317
+ else
318
+ printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
319
+ fi
320
+ printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
321
+ "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
322
+
323
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
324
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
325
+ A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
326
+ else
327
+ run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
328
+ A1="$(resolve_grpo_adapter "${G1_DIR}")"
329
+ fi
330
+ if [[ -z "${A1}" ]]; then
331
+ printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
332
+ exit 1
333
+ fi
334
+ printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
335
+
336
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
337
+ run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
338
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
339
+ if [[ -z "${CKPT_S2}" ]]; then
340
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
341
+ exit 1
342
+ fi
343
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
344
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
345
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
346
+ if [[ -z "${A2}" ]]; then
347
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
348
+ exit 1
349
+ fi
350
+
351
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
352
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
353
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
354
+ if [[ -z "${CKPT_S3}" ]]; then
355
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
356
+ exit 1
357
+ fi
358
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
359
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
360
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
361
+ if [[ -z "${A3}" ]]; then
362
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
363
+ exit 1
364
+ fi
365
+
366
+ printf '\nAll phases finished.\n'
367
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
368
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Stage-1 SFT only for 20-empty: train until eval value_precision AND value_recall
3
+ # both reach 0.98 (or max_steps / optional wall clock). Use the resulting
4
+ # checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for
5
+ # launch_20empty_post_s1sft_stages123_value98.sh.
6
+ #
7
+ # Fresh LoRA on base model:
8
+ # ./launch_20empty_sft_stage1_98p.sh
9
+ #
10
+ # Continue from a prior SFT checkpoint:
11
+ # INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_sft_stage1_98p.sh
12
+ #
13
+ set -euo pipefail
14
+
15
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
16
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
17
+
18
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
19
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
20
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
21
+
22
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
23
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
24
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
25
+ WANDB_MODE="${WANDB_MODE:-online}"
26
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
27
+
28
+ EMPTIES=20
29
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
30
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
31
+ SFT_TARGET="${SFT_TARGET:-0.98}"
32
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
33
+ MAX_STEPS="${MAX_STEPS:-30000}"
34
+
35
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_sft98_stage1}"
36
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}"
37
+
38
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
39
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
40
+
41
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
42
+
43
+ if [[ ! -f "${train_jsonl}" ]]; then
44
+ mkdir -p "$(dirname "${train_jsonl}")"
45
+ printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
46
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
47
+ fi
48
+ if [[ ! -f "${eval_jsonl}" ]]; then
49
+ mkdir -p "$(dirname "${eval_jsonl}")"
50
+ printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
51
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
52
+ fi
53
+
54
+ mkdir -p "${OUTPUT_DIR}"
55
+
56
+ INIT_FLAGS=()
57
+ if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
58
+ INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
59
+ printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
60
+ fi
61
+
62
+ GC_FLAGS=()
63
+ if [[ "${USE_GC:-0}" == "1" ]]; then
64
+ GC_FLAGS+=(--enable_gradient_checkpointing)
65
+ printf 'NOTE: USE_GC=1 - slower, less VRAM.\n'
66
+ fi
67
+
68
+ if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
69
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
70
+ else
71
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
72
+ fi
73
+ printf 'Output: %s\n' "${OUTPUT_DIR}"
74
+
75
+ exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
76
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
77
+ --train_jsonl "${train_jsonl}" \
78
+ --eval_jsonl "${eval_jsonl}" \
79
+ --output_dir "${OUTPUT_DIR}" \
80
+ --cache_dir "${ROOT}/.hf_cache" \
81
+ "${INIT_FLAGS[@]}" \
82
+ --seed 0 \
83
+ --gpu_id 0 \
84
+ --stage_i 1 \
85
+ --total_empties_hint "${EMPTIES}" \
86
+ --per_device_train_batch_size 16 \
87
+ --gradient_accumulation_steps 2 \
88
+ --num_epochs 64.0 \
89
+ --learning_rate 2e-4 \
90
+ --max_grad_norm 1.0 \
91
+ "${GC_FLAGS[@]}" \
92
+ --logging_steps 20 \
93
+ --eval_steps 250 \
94
+ --save_steps 100 \
95
+ --eval_rows "${EVAL_PUZZLES}" \
96
+ --max_completion_length 24 \
97
+ --limit_train_rows "${TRAIN_PUZZLES}" \
98
+ --lora_r 32 \
99
+ --lora_alpha 64 \
100
+ --lora_dropout 0.05 \
101
+ --eval_value_precision_stop "${SFT_TARGET}" \
102
+ --eval_value_recall_stop "${SFT_TARGET}" \
103
+ --eval_exact_set_match_stop 0 \
104
+ --eval_solve_rate_stop 0 \
105
+ --min_steps_before_stop 50 \
106
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
107
+ --max_steps "${MAX_STEPS}" \
108
+ --use_wandb \
109
+ --wandb_project "sudoku-multi-output-sft" \
110
+ --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
111
+ --wandb_mode "${WANDB_MODE}" \
112
+ --wandb_entity "${WANDB_ENTITY}"
hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Stage-1 SFT sweep over all latent modes for 20-empty Sudoku.
3
+ #
4
+ # Runs four independent SFT jobs in parallel:
5
+ # residual, fixed_slots, recurrent_hidden, latent_seeds
6
+ #
7
+ # Default GPU split on an 8-GPU node:
8
+ # residual -> CUDA_VISIBLE_DEVICES=0,1
9
+ # fixed_slots -> CUDA_VISIBLE_DEVICES=2,3
10
+ # recurrent_hidden -> CUDA_VISIBLE_DEVICES=4,5
11
+ # latent_seeds -> CUDA_VISIBLE_DEVICES=6,7
12
+ #
13
+ # Useful overrides:
14
+ # RUN_TAG=... CHECKPOINT_ROOT=...
15
+ # GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1
16
+ # TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 SFT_VALUE_TARGET=0.98
17
+ # STAGE1_INIT_ADAPTER_DIR=/path/to/init_adapter
18
+ set -euo pipefail
19
+
20
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
21
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
22
+
23
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
24
+ SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
25
+
26
+ MODES=("residual" "fixed_slots" "recurrent_hidden" "latent_seeds")
27
+ MODE_TAGS=("latent_residual" "latent_fixed_slots" "latent_recurrent_hidden" "latent_seeds")
28
+
29
+ # Space-separated list of CUDA_VISIBLE_DEVICES groups, one per latent mode.
30
+ # Example for one GPU per method: GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1
31
+ GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}"
32
+ read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}"
33
+
34
+ NPROC_PER_JOB="${NPROC_PER_JOB:-2}"
35
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
36
+ WANDB_MODE="${WANDB_MODE:-online}"
37
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
38
+
39
+ EMPTIES="${EMPTIES:-20}"
40
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
41
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
42
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
43
+ SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}"
44
+ MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
45
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
46
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
47
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
48
+
49
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
50
+ # Match the recurrent 20-empty launcher defaults: -1 resolves inside the
51
+ # trainer to hidden_size, and alpha=-1 resolves to 2 * resolved rank.
52
+ LORA_R="${LORA_R:--1}"
53
+ LORA_ALPHA="${LORA_ALPHA:--1}"
54
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
55
+ STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
56
+ SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
57
+ SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
58
+ NUM_COT_TOKENS="${NUM_COT_TOKENS:-1}"
59
+ MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}"
60
+ MAX_LATENT_SEEDS="${MAX_LATENT_SEEDS:-8}"
61
+ STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
62
+
63
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_stage1_sft_all_latent_modes}"
64
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
65
+
66
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
67
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
68
+
69
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
70
+ export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
71
+
72
+ if [[ ${#GPU_GROUPS[@]} -ne ${#MODES[@]} ]]; then
73
+ printf 'ERROR: expected %d GPU groups, got %d.\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2
74
+ printf 'Example: GPU_GROUPS_SPEC="0,1 2,3 4,5 6,7"\n' >&2
75
+ exit 1
76
+ fi
77
+
78
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
79
+ printf 'ERROR: Missing train or eval jsonl.\n' >&2
80
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
81
+ exit 1
82
+ fi
83
+
84
+ if [[ -n "${STAGE1_INIT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then
85
+ printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2
86
+ exit 1
87
+ fi
88
+
89
+ mkdir -p "${OUTPUT_ROOT}"
90
+
91
+ run_stage1_sft_for_mode() {
92
+ local mode="$1"
93
+ local tag="$2"
94
+ local gpu_group="$3"
95
+ local out_dir="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${tag}"
96
+ local log_dir="${OUTPUT_ROOT}/logs"
97
+ local log_file="${log_dir}/stage01_sft_${mode}.log"
98
+
99
+ mkdir -p "${out_dir}" "${log_dir}"
100
+ printf '\n=== launching stage-1 SFT: mode=%s gpus=%s out=%s ===\n' "${mode}" "${gpu_group}" "${out_dir}" >&2
101
+
102
+ (
103
+ export CUDA_VISIBLE_DEVICES="${gpu_group}"
104
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${SFT_SCRIPT}" \
105
+ --model_name "${MODEL_NAME}" \
106
+ --train_jsonl "${train_jsonl}" \
107
+ --eval_jsonl "${eval_jsonl}" \
108
+ --output_dir "${out_dir}" \
109
+ --cache_dir "${ROOT}/.hf_cache" \
110
+ --init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}" \
111
+ --seed 0 \
112
+ --gpu_id 0 \
113
+ --stage_i 1 \
114
+ --num_cot_tokens "${NUM_COT_TOKENS}" \
115
+ --latent_mode "${mode}" \
116
+ --max_latent_slots "${MAX_LATENT_SLOTS}" \
117
+ --max_latent_seeds "${MAX_LATENT_SEEDS}" \
118
+ --total_empties_hint "${EMPTIES}" \
119
+ --mixed_stage1_ratio 1 \
120
+ --mixed_stage2_ratio 0 \
121
+ --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
122
+ --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
123
+ --num_epochs "${SFT_NUM_EPOCHS}" \
124
+ --learning_rate "${STAGE1_SFT_LR}" \
125
+ --weight_decay 0.0 \
126
+ --enable_gradient_checkpointing \
127
+ --logging_steps 20 \
128
+ --eval_steps 250 \
129
+ --save_steps 200 \
130
+ --eval_rows "${EVAL_PUZZLES}" \
131
+ --max_completion_length 24 \
132
+ --limit_train_rows "${TRAIN_PUZZLES}" \
133
+ --eval_value_precision_stop "${SFT_VALUE_TARGET}" \
134
+ --eval_value_recall_stop "${SFT_VALUE_TARGET}" \
135
+ --eval_exact_set_match_stop 0 \
136
+ --eval_solve_rate_stop 0 \
137
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
138
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
139
+ --max_steps "${SFT_MAX_STEPS}" \
140
+ --reward_good_value 1.25 \
141
+ --penalty_bad_value 1.0 \
142
+ --penalty_malformed 4.0 \
143
+ --penalty_empty 0.5 \
144
+ --penalty_singleton 1.5 \
145
+ --lora_r "${LORA_R}" \
146
+ --lora_alpha "${LORA_ALPHA}" \
147
+ --lora_dropout "${LORA_DROPOUT}" \
148
+ --use_wandb \
149
+ --wandb_project "sudoku-latent-stage1-sft-all-modes" \
150
+ --wandb_run_name "latent20_stage1_sft_${mode}_cot${NUM_COT_TOKENS}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \
151
+ --wandb_mode "${WANDB_MODE}" \
152
+ --wandb_entity "${WANDB_ENTITY}"
153
+ ) >"${log_file}" 2>&1 &
154
+
155
+ printf '%s\n' "$!"
156
+ }
157
+
158
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
159
+ printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
160
+ printf 'Modes: %s\n' "${MODES[*]}"
161
+ printf 'GPU groups: %s\n' "${GPU_GROUPS[*]}"
162
+ printf 'Processes per job: %s\n' "${NPROC_PER_JOB}"
163
+
164
+ pids=()
165
+ names=()
166
+ for i in "${!MODES[@]}"; do
167
+ pid="$(run_stage1_sft_for_mode "${MODES[$i]}" "${MODE_TAGS[$i]}" "${GPU_GROUPS[$i]}")"
168
+ pids+=("${pid}")
169
+ names+=("${MODES[$i]}")
170
+ done
171
+
172
+ failed=0
173
+ for i in "${!pids[@]}"; do
174
+ if wait "${pids[$i]}"; then
175
+ printf 'DONE: %s\n' "${names[$i]}"
176
+ else
177
+ printf 'FAILED: %s (pid=%s). See logs under %s/logs\n' "${names[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2
178
+ failed=1
179
+ fi
180
+ done
181
+
182
+ if [[ "${failed}" -ne 0 ]]; then
183
+ exit 1
184
+ fi
185
+
186
+ printf '\nAll stage-1 latent SFT jobs finished.\n'
187
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Full 20-empty latent comparison with baseline warm-up before latent stages.
3
+ #
4
+ # Required:
5
+ # STAGE1_BASELINE_ADAPTER_DIR=/path/to/baseline/stage1/checkpoint-step-XXXXX
6
+ #
7
+ # Default mode split on 8 GPUs:
8
+ # residual -> GPUs 0,1
9
+ # fixed_slots -> GPUs 2,3
10
+ # recurrent_hidden -> GPUs 4,5
11
+ # latent_seeds -> GPUs 6,7
12
+ #
13
+ # Per mode:
14
+ # stage1 latent SFT -> stage1 latent GRPO
15
+ # stage2 baseline SFT warm-up -> stage2 latent SFT -> stage2 latent GRPO
16
+ # stage3 baseline SFT warm-up -> stage3 latent SFT -> stage3 latent GRPO
17
+ #
18
+ # Optional resume adapters, intended for single-mode resumes:
19
+ # STAGE1_LATENT_SFT_ADAPTER_DIR=/path/to/stage01_latent_sft_or_checkpoint
20
+ # STAGE1_LATENT_GRPO_ADAPTER_DIR=/path/to/stage01_latent_grpo
21
+ # STAGE2_BASELINE_WARM_ADAPTER_DIR=/path/to/stage02_baseline_warm_sft
22
+ # STAGE2_LATENT_SFT_INIT_ADAPTER_DIR=/path/to/stage02_latent_sft_checkpoint_to_continue_training
23
+ # STAGE2_LATENT_SFT_ADAPTER_DIR=/path/to/stage02_latent_sft_or_checkpoint
24
+ # STAGE2_LATENT_GRPO_ADAPTER_DIR=/path/to/stage02_latent_grpo
25
+ # STAGE3_BASELINE_WARM_ADAPTER_DIR=/path/to/stage03_baseline_warm_sft
26
+ # STAGE3_LATENT_SFT_ADAPTER_DIR=/path/to/stage03_latent_sft_or_checkpoint
27
+ set -euo pipefail
28
+
29
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
30
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
31
+
32
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
33
+ BASELINE_SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
34
+ LATENT_SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
35
+ LATENT_GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py"
36
+
37
+ EMPTIES="${EMPTIES:-20}"
38
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
39
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
40
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
41
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
42
+ WANDB_MODE="${WANDB_MODE:-online}"
43
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
44
+
45
+ MODES_SPEC="${MODES_SPEC:-residual fixed_slots recurrent_hidden latent_seeds}"
46
+ GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}"
47
+ NPROC_PER_JOB="${NPROC_PER_JOB:-2}"
48
+
49
+ STAGE1_BASELINE_ADAPTER_DIR="${STAGE1_BASELINE_ADAPTER_DIR:-}"
50
+ STAGE1_LATENT_SFT_ADAPTER_DIR="${STAGE1_LATENT_SFT_ADAPTER_DIR:-}"
51
+ STAGE1_LATENT_GRPO_ADAPTER_DIR="${STAGE1_LATENT_GRPO_ADAPTER_DIR:-}"
52
+ STAGE2_BASELINE_WARM_ADAPTER_DIR="${STAGE2_BASELINE_WARM_ADAPTER_DIR:-}"
53
+ STAGE2_LATENT_SFT_INIT_ADAPTER_DIR="${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR:-}"
54
+ STAGE2_LATENT_SFT_ADAPTER_DIR="${STAGE2_LATENT_SFT_ADAPTER_DIR:-}"
55
+ STAGE2_LATENT_GRPO_ADAPTER_DIR="${STAGE2_LATENT_GRPO_ADAPTER_DIR:-}"
56
+ STAGE3_BASELINE_WARM_ADAPTER_DIR="${STAGE3_BASELINE_WARM_ADAPTER_DIR:-}"
57
+ STAGE3_LATENT_SFT_ADAPTER_DIR="${STAGE3_LATENT_SFT_ADAPTER_DIR:-}"
58
+ if [[ -z "${STAGE1_LATENT_SFT_ADAPTER_DIR}" && -z "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then
59
+ if [[ -z "${STAGE1_BASELINE_ADAPTER_DIR}" ]] || [[ ! -d "${STAGE1_BASELINE_ADAPTER_DIR}" ]]; then
60
+ printf 'ERROR: Set STAGE1_BASELINE_ADAPTER_DIR to a finished baseline SFT checkpoint directory, or provide a stage-1 latent resume adapter.\n' >&2
61
+ exit 1
62
+ fi
63
+ fi
64
+
65
+ SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
66
+ SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
67
+ BASELINE_PER_DEVICE_BS="${BASELINE_PER_DEVICE_BS:-16}"
68
+ BASELINE_GRAD_ACCUM="${BASELINE_GRAD_ACCUM:-2}"
69
+ GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-4}"
70
+ GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}"
71
+
72
+ BASELINE_WARM_MAX_STEPS="${BASELINE_WARM_MAX_STEPS:-1000}"
73
+ LATENT_SFT_MAX_STEPS="${LATENT_SFT_MAX_STEPS:-1000}"
74
+ LATENT_GRPO_MAX_STEPS="${LATENT_GRPO_MAX_STEPS:-500}"
75
+ LATENT_SFT_EVAL_STEPS="${LATENT_SFT_EVAL_STEPS:-250}"
76
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-64}"
77
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-50}"
78
+
79
+ SOLVE_TARGET="${SOLVE_TARGET:-0.95}"
80
+ VALUE_TARGET="${VALUE_TARGET:-0}"
81
+ MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
82
+ GRPO_BETA="${GRPO_BETA:-0.0}"
83
+
84
+ LORA_R="${LORA_R:-32}"
85
+ LORA_ALPHA="${LORA_ALPHA:-64}"
86
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
87
+
88
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_warm_baseline_all_latent_modes_stages123}"
89
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
90
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
91
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
92
+
93
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
94
+ export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
95
+
96
+ read -r -a MODES <<< "${MODES_SPEC}"
97
+ read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}"
98
+ if [[ ${#MODES[@]} -ne ${#GPU_GROUPS[@]} ]]; then
99
+ printf 'ERROR: expected one GPU group per mode. modes=%d gpu_groups=%d\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2
100
+ exit 1
101
+ fi
102
+
103
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
104
+ printf 'ERROR: Missing train or eval jsonl.\n %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
105
+ exit 1
106
+ fi
107
+
108
+ mkdir -p "${OUTPUT_ROOT}/logs"
109
+
110
+ mode_tag() {
111
+ case "$1" in
112
+ residual) printf 'latent_residual' ;;
113
+ fixed_slots) printf 'latent_fixed_slots' ;;
114
+ recurrent_hidden) printf 'latent_recurrent_hidden' ;;
115
+ latent_seeds) printf 'latent_seeds' ;;
116
+ *) printf 'latent_%s' "$1" ;;
117
+ esac
118
+ }
119
+
120
+ latest_checkpoint_or_dir() {
121
+ local d="$1"
122
+ shopt -s nullglob
123
+ local checkpoints=("${d}"/checkpoint-step-*)
124
+ shopt -u nullglob
125
+ if (( ${#checkpoints[@]} > 0 )); then
126
+ printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
127
+ return 0
128
+ fi
129
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
130
+ printf '%s\n' "${d}"
131
+ return 0
132
+ fi
133
+ printf ''
134
+ return 1
135
+ }
136
+
137
+ run_baseline_sft() {
138
+ local stage="$1" init_adapter="$2" out_dir="$3" lr="$4" run_name="$5"
139
+ mkdir -p "${out_dir}"
140
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${BASELINE_SFT_SCRIPT}" \
141
+ --model_name "${MODEL_NAME}" \
142
+ --train_jsonl "${train_jsonl}" \
143
+ --eval_jsonl "${eval_jsonl}" \
144
+ --output_dir "${out_dir}" \
145
+ --cache_dir "${ROOT}/.hf_cache" \
146
+ --init_adapter_dir "${init_adapter}" \
147
+ --seed 0 \
148
+ --gpu_id 0 \
149
+ --stage_i "${stage}" \
150
+ --total_empties_hint "${EMPTIES}" \
151
+ --per_device_train_batch_size "${BASELINE_PER_DEVICE_BS}" \
152
+ --gradient_accumulation_steps "${BASELINE_GRAD_ACCUM}" \
153
+ --num_epochs "${SFT_NUM_EPOCHS}" \
154
+ --learning_rate "${lr}" \
155
+ --max_grad_norm 1.0 \
156
+ --logging_steps 20 \
157
+ --eval_steps 250 \
158
+ --save_steps 200 \
159
+ --eval_rows "${EVAL_PUZZLES}" \
160
+ --max_completion_length 24 \
161
+ --limit_train_rows "${TRAIN_PUZZLES}" \
162
+ --lora_r "${LORA_R}" \
163
+ --lora_alpha "${LORA_ALPHA}" \
164
+ --lora_dropout "${LORA_DROPOUT}" \
165
+ --eval_value_precision_stop "${VALUE_TARGET}" \
166
+ --eval_value_recall_stop "${VALUE_TARGET}" \
167
+ --eval_exact_set_match_stop 0 \
168
+ --eval_solve_rate_stop "${SOLVE_TARGET}" \
169
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
170
+ --max_wall_clock_seconds 0 \
171
+ --max_steps "${BASELINE_WARM_MAX_STEPS}" \
172
+ --use_wandb \
173
+ --wandb_project "sudoku-baseline-stage-warmups" \
174
+ --wandb_run_name "${run_name}" \
175
+ --wandb_mode "${WANDB_MODE}" \
176
+ --wandb_entity "${WANDB_ENTITY}"
177
+ }
178
+
179
+ run_latent_sft() {
180
+ local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" lr="$6" run_name="$7"
181
+ local ms1=0 ms2=1
182
+ if [[ "${stage}" == "1" ]]; then
183
+ ms1=1
184
+ ms2=0
185
+ fi
186
+ mkdir -p "${out_dir}"
187
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_SFT_SCRIPT}" \
188
+ --model_name "${MODEL_NAME}" \
189
+ --train_jsonl "${train_jsonl}" \
190
+ --eval_jsonl "${eval_jsonl}" \
191
+ --output_dir "${out_dir}" \
192
+ --cache_dir "${ROOT}/.hf_cache" \
193
+ --init_adapter_dir "${init_adapter}" \
194
+ --seed 0 \
195
+ --gpu_id 0 \
196
+ --stage_i "${stage}" \
197
+ --num_cot_tokens "${cot}" \
198
+ --latent_mode "${mode}" \
199
+ --max_latent_slots 8 \
200
+ --max_latent_seeds 8 \
201
+ --total_empties_hint "${EMPTIES}" \
202
+ --mixed_stage1_ratio "${ms1}" \
203
+ --mixed_stage2_ratio "${ms2}" \
204
+ --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
205
+ --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
206
+ --num_epochs "${SFT_NUM_EPOCHS}" \
207
+ --learning_rate "${lr}" \
208
+ --weight_decay 0.0 \
209
+ --enable_gradient_checkpointing \
210
+ --logging_steps 20 \
211
+ --eval_steps "${LATENT_SFT_EVAL_STEPS}" \
212
+ --save_steps 200 \
213
+ --eval_rows "${EVAL_PUZZLES}" \
214
+ --max_completion_length 24 \
215
+ --limit_train_rows "${TRAIN_PUZZLES}" \
216
+ --eval_value_precision_stop "${VALUE_TARGET}" \
217
+ --eval_value_recall_stop "${VALUE_TARGET}" \
218
+ --eval_exact_set_match_stop 0 \
219
+ --eval_solve_rate_stop "${SOLVE_TARGET}" \
220
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
221
+ --max_wall_clock_seconds 0 \
222
+ --max_steps "${LATENT_SFT_MAX_STEPS}" \
223
+ --reward_good_value 1.25 \
224
+ --penalty_bad_value 1.0 \
225
+ --penalty_malformed 4.0 \
226
+ --penalty_empty 0.5 \
227
+ --penalty_singleton 1.5 \
228
+ --lora_r "${LORA_R}" \
229
+ --lora_alpha "${LORA_ALPHA}" \
230
+ --lora_dropout "${LORA_DROPOUT}" \
231
+ --use_wandb \
232
+ --wandb_project "sudoku-latent-stage-sft-warm-baseline" \
233
+ --wandb_run_name "${run_name}" \
234
+ --wandb_mode "${WANDB_MODE}" \
235
+ --wandb_entity "${WANDB_ENTITY}"
236
+ }
237
+
238
+ run_latent_grpo() {
239
+ local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" run_name="$6"
240
+ mkdir -p "${out_dir}"
241
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_GRPO_SCRIPT}" \
242
+ --model_name "${MODEL_NAME}" \
243
+ --train_jsonl "${train_jsonl}" \
244
+ --eval_jsonl "${eval_jsonl}" \
245
+ --output_dir "${out_dir}" \
246
+ --cache_dir "${ROOT}/.hf_cache" \
247
+ --init_adapter_dir "${init_adapter}" \
248
+ --seed 0 \
249
+ --gpu_id 0 \
250
+ --stage_i "${stage}" \
251
+ --num_cot_tokens "${cot}" \
252
+ --latent_mode "${mode}" \
253
+ --max_latent_seeds 8 \
254
+ --total_empties_hint "${EMPTIES}" \
255
+ --mixed_stage1_ratio 0 \
256
+ --mixed_stage2_ratio 1 \
257
+ --per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \
258
+ --gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \
259
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
260
+ --learning_rate 1e-6 \
261
+ --logging_steps 20 \
262
+ --save_steps 200 \
263
+ --eval_steps 500 \
264
+ --eval_rows "${EVAL_PUZZLES}" \
265
+ --num_generations 4 \
266
+ --max_prompt_length 1024 \
267
+ --max_completion_length 24 \
268
+ --beta "${GRPO_BETA}" \
269
+ --enable_gradient_checkpointing \
270
+ --limit_train_rows "${TRAIN_PUZZLES}" \
271
+ --reward_good_value 1.25 \
272
+ --penalty_bad_value 1.0 \
273
+ --penalty_malformed 4.0 \
274
+ --penalty_empty 0.5 \
275
+ --penalty_singleton 1.5 \
276
+ --eval_value_precision_stop "${VALUE_TARGET}" \
277
+ --eval_value_recall_stop "${VALUE_TARGET}" \
278
+ --eval_solve_rate_stop "${SOLVE_TARGET}" \
279
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
280
+ --max_wall_clock_seconds 0 \
281
+ --max_steps "${LATENT_GRPO_MAX_STEPS}" \
282
+ --lora_r "${LORA_R}" \
283
+ --lora_alpha "${LORA_ALPHA}" \
284
+ --lora_dropout "${LORA_DROPOUT}" \
285
+ --use_wandb \
286
+ --wandb_project "sudoku-latent-stage-grpo-warm-baseline" \
287
+ --wandb_run_name "${run_name}" \
288
+ --wandb_mode "${WANDB_MODE}" \
289
+ --wandb_entity "${WANDB_ENTITY}"
290
+ }
291
+
292
+ run_mode_pipeline() {
293
+ local mode="$1" gpu_group="$2" tag
294
+ tag="$(mode_tag "${mode}")"
295
+ local mode_root="${OUTPUT_ROOT}/${tag}"
296
+ local log="${OUTPUT_ROOT}/logs/${tag}.log"
297
+ mkdir -p "${mode_root}"
298
+ export CUDA_VISIBLE_DEVICES="${gpu_group}"
299
+ printf 'Mode %s on GPUs %s\n' "${mode}" "${gpu_group}"
300
+
301
+ local s1_lat="${mode_root}/stage01_latent_sft_i1_${EMPTIES}empty_${tag}"
302
+ local g1="${mode_root}/stage01_latent_grpo_i1_${EMPTIES}empty_${tag}"
303
+ local a_s1_lat a_g1
304
+ if [[ -n "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then
305
+ a_g1="$(latest_checkpoint_or_dir "${STAGE1_LATENT_GRPO_ADAPTER_DIR}")"
306
+ printf 'Using existing stage-1 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g1}" | tee -a "${log}"
307
+ else
308
+ if [[ -n "${STAGE1_LATENT_SFT_ADAPTER_DIR}" ]]; then
309
+ a_s1_lat="$(latest_checkpoint_or_dir "${STAGE1_LATENT_SFT_ADAPTER_DIR}")"
310
+ printf 'Using existing stage-1 latent SFT adapter for %s: %s\n' "${mode}" "${a_s1_lat}" | tee -a "${log}"
311
+ else
312
+ run_latent_sft "${mode}" 1 1 "${STAGE1_BASELINE_ADAPTER_DIR}" "${s1_lat}" "2e-4" "warmfull_${mode}_st1_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
313
+ a_s1_lat="$(latest_checkpoint_or_dir "${s1_lat}")"
314
+ fi
315
+ run_latent_grpo "${mode}" 1 1 "${a_s1_lat}" "${g1}" "warmfull_${mode}_st1_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
316
+ a_g1="$(latest_checkpoint_or_dir "${g1}")"
317
+ fi
318
+
319
+ local b2="${mode_root}/stage02_baseline_warm_sft_i2_${EMPTIES}empty_${tag}"
320
+ local s2_lat="${mode_root}/stage02_latent_sft_i2_${EMPTIES}empty_${tag}"
321
+ local g2="${mode_root}/stage02_latent_grpo_i2_${EMPTIES}empty_${tag}"
322
+ local a_b2 a_s2_lat a_g2
323
+ if [[ -n "${STAGE2_LATENT_GRPO_ADAPTER_DIR}" ]]; then
324
+ a_g2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_GRPO_ADAPTER_DIR}")"
325
+ printf 'Using existing stage-2 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g2}" | tee -a "${log}"
326
+ else
327
+ if [[ -n "${STAGE2_LATENT_SFT_ADAPTER_DIR}" ]]; then
328
+ a_s2_lat="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_ADAPTER_DIR}")"
329
+ printf 'Using existing stage-2 latent SFT adapter for %s: %s\n' "${mode}" "${a_s2_lat}" | tee -a "${log}"
330
+ else
331
+ if [[ -n "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}" ]]; then
332
+ a_b2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}")"
333
+ printf 'Continuing stage-2 latent SFT for %s from adapter: %s\n' "${mode}" "${a_b2}" | tee -a "${log}"
334
+ elif [[ -n "${STAGE2_BASELINE_WARM_ADAPTER_DIR}" ]]; then
335
+ a_b2="$(latest_checkpoint_or_dir "${STAGE2_BASELINE_WARM_ADAPTER_DIR}")"
336
+ printf 'Using existing stage-2 baseline warm adapter for %s: %s\n' "${mode}" "${a_b2}" | tee -a "${log}"
337
+ else
338
+ run_baseline_sft 2 "${a_g1}" "${b2}" "5e-5" "warmfull_${mode}_st2_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
339
+ a_b2="$(latest_checkpoint_or_dir "${b2}")"
340
+ fi
341
+ run_latent_sft "${mode}" 2 2 "${a_b2}" "${s2_lat}" "5e-5" "warmfull_${mode}_st2_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
342
+ a_s2_lat="$(latest_checkpoint_or_dir "${s2_lat}")"
343
+ fi
344
+ run_latent_grpo "${mode}" 2 2 "${a_s2_lat}" "${g2}" "warmfull_${mode}_st2_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
345
+ a_g2="$(latest_checkpoint_or_dir "${g2}")"
346
+ fi
347
+
348
+ local b3="${mode_root}/stage03_baseline_warm_sft_i3_${EMPTIES}empty_${tag}"
349
+ local s3_lat="${mode_root}/stage03_latent_sft_i3_${EMPTIES}empty_${tag}"
350
+ local g3="${mode_root}/stage03_latent_grpo_i3_${EMPTIES}empty_${tag}"
351
+ local a_b3 a_s3_lat
352
+ if [[ -n "${STAGE3_LATENT_SFT_ADAPTER_DIR}" ]]; then
353
+ a_s3_lat="$(latest_checkpoint_or_dir "${STAGE3_LATENT_SFT_ADAPTER_DIR}")"
354
+ printf 'Using existing stage-3 latent SFT adapter for %s: %s\n' "${mode}" "${a_s3_lat}" | tee -a "${log}"
355
+ else
356
+ if [[ -n "${STAGE3_BASELINE_WARM_ADAPTER_DIR}" ]]; then
357
+ a_b3="$(latest_checkpoint_or_dir "${STAGE3_BASELINE_WARM_ADAPTER_DIR}")"
358
+ printf 'Using existing stage-3 baseline warm adapter for %s: %s\n' "${mode}" "${a_b3}" | tee -a "${log}"
359
+ else
360
+ run_baseline_sft 3 "${a_g2}" "${b3}" "5e-5" "warmfull_${mode}_st3_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
361
+ a_b3="$(latest_checkpoint_or_dir "${b3}")"
362
+ fi
363
+ run_latent_sft "${mode}" 3 3 "${a_b3}" "${s3_lat}" "5e-5" "warmfull_${mode}_st3_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
364
+ a_s3_lat="$(latest_checkpoint_or_dir "${s3_lat}")"
365
+ fi
366
+ run_latent_grpo "${mode}" 3 3 "${a_s3_lat}" "${g3}" "warmfull_${mode}_st3_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
367
+
368
+ printf 'Mode %s finished. Output: %s\n' "${mode}" "${mode_root}" | tee -a "${log}"
369
+ }
370
+
371
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
372
+ printf 'Stage-1 baseline adapter: %s\n' "${STAGE1_BASELINE_ADAPTER_DIR}"
373
+ printf 'Solve target: %s (value target: %s)\n' "${SOLVE_TARGET}" "${VALUE_TARGET}"
374
+
375
+ pids=()
376
+ for i in "${!MODES[@]}"; do
377
+ (
378
+ run_mode_pipeline "${MODES[$i]}" "${GPU_GROUPS[$i]}"
379
+ ) >"${OUTPUT_ROOT}/logs/$(mode_tag "${MODES[$i]}").supervisor.log" 2>&1 &
380
+ pids+=("$!")
381
+ printf 'Launched mode=%s pid=%s gpus=%s\n' "${MODES[$i]}" "${pids[-1]}" "${GPU_GROUPS[$i]}"
382
+ done
383
+
384
+ failed=0
385
+ for i in "${!pids[@]}"; do
386
+ if wait "${pids[$i]}"; then
387
+ printf 'DONE: %s\n' "${MODES[$i]}"
388
+ else
389
+ printf 'FAILED: %s (pid=%s). See %s/logs\n' "${MODES[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2
390
+ failed=1
391
+ fi
392
+ done
393
+
394
+ exit "${failed}"
hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Latent residual projector pipeline (7-empty), aligned with the text
3
+ # launch_7empty_post_s1sft_stages123_value98.sh order and value gate:
4
+ # 1) Stage-1 SFT (default: init = STAGE1_INIT_ADAPTER_DIR or fresh LoRA + random residual)
5
+ # 2) Stage-1 GRPO (init = stage-1 SFT checkpoint-step-* dir, or STAGE1_SFT_ADAPTER_DIR if set)
6
+ # 3) Stage-2 SFT (init = stage-1 GRPO adapter)
7
+ # 4) Stage-2 GRPO
8
+ # 5) Stage-3 SFT
9
+ # 6) Stage-3 GRPO
10
+ #
11
+ # Legacy GRPO-first (skip training stage-1 SFT): STAGE1_GRPO_FIRST=1
12
+ #
13
+ # Latent structure (implemented in latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py):
14
+ # - attach_residual_projector_modules(): adds trainable special_thought_embed, latent_mix_logit,
15
+ # and MLP latent_projector_in/out (hidden→4096→hidden) on the Peft-wrapped model.
16
+ # - build_latent_hidden() / residual_next_token_logits_from_ids(): append num_cot_tokens "latent"
17
+ # virtual tokens, run backbone, take (latent_hidden - base_hidden), project through the MLP,
18
+ # mix with base hidden (sigmoid(latent_mix_logit)), then lm_head logits (with optional fallback).
19
+ # - sample_latent_completion() / GRPO use this path for generation; SFT uses the same via
20
+ # residual_projector_warmstart_sft_latent_multi_output_train.py (latent_residual_completion_ce_loss).
21
+ # - latent_cot_state.pt saves/loads the projector + special_thought_embed + mix logit.
22
+ #
23
+ # Each phase stops when eval value_precision AND value_recall are both >= VALUE_TARGET
24
+ # (default 0.98), after MIN_STEPS_BEFORE_STOP optimizer steps (SFT) / GRPO steps (GRPO).
25
+ # Eval rows come from eval_jsonl (same held-out file as the text pipeline).
26
+ #
27
+ # Stage-1 SFT init (when not using STAGE1_SFT_ADAPTER_DIR or STAGE1_GRPO_FIRST):
28
+ # Default: omit STAGE1_INIT_ADAPTER_DIR → fresh LoRA + random residual (same as trainers --init_adapter_dir "").
29
+ # Optional: STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
30
+ #
31
+ # Skip running stage-1 SFT (you already have a finished SFT checkpoint-step-*):
32
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft_.../checkpoint-step-XXXX
33
+ # → first trained phase is stage-1 GRPO with that init.
34
+ #
35
+ # Resume:
36
+ # RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty_latent_residual
37
+ # START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty_latent_residual
38
+ # START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty_latent_residual
39
+ #
40
+ # Optional env: VALUE_TARGET, TRAIN_PUZZLES, EVAL_PUZZLES, RUN_TAG, CHECKPOINT_ROOT, GPU_IDS,
41
+ # WANDB_MODE, WANDB_ENTITY, SFT_NUM_EPOCHS, GRPO_NUM_TRAIN_EPOCHS, SFT_MAX_STEPS, GRPO_MAX_STEPS,
42
+ # STAGE1_SFT_LR (default 2e-4), STAGE1_GRPO_FIRST, STAGE1_SFT_ADAPTER_DIR
43
+ #
44
+ set -euo pipefail
45
+
46
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
47
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
48
+
49
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
50
+ SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py"
51
+ GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
52
+
53
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
54
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
55
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
56
+ WANDB_MODE="${WANDB_MODE:-online}"
57
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
58
+
59
+ EMPTIES=7
60
+ TAG_SUFFIX="latent_residual"
61
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
62
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
63
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
64
+ MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
65
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
66
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
67
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
68
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
69
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
70
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
71
+ # Fresh-LoRA defaults (match text 7-empty SFT scale); override if you use a different init checkpoint.
72
+ LORA_R="${LORA_R:-32}"
73
+ LORA_ALPHA="${LORA_ALPHA:-64}"
74
+ LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
75
+
76
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_latent_residual_stages123_value98}"
77
+ START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
78
+ START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
79
+ RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
80
+ STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
81
+ STAGE1_GRPO_FIRST="${STAGE1_GRPO_FIRST:-0}"
82
+ STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
83
+
84
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
85
+ if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
86
+ printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
87
+ exit 1
88
+ fi
89
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
90
+ elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
91
+ if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
92
+ printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
93
+ exit 1
94
+ fi
95
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
96
+ elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
97
+ if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
98
+ printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
99
+ exit 1
100
+ fi
101
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
102
+ else
103
+ if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then
104
+ printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2
105
+ exit 1
106
+ fi
107
+ if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
108
+ printf 'ERROR: STAGE1_SFT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" >&2
109
+ exit 1
110
+ fi
111
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
112
+ fi
113
+
114
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
115
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
116
+
117
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
118
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
119
+
120
+ mkdir -p "${OUTPUT_ROOT}"
121
+
122
+ latest_sft_step_ckpt() {
123
+ local d="$1"
124
+ shopt -s nullglob
125
+ local cks=("${d}"/checkpoint-step-*)
126
+ shopt -u nullglob
127
+ if (( ${#cks[@]} == 0 )); then
128
+ printf ''
129
+ return 1
130
+ fi
131
+ set +o pipefail
132
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
133
+ set -o pipefail
134
+ }
135
+
136
+ resolve_latent_grpo_adapter() {
137
+ local d="$1"
138
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
139
+ printf '%s\n' "${d}"
140
+ return 0
141
+ fi
142
+ local best="" step=-1
143
+ shopt -s nullglob
144
+ local c
145
+ for c in "${d}"/checkpoint-*; do
146
+ [[ -d "${c}" ]] || continue
147
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
148
+ local n
149
+ n="${c##*checkpoint-}"
150
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
151
+ step=$((10#${n}))
152
+ best="${c}"
153
+ fi
154
+ done
155
+ shopt -u nullglob
156
+ if [[ -n "${best}" ]]; then
157
+ printf '%s\n' "${best}"
158
+ return 0
159
+ fi
160
+ printf ''
161
+ return 1
162
+ }
163
+
164
+ run_latent_sft() {
165
+ local stage="$1"
166
+ local init_adapter="$2"
167
+ local out_dir="$3"
168
+ local lr="$4"
169
+ local cot="$5"
170
+ # Stage-1 SFT must weight stage-1 rows only (mixed 1/0). Stages 2–3 use stage-i curriculum (mixed 0/1).
171
+ local ms1=0 ms2=1
172
+ if [[ "${stage}" == "1" ]]; then
173
+ ms1=1
174
+ ms2=0
175
+ fi
176
+ mkdir -p "${out_dir}"
177
+ printf '\n=== Latent stage %s SFT (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
178
+ printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
179
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
180
+ --model_name "${MODEL_NAME}" \
181
+ --train_jsonl "${train_jsonl}" \
182
+ --output_dir "${out_dir}" \
183
+ --cache_dir "${ROOT}/.hf_cache" \
184
+ --init_adapter_dir "${init_adapter}" \
185
+ --seed 0 \
186
+ --gpu_id 0 \
187
+ --stage_i "${stage}" \
188
+ --num_cot_tokens "${cot}" \
189
+ --total_empties_hint "${EMPTIES}" \
190
+ --mixed_stage1_ratio "${ms1}" \
191
+ --mixed_stage2_ratio "${ms2}" \
192
+ --gradient_accumulation_steps 2 \
193
+ --num_epochs "${SFT_NUM_EPOCHS}" \
194
+ --learning_rate "${lr}" \
195
+ --weight_decay 0.0 \
196
+ --enable_gradient_checkpointing \
197
+ --logging_steps 20 \
198
+ --eval_steps 250 \
199
+ --save_steps 200 \
200
+ --eval_rows "${EVAL_PUZZLES}" \
201
+ --eval_jsonl "${eval_jsonl}" \
202
+ --max_completion_length 24 \
203
+ --limit_train_rows "${TRAIN_PUZZLES}" \
204
+ --eval_value_precision_stop "${VALUE_TARGET}" \
205
+ --eval_value_recall_stop "${VALUE_TARGET}" \
206
+ --eval_exact_set_match_stop 0 \
207
+ --eval_solve_rate_stop 0 \
208
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
209
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
210
+ --max_steps "${SFT_MAX_STEPS}" \
211
+ --reward_good_value 1.25 \
212
+ --penalty_bad_value 1.0 \
213
+ --penalty_malformed 4.0 \
214
+ --penalty_empty 0.5 \
215
+ --penalty_singleton 1.5 \
216
+ --lora_r "${LORA_R}" \
217
+ --lora_alpha "${LORA_ALPHA}" \
218
+ --lora_dropout "${LORA_DROPOUT}" \
219
+ --use_wandb \
220
+ --wandb_project "sudoku-latent-multi-output-sft-residual-projector" \
221
+ --wandb_run_name "latent7_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
222
+ --wandb_mode "${WANDB_MODE}" \
223
+ --wandb_entity "${WANDB_ENTITY}"
224
+ }
225
+
226
+ run_latent_grpo() {
227
+ local stage="$1"
228
+ local init_adapter="$2"
229
+ local out_dir="$3"
230
+ local cot="$4"
231
+ mkdir -p "${out_dir}"
232
+ printf '\n=== Latent stage %s GRPO (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
233
+ printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
234
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
235
+ --model_name "${MODEL_NAME}" \
236
+ --train_jsonl "${train_jsonl}" \
237
+ --output_dir "${out_dir}" \
238
+ --cache_dir "${ROOT}/.hf_cache" \
239
+ --init_adapter_dir "${init_adapter}" \
240
+ --seed 0 \
241
+ --gpu_id 0 \
242
+ --stage_i "${stage}" \
243
+ --num_cot_tokens "${cot}" \
244
+ --total_empties_hint "${EMPTIES}" \
245
+ --mixed_stage1_ratio 0 \
246
+ --mixed_stage2_ratio 1 \
247
+ --per_device_train_batch_size 8 \
248
+ --gradient_accumulation_steps 2 \
249
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
250
+ --learning_rate 1e-6 \
251
+ --logging_steps 20 \
252
+ --save_steps 200 \
253
+ --eval_steps 500 \
254
+ --eval_rows "${EVAL_PUZZLES}" \
255
+ --eval_jsonl "${eval_jsonl}" \
256
+ --num_generations 4 \
257
+ --max_prompt_length 1024 \
258
+ --max_completion_length 24 \
259
+ --beta 0.0 \
260
+ --enable_gradient_checkpointing \
261
+ --limit_train_rows "${TRAIN_PUZZLES}" \
262
+ --reward_good_value 1.25 \
263
+ --penalty_bad_value 1.0 \
264
+ --penalty_malformed 4.0 \
265
+ --penalty_empty 0.5 \
266
+ --penalty_singleton 1.5 \
267
+ --eval_value_precision_stop "${VALUE_TARGET}" \
268
+ --eval_value_recall_stop "${VALUE_TARGET}" \
269
+ --eval_solve_rate_stop 0 \
270
+ --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
271
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
272
+ --max_steps "${GRPO_MAX_STEPS}" \
273
+ --lora_r "${LORA_R}" \
274
+ --lora_alpha "${LORA_ALPHA}" \
275
+ --lora_dropout "${LORA_DROPOUT}" \
276
+ --use_wandb \
277
+ --wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \
278
+ --wandb_run_name "latent7_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
279
+ --wandb_mode "${WANDB_MODE}" \
280
+ --wandb_entity "${WANDB_ENTITY}"
281
+ }
282
+
283
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
284
+ printf 'ERROR: Missing train or eval jsonl.\n' >&2
285
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
286
+ exit 1
287
+ fi
288
+
289
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
290
+ printf 'Fast-forward: stage-2 latent SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
291
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
292
+ S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
293
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
294
+ if [[ -z "${CKPT_S2}" ]]; then
295
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
296
+ exit 1
297
+ fi
298
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
299
+ run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
300
+ A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
301
+ if [[ -z "${A2}" ]]; then
302
+ printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
303
+ exit 1
304
+ fi
305
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
306
+ run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
307
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
308
+ if [[ -z "${CKPT_S3}" ]]; then
309
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
310
+ exit 1
311
+ fi
312
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
313
+ run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
314
+ A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
315
+ if [[ -z "${A3}" ]]; then
316
+ printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
317
+ exit 1
318
+ fi
319
+ printf '\nAll latent phases finished (started at stage-2 GRPO).\n'
320
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
321
+ printf 'Final latent GRPO adapter: %s\n' "${A3}"
322
+ exit 0
323
+ fi
324
+
325
+ if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
326
+ printf 'Fast-forward: stage-2 latent GRPO dir %s → stage-3 SFT + GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
327
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
328
+ A2="$(resolve_latent_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
329
+ if [[ -z "${A2}" ]]; then
330
+ printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
331
+ exit 1
332
+ fi
333
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
334
+ run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
335
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
336
+ if [[ -z "${CKPT_S3}" ]]; then
337
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
338
+ exit 1
339
+ fi
340
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
341
+ run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
342
+ A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
343
+ if [[ -z "${A3}" ]]; then
344
+ printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
345
+ exit 1
346
+ fi
347
+ printf '\nAll latent phases finished (started after stage-2 GRPO).\n'
348
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
349
+ printf 'Final latent GRPO adapter: %s\n' "${A3}"
350
+ exit 0
351
+ fi
352
+
353
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
354
+ printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}"
355
+
356
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
357
+ S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
358
+ STAGE1_INIT="${STAGE1_INIT_ADAPTER_DIR:-}"
359
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
360
+ A1="$(resolve_latent_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
361
+ elif [[ "${STAGE1_GRPO_FIRST}" == "1" ]]; then
362
+ # Legacy: stage-1 GRPO first (fresh LoRA + random residual unless STAGE1_INIT_ADAPTER_DIR set).
363
+ run_latent_grpo 1 "${STAGE1_INIT}" "${G1_DIR}" 1
364
+ A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
365
+ else
366
+ # Default: stage-1 SFT → stage-1 GRPO (matches text post-s1sft pipeline).
367
+ if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
368
+ G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
369
+ printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
370
+ else
371
+ run_latent_sft 1 "${STAGE1_INIT}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
372
+ G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
373
+ if [[ -z "${G1_SFT_CKPT}" ]]; then
374
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
375
+ exit 1
376
+ fi
377
+ fi
378
+ run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
379
+ A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
380
+ fi
381
+ if [[ -z "${A1}" ]]; then
382
+ printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
383
+ exit 1
384
+ fi
385
+ printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
386
+
387
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
388
+ run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
389
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
390
+ if [[ -z "${CKPT_S2}" ]]; then
391
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
392
+ exit 1
393
+ fi
394
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
395
+ run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
396
+ A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
397
+ if [[ -z "${A2}" ]]; then
398
+ printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
399
+ exit 1
400
+ fi
401
+
402
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
403
+ run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
404
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
405
+ if [[ -z "${CKPT_S3}" ]]; then
406
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
407
+ exit 1
408
+ fi
409
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
410
+ run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
411
+ A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
412
+ if [[ -z "${A3}" ]]; then
413
+ printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
414
+ exit 1
415
+ fi
416
+
417
+ printf '\nAll latent residual phases finished.\n'
418
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
419
+ printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run AFTER stage-1 SFT finishes (7-empty). Order:
3
+ # 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
4
+ # 2) Stage-2 SFT (init = stage-1 GRPO adapter)
5
+ # 3) Stage-2 GRPO (init = stage-2 SFT adapter)
6
+ # 4) Stage-3 SFT (init = stage-2 GRPO adapter)
7
+ # 5) Stage-3 GRPO (init = stage-3 SFT adapter)
8
+ #
9
+ # Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
10
+ # are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
11
+ # very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
12
+ # (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
13
+ #
14
+ # Required (full pipeline from stage-1 SFT):
15
+ # STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
16
+ #
17
+ # Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
18
+ # RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty
19
+ # (OUTPUT_ROOT defaults to dirname of that dir.)
20
+ #
21
+ # Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
22
+ # START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty
23
+ #
24
+ # Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
25
+ # START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty
26
+ #
27
+ # Optional:
28
+ # VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
29
+ # TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
30
+ #
31
+ set -euo pipefail
32
+
33
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
34
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
35
+
36
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
37
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
38
+ GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
39
+
40
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
41
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
42
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
43
+ WANDB_MODE="${WANDB_MODE:-online}"
44
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
45
+
46
+ EMPTIES=7
47
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
48
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
49
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
50
+ # Large caps so training is gated by VALUE_TARGET, not an arbitrary step limit.
51
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
52
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
53
+ SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
54
+ GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
55
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
56
+
57
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_stages123_value98}"
58
+ START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
59
+ START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
60
+ RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
61
+
62
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
63
+ if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
64
+ printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
65
+ exit 1
66
+ fi
67
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
68
+ elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
69
+ if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
70
+ printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
71
+ exit 1
72
+ fi
73
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
74
+ elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
75
+ if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
76
+ printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
77
+ exit 1
78
+ fi
79
+ # Keep stage02+ next to stage01 (same parent dir). Override with OUTPUT_ROOT=... if needed.
80
+ OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
81
+ else
82
+ if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
83
+ printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
84
+ exit 1
85
+ fi
86
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
87
+ fi
88
+
89
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
90
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
91
+
92
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
93
+
94
+ mkdir -p "${OUTPUT_ROOT}"
95
+
96
+ latest_sft_step_ckpt() {
97
+ local d="$1"
98
+ shopt -s nullglob
99
+ local cks=("${d}"/checkpoint-step-*)
100
+ shopt -u nullglob
101
+ if (( ${#cks[@]} == 0 )); then
102
+ printf ''
103
+ return 1
104
+ fi
105
+ # Avoid set -o pipefail treating a pipeline edge case as failure (was breaking the caller).
106
+ set +o pipefail
107
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
108
+ set -o pipefail
109
+ }
110
+
111
+ # GRPO: prefer final root adapter; else latest TRL checkpoint dir with adapter weights.
112
+ resolve_grpo_adapter() {
113
+ local d="$1"
114
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
115
+ printf '%s\n' "${d}"
116
+ return 0
117
+ fi
118
+ local best="" step=-1
119
+ shopt -s nullglob
120
+ local c
121
+ for c in "${d}"/checkpoint-*; do
122
+ [[ -d "${c}" ]] || continue
123
+ [[ -f "${c}/adapter_model.safetensors" ]] || continue
124
+ local n
125
+ n="${c##*checkpoint-}"
126
+ if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
127
+ step=$((10#${n}))
128
+ best="${c}"
129
+ fi
130
+ done
131
+ shopt -u nullglob
132
+ if [[ -n "${best}" ]]; then
133
+ printf '%s\n' "${best}"
134
+ return 0
135
+ fi
136
+ printf ''
137
+ return 1
138
+ }
139
+
140
+ GC_FLAGS=()
141
+ if [[ "${USE_GC:-0}" == "1" ]]; then
142
+ GC_FLAGS+=(--enable_gradient_checkpointing)
143
+ fi
144
+
145
+ run_sft() {
146
+ local stage="$1"
147
+ local init_adapter="$2"
148
+ local out_dir="$3"
149
+ local lr="$4"
150
+ mkdir -p "${out_dir}"
151
+ printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
152
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
153
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
154
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
155
+ --train_jsonl "${train_jsonl}" \
156
+ --eval_jsonl "${eval_jsonl}" \
157
+ --output_dir "${out_dir}" \
158
+ --cache_dir "${ROOT}/.hf_cache" \
159
+ --init_adapter_dir "${init_adapter}" \
160
+ --seed 0 \
161
+ --gpu_id 0 \
162
+ --stage_i "${stage}" \
163
+ --total_empties_hint "${EMPTIES}" \
164
+ --per_device_train_batch_size 16 \
165
+ --gradient_accumulation_steps 2 \
166
+ --num_epochs "${SFT_NUM_EPOCHS}" \
167
+ --learning_rate "${lr}" \
168
+ --max_grad_norm 1.0 \
169
+ "${GC_FLAGS[@]}" \
170
+ --logging_steps 20 \
171
+ --eval_steps 250 \
172
+ --save_steps 200 \
173
+ --eval_rows "${EVAL_PUZZLES}" \
174
+ --max_completion_length 24 \
175
+ --limit_train_rows "${TRAIN_PUZZLES}" \
176
+ --lora_r 32 \
177
+ --lora_alpha 64 \
178
+ --lora_dropout 0.05 \
179
+ --eval_value_precision_stop "${VALUE_TARGET}" \
180
+ --eval_value_recall_stop "${VALUE_TARGET}" \
181
+ --eval_exact_set_match_stop 0 \
182
+ --eval_solve_rate_stop 0 \
183
+ --min_steps_before_stop 50 \
184
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
185
+ --max_steps "${SFT_MAX_STEPS}" \
186
+ --use_wandb \
187
+ --wandb_project "sudoku-multi-output-sft" \
188
+ --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
189
+ --wandb_mode "${WANDB_MODE}" \
190
+ --wandb_entity "${WANDB_ENTITY}"
191
+ }
192
+
193
+ run_grpo() {
194
+ local stage="$1"
195
+ local init_adapter="$2"
196
+ local out_dir="$3"
197
+ mkdir -p "${out_dir}"
198
+ printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
199
+ printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
200
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
201
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
202
+ --train_jsonl "${train_jsonl}" \
203
+ --eval_jsonl "${eval_jsonl}" \
204
+ --output_dir "${out_dir}" \
205
+ --cache_dir "${ROOT}/.hf_cache" \
206
+ --init_adapter_dir "${init_adapter}" \
207
+ --seed 0 \
208
+ --gpu_id 0 \
209
+ --stage_i "${stage}" \
210
+ --total_empties_hint "${EMPTIES}" \
211
+ --per_device_train_batch_size 8 \
212
+ --gradient_accumulation_steps 2 \
213
+ --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
214
+ --learning_rate 1e-6 \
215
+ --logging_steps 20 \
216
+ --save_steps 200 \
217
+ --eval_steps 500 \
218
+ --eval_rows "${EVAL_PUZZLES}" \
219
+ --num_generations 4 \
220
+ --max_prompt_length 1024 \
221
+ --max_completion_length 24 \
222
+ --beta 0.0 \
223
+ --enable_gradient_checkpointing \
224
+ --limit_train_rows "${TRAIN_PUZZLES}" \
225
+ --reward_good_value 1.25 \
226
+ --penalty_bad_value 1.0 \
227
+ --penalty_malformed 4.0 \
228
+ --penalty_empty 0.5 \
229
+ --penalty_singleton 1.5 \
230
+ --eval_value_precision_stop "${VALUE_TARGET}" \
231
+ --eval_value_recall_stop "${VALUE_TARGET}" \
232
+ --eval_solve_rate_stop 0 \
233
+ --min_steps_before_stop 50 \
234
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
235
+ --max_steps "${GRPO_MAX_STEPS}" \
236
+ --use_wandb \
237
+ --wandb_project "sudoku-multi-output-grpo" \
238
+ --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
239
+ --wandb_mode "${WANDB_MODE}" \
240
+ --wandb_entity "${WANDB_ENTITY}"
241
+ }
242
+
243
+ if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
244
+ printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
245
+ printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
246
+ exit 1
247
+ fi
248
+
249
+ if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
250
+ printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
251
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
252
+ S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
253
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
254
+ if [[ -z "${CKPT_S2}" ]]; then
255
+ printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
256
+ exit 1
257
+ fi
258
+ printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
259
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
260
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
261
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
262
+ if [[ -z "${A2}" ]]; then
263
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
264
+ exit 1
265
+ fi
266
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
267
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
268
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
269
+ if [[ -z "${CKPT_S3}" ]]; then
270
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
271
+ exit 1
272
+ fi
273
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
274
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
275
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
276
+ if [[ -z "${A3}" ]]; then
277
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
278
+ exit 1
279
+ fi
280
+ printf '\nAll phases finished (started at stage-2 GRPO).\n'
281
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
282
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
283
+ exit 0
284
+ fi
285
+
286
+ if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
287
+ printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
288
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
289
+ A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
290
+ if [[ -z "${A2}" ]]; then
291
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
292
+ exit 1
293
+ fi
294
+ printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
295
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
296
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
297
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
298
+ if [[ -z "${CKPT_S3}" ]]; then
299
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
300
+ exit 1
301
+ fi
302
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
303
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
304
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
305
+ if [[ -z "${A3}" ]]; then
306
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
307
+ exit 1
308
+ fi
309
+ printf '\nAll phases finished (started after stage-2 GRPO).\n'
310
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
311
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
312
+ exit 0
313
+ fi
314
+
315
+ printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
316
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
317
+ printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
318
+ else
319
+ printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
320
+ fi
321
+ printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
322
+ "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
323
+
324
+ # --- Stage 1 GRPO (skip if resuming) ---
325
+ G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
326
+ if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
327
+ A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
328
+ else
329
+ run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
330
+ A1="$(resolve_grpo_adapter "${G1_DIR}")"
331
+ fi
332
+ if [[ -z "${A1}" ]]; then
333
+ printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
334
+ exit 1
335
+ fi
336
+ printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
337
+
338
+ # --- Stage 2 SFT + GRPO ---
339
+ S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
340
+ run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
341
+ CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
342
+ if [[ -z "${CKPT_S2}" ]]; then
343
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
344
+ exit 1
345
+ fi
346
+ G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
347
+ run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
348
+ A2="$(resolve_grpo_adapter "${G2_DIR}")"
349
+ if [[ -z "${A2}" ]]; then
350
+ printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
351
+ exit 1
352
+ fi
353
+
354
+ # --- Stage 3 SFT + GRPO ---
355
+ S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
356
+ run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
357
+ CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
358
+ if [[ -z "${CKPT_S3}" ]]; then
359
+ printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
360
+ exit 1
361
+ fi
362
+ G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
363
+ run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
364
+ A3="$(resolve_grpo_adapter "${G3_DIR}")"
365
+ if [[ -z "${A3}" ]]; then
366
+ printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
367
+ exit 1
368
+ fi
369
+
370
+ printf '\nAll phases finished.\n'
371
+ printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
372
+ printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Stage-1 SFT only (no GRPO): train until eval value_precision AND value_recall both >= SFT_TARGET
3
+ # (default 0.95), or max_steps / optional wall clock. Use this before tuning GRPO.
4
+ #
5
+ # Fresh LoRA on base model:
6
+ # ./launch_sft_stage1_95p.sh
7
+ # EMPTIES=10 ./launch_sft_stage1_95p.sh
8
+ #
9
+ # Continue from a prior SFT checkpoint (recommended after a 90% run plateaus):
10
+ # INIT_ADAPTER_DIR=/path/to/checkpoint-step-01200 ./launch_sft_stage1_95p.sh
11
+ #
12
+ # Optional W&B display name:
13
+ # WANDB_RUN_NAME=my_run_name SFT_TARGET=0.99 MAX_STEPS=30000 INIT_ADAPTER_DIR=... ./launch_sft_stage1_95p.sh
14
+ #
15
+ set -euo pipefail
16
+
17
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
18
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
19
+
20
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
21
+ DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
22
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
23
+
24
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
25
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
26
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
27
+ WANDB_MODE="${WANDB_MODE:-online}"
28
+ WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
29
+
30
+ EMPTIES="${EMPTIES:-7}"
31
+ TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
32
+ # Held-out eval: 100 rows lines up with the baseline queue and stabilizes 95% gate vs tiny eval.
33
+ EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
34
+ SFT_TARGET="${SFT_TARGET:-0.95}"
35
+ PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
36
+ MAX_STEPS="${MAX_STEPS:-12000}"
37
+
38
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_sft95_stage1}"
39
+ OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft95}"
40
+
41
+ train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
42
+ eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
43
+
44
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
45
+
46
+ if [[ ! -f "${train_jsonl}" ]]; then
47
+ mkdir -p "$(dirname "${train_jsonl}")"
48
+ printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
49
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles 10000 --empties "${EMPTIES}" --seed 0
50
+ fi
51
+ if [[ ! -f "${eval_jsonl}" ]]; then
52
+ mkdir -p "$(dirname "${eval_jsonl}")"
53
+ printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
54
+ "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
55
+ fi
56
+
57
+ mkdir -p "${OUTPUT_DIR}"
58
+
59
+ INIT_FLAGS=()
60
+ if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
61
+ INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
62
+ printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
63
+ fi
64
+
65
+ GC_FLAGS=()
66
+ if [[ "${USE_GC:-0}" == "1" ]]; then
67
+ GC_FLAGS+=(--enable_gradient_checkpointing)
68
+ printf 'NOTE: USE_GC=1 — slower, less VRAM.\n'
69
+ fi
70
+
71
+ if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
72
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
73
+ else
74
+ printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
75
+ fi
76
+ printf 'Output: %s\n' "${OUTPUT_DIR}"
77
+
78
+ exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
79
+ --model_name "Qwen/Qwen2.5-0.5B-Instruct" \
80
+ --train_jsonl "${train_jsonl}" \
81
+ --eval_jsonl "${eval_jsonl}" \
82
+ --output_dir "${OUTPUT_DIR}" \
83
+ --cache_dir "${ROOT}/.hf_cache" \
84
+ "${INIT_FLAGS[@]}" \
85
+ --seed 0 \
86
+ --gpu_id 0 \
87
+ --stage_i 1 \
88
+ --total_empties_hint "${EMPTIES}" \
89
+ --per_device_train_batch_size 16 \
90
+ --gradient_accumulation_steps 2 \
91
+ --num_epochs 24.0 \
92
+ --learning_rate 2e-4 \
93
+ --max_grad_norm 1.0 \
94
+ "${GC_FLAGS[@]}" \
95
+ --logging_steps 20 \
96
+ --eval_steps 250 \
97
+ --save_steps 100 \
98
+ --eval_rows "${EVAL_PUZZLES}" \
99
+ --max_completion_length 24 \
100
+ --limit_train_rows "${TRAIN_PUZZLES}" \
101
+ --lora_r 32 \
102
+ --lora_alpha 64 \
103
+ --lora_dropout 0.05 \
104
+ --eval_value_precision_stop "${SFT_TARGET}" \
105
+ --eval_value_recall_stop "${SFT_TARGET}" \
106
+ --min_steps_before_stop 50 \
107
+ --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
108
+ --max_steps "${MAX_STEPS}" \
109
+ --use_wandb \
110
+ --wandb_project "sudoku-multi-output-sft" \
111
+ --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft95_i1_${EMPTIES}empty_${RUN_TAG}}" \
112
+ --wandb_mode "${WANDB_MODE}" \
113
+ --wandb_entity "${WANDB_ENTITY}"
hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recurrent Hidden Stage 2 Resume Summary
2
+
3
+ This note summarizes the May 16, 2026 stage-2 recurrent-hidden SFT recovery and monitoring changes.
4
+
5
+ ## What Changed
6
+
7
+ - Added explicit eval lifecycle logging in `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`.
8
+ - W&B now receives `eval/in_progress`, `eval/rows`, and `eval/duration_seconds`.
9
+ - Local logs now print `[latent sft eval start ...]` and `[latent sft eval end ...]` markers.
10
+
11
+ ## Why
12
+
13
+ The previous resumed stage-2 run reached step 2000 but appeared silent during validation. The validation metrics only logged after the whole eval completed, and the old `eval_rows=100` setting made a single validation take roughly 35 minutes. The run then crashed before producing the step-2000 eval metrics or checkpoint.
14
+
15
+ ## Probe Result
16
+
17
+ A one-GPU eval probe from `checkpoint-step-01800` measured validation cost:
18
+
19
+ - Eval rows: 20 puzzles
20
+ - Eval duration: 427.3 seconds, about 7.1 minutes
21
+ - Exact set match: 0.9225
22
+ - Value precision: 0.945
23
+ - Value recall: 0.934
24
+ - Solve rate: 0.15
25
+ - W&B run: `xudqbjqh`
26
+
27
+ ## Active Resume Run
28
+
29
+ The main run was restarted from:
30
+
31
+ `final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_from200_20260515_205857/latent_recurrent_hidden/stage02_latent_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01800`
32
+
33
+ Run settings:
34
+
35
+ - Stage: 2
36
+ - Latent mode: recurrent_hidden
37
+ - GPUs: 8
38
+ - Eval rows: 20
39
+ - Eval interval: every 100 steps
40
+ - Checkpoint interval: every 100 steps
41
+ - Max steps: 5000
42
+ - Early stop: disabled for solve rate; precision and recall target set to 0.9999
43
+ - W&B run: `h3lxi62v`
44
+
45
+ At the first eval:
46
+
47
+ - Step: 100
48
+ - Eval duration: 427.3 seconds
49
+ - Exact set match: 0.935
50
+ - Value precision: 0.95875
51
+ - Value recall: 0.94875
52
+ - Solve rate: 0.25
53
+
54
+ ## Checkpoint Sync
55
+
56
+ The run output is periodically synced to Hugging Face every 10 minutes:
57
+
58
+ `Avra98/sudoku-latent-recurrent-hidden-20empty-stages/resume_runs/recurrent_hidden_resume_stage2sft_from1800_eval20_long_20260516_090446`
59
+
60
+ Confirmed uploaded checkpoint:
61
+
62
+ - `checkpoint-step-00100/adapter_model.safetensors`
63
+ - `checkpoint-step-00100/adapter_config.json`
64
+ - `checkpoint-step-00100/tokenizer.json`
65
+ - `checkpoint-step-00100/tokenizer_config.json`
66
+ - `checkpoint-step-00100/chat_template.jinja`
67
+ - `checkpoint-step-00100/README.md`
68
+
hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recurrent-Hidden 20-Empty Stage-2 SFT Resume
2
+
3
+ This note records the recovered recurrent-hidden run restarted on May 15, 2026.
4
+
5
+ ## Source Checkpoints
6
+
7
+ Recovered adapters were downloaded from:
8
+
9
+ ```text
10
+ https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages
11
+ ```
12
+
13
+ Local snapshot path:
14
+
15
+ ```text
16
+ /home/ubuntu/curriculum_cot/final_checkpoint/hf_sudoku_latent_recurrent_hidden_20empty_stages
17
+ ```
18
+
19
+ Available recovered folders:
20
+
21
+ ```text
22
+ stage01_latent_sft_i1_20empty_latent_recurrent_hidden
23
+ stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
24
+ stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden
25
+ stage02_latent_sft_i2_20empty_latent_recurrent_hidden
26
+ ```
27
+
28
+ The uploaded stage-2 latent SFT checkpoint did not include `trainer_state.json`
29
+ or solve-rate metadata, so the restart intentionally resumes from the stage-2
30
+ baseline warm-up adapter and reruns stage-2 latent SFT instead of jumping to
31
+ stage-2 GRPO.
32
+
33
+ ## Active Resume Run
34
+
35
+ Output root:
36
+
37
+ ```text
38
+ /home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858
39
+ ```
40
+
41
+ W&B run:
42
+
43
+ ```text
44
+ https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1vyq1a1n
45
+ ```
46
+
47
+ Launch settings:
48
+
49
+ ```text
50
+ MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
51
+ MODES_SPEC=recurrent_hidden
52
+ GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7
53
+ NPROC_PER_JOB=8
54
+ STAGE1_LATENT_GRPO_ADAPTER_DIR=<HF snapshot>/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
55
+ STAGE2_BASELINE_WARM_ADAPTER_DIR=<HF snapshot>/stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01000
56
+ LATENT_SFT_MAX_STEPS=5000
57
+ LATENT_GRPO_MAX_STEPS=500
58
+ SOLVE_TARGET=0.95
59
+ VALUE_TARGET=0
60
+ MIN_STEPS_BEFORE_STOP=50
61
+ WANDB_MODE=online
62
+ WANDB_ENTITY=training-dynamics
63
+ ```
64
+
65
+ ## Backup Plan
66
+
67
+ Code changes are pushed to GitHub branch:
68
+
69
+ ```text
70
+ llm-policy-icon-code
71
+ ```
72
+
73
+ Checkpoint backups should be pushed periodically to the same Hugging Face repo
74
+ using:
75
+
76
+ ```bash
77
+ HF_TOKEN=hf_xxx \
78
+ RUN_OUTPUT_DIR=/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858 \
79
+ bash hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh
80
+ ```
81
+
82
+ The sync script uploads checkpoint folders, adapter files, tokenizer files, and
83
+ logs while ignoring W&B runtime directories and prepared-data caches.
hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Periodically upload the active recurrent-hidden resume output to Hugging Face.
3
+ #
4
+ # Required:
5
+ # RUN_OUTPUT_DIR=/path/to/recurrent_hidden_resume_stage2sft_...
6
+ #
7
+ # Optional:
8
+ # HF_TOKEN=hf_... # otherwise uses `hf auth login` / cached login
9
+ # HF_REPO_ID=Avra98/sudoku-latent-recurrent-hidden-20empty-stages
10
+ # HF_REPO_PREFIX=resume_runs/<run_name>
11
+ # SYNC_INTERVAL_SECONDS=900
12
+
13
+ set -euo pipefail
14
+
15
+ ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
16
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
17
+ HF_REPO_ID="${HF_REPO_ID:-Avra98/sudoku-latent-recurrent-hidden-20empty-stages}"
18
+ RUN_OUTPUT_DIR="${RUN_OUTPUT_DIR:-}"
19
+ SYNC_INTERVAL_SECONDS="${SYNC_INTERVAL_SECONDS:-900}"
20
+
21
+ if [[ -z "${RUN_OUTPUT_DIR}" ]] || [[ ! -d "${RUN_OUTPUT_DIR}" ]]; then
22
+ printf 'ERROR: Set RUN_OUTPUT_DIR to an existing run output directory.\n' >&2
23
+ exit 1
24
+ fi
25
+
26
+ if [[ ! -x "${PYTHON_BIN}" ]]; then
27
+ printf 'ERROR: Python not found at %s\n' "${PYTHON_BIN}" >&2
28
+ exit 1
29
+ fi
30
+
31
+ RUN_NAME="$(basename "${RUN_OUTPUT_DIR}")"
32
+ HF_REPO_PREFIX="${HF_REPO_PREFIX:-resume_runs/${RUN_NAME}}"
33
+ export HF_REPO_ID RUN_OUTPUT_DIR HF_REPO_PREFIX
34
+
35
+ upload_once() {
36
+ "${PYTHON_BIN}" - <<'PY'
37
+ import os
38
+ from pathlib import Path
39
+ from huggingface_hub import HfApi, get_token
40
+
41
+ repo_id = os.environ["HF_REPO_ID"]
42
+ folder = Path(os.environ["RUN_OUTPUT_DIR"]).resolve()
43
+ path_in_repo = os.environ["HF_REPO_PREFIX"].strip("/")
44
+
45
+ token = os.environ.get("HF_TOKEN") or get_token()
46
+ if not token:
47
+ raise SystemExit("No Hugging Face token found. Run `hf auth login` or set HF_TOKEN.")
48
+
49
+ api = HfApi(token=token)
50
+ api.upload_folder(
51
+ repo_id=repo_id,
52
+ repo_type="model",
53
+ folder_path=str(folder),
54
+ path_in_repo=path_in_repo,
55
+ commit_message=f"Sync recurrent-hidden resume checkpoints: {folder.name}",
56
+ allow_patterns=[
57
+ "logs/**",
58
+ "**/checkpoint*/**",
59
+ "**/adapter_config.json",
60
+ "**/adapter_model.safetensors",
61
+ "**/tokenizer.json",
62
+ "**/tokenizer_config.json",
63
+ "**/chat_template.jinja",
64
+ "**/README.md",
65
+ "**/training_args.bin",
66
+ ],
67
+ ignore_patterns=[
68
+ "**/wandb_runtime/**",
69
+ "**/.wandb/**",
70
+ "**/wandb/**",
71
+ "**/optimizer.pt",
72
+ "**/scheduler.pt",
73
+ "**/rng_state_*.pth",
74
+ ],
75
+ )
76
+ print(f"Uploaded {folder} to {repo_id}/{path_in_repo}")
77
+ PY
78
+ }
79
+
80
+ while true; do
81
+ date -Is
82
+ upload_once
83
+ sleep "${SYNC_INTERVAL_SECONDS}"
84
+ done
hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Warm Baseline All-Latent Stages 1-3 Results
2
+
3
+ Run tag: `warmbaseline_alllatent_stages123_20260512_1620`
4
+
5
+ Base model: `Qwen/Qwen2.5-1.5B-Instruct`
6
+
7
+ Stage-1 warm baseline adapter:
8
+
9
+ ```text
10
+ /home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_baseline_1p5b_warmup/baseline_1p5b_warmup_bs32_eval100_20260512_203845/20empty/stage01_sft_i1_20empty_1p5b_warmup/checkpoint-step-01000
11
+ ```
12
+
13
+ This file records the solve-rate snapshot from the ongoing full pipeline. Later
14
+ stages should be updated when all modes finish.
15
+
16
+ ## Current Phase Snapshot
17
+
18
+ | Mode | Current phase at snapshot |
19
+ | --- | --- |
20
+ | `residual` | Stage-2 latent SFT |
21
+ | `fixed_slots` | Stage-2 latent SFT |
22
+ | `recurrent_hidden` | Stage-2 baseline warm-up SFT |
23
+ | `latent_seeds` | Stage-3 baseline warm-up SFT |
24
+
25
+ ## Latest Solve Rates By Phase
26
+
27
+ | Mode | Stage 1 latent SFT | Stage 1 latent GRPO | Stage 2 baseline warm-up | Stage 2 latent SFT | Stage 2 latent GRPO | Stage 3 baseline warm-up |
28
+ | --- | ---: | ---: | ---: | ---: | ---: | ---: |
29
+ | `residual` | 0.470 latest / 0.610 best | 0.620 | 0.110 latest / 0.150 best | in progress | not reached | not reached |
30
+ | `fixed_slots` | 0.770 latest / 0.770 best | 0.870 | 0.140 latest / 0.140 best | 0.100 latest / 0.100 best | not reached | not reached |
31
+ | `recurrent_hidden` | 0.860 latest / 0.860 best | 0.950 | 0.110 latest / 0.110 best | not reached | not reached | not reached |
32
+ | `latent_seeds` | 0.740 latest / 0.740 best | 0.860 | 0.090 latest / 0.100 best | 0.120 latest / 0.120 best | 0.090 | started, no eval yet |
33
+
34
+ ## Stage 1 Solve Trajectories
35
+
36
+ | Mode | Latent SFT solve rates | Post-GRPO solve rate |
37
+ | --- | --- | ---: |
38
+ | `residual` | 0.320 -> 0.610 -> 0.520 -> 0.470 | 0.620 |
39
+ | `fixed_slots` | 0.650 -> 0.200 -> 0.660 -> 0.770 | 0.870 |
40
+ | `recurrent_hidden` | 0.400 -> 0.600 -> 0.800 -> 0.860 | 0.950 |
41
+ | `latent_seeds` | 0.290 -> 0.500 -> 0.640 -> 0.740 | 0.860 |
42
+
43
+ ## Stage 2 Solve Trajectories So Far
44
+
45
+ | Mode | Baseline warm-up solve rates | Latent SFT solve rates | Post-GRPO solve rate |
46
+ | --- | --- | --- | ---: |
47
+ | `residual` | 0.050 -> 0.150 -> 0.110 -> 0.110 | in progress | not reached |
48
+ | `fixed_slots` | 0.090 -> 0.120 -> 0.080 -> 0.140 | 0.080 -> 0.100 | not reached |
49
+ | `recurrent_hidden` | 0.060 -> 0.090 -> 0.100 -> 0.110 | not reached | not reached |
50
+ | `latent_seeds` | 0.090 -> 0.100 -> 0.080 -> 0.090 | 0.080 -> 0.090 -> 0.110 -> 0.120 | 0.090 |
51
+
52
+ ## W&B Links
53
+
54
+ Stage 1 latent SFT:
55
+
56
+ - `residual`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/sp4seb59
57
+ - `fixed_slots`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/d62aiu1g
58
+ - `recurrent_hidden`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/cv3nr7ie
59
+ - `latent_seeds`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1f818jfg
60
+
61
+ Additional stage runs are logged under:
62
+
63
+ - SFT project: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline
64
+ - GRPO project: https://wandb.ai/training-dynamics/sudoku-latent-stage-grpo-warm-baseline
65
+ - Baseline warm-up project: https://wandb.ai/training-dynamics/sudoku-baseline-stage-warmups
large_baseline_extension/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large Baseline Extension Launchers
2
+
3
+ This folder contains launch scripts for the non-location baseline multi-output runs.
4
+
5
+ - `launch_nonlocation_pipeline.sh`
6
+ - `launch_nonlocation_sft.sh`
7
+ - `launch_nonlocation_grpo.sh`
8
+
9
+ The main entry point for a full staged resume run is `launch_nonlocation_pipeline.sh`.
10
+
11
+ Useful environment variables:
12
+
13
+ - `MIN_STAGE`
14
+ - `MAX_STAGE`
15
+ - `NUM_PROCESSES`
16
+ - `GPU_IDS`
17
+ - `BOOTSTRAP_ADAPTER_DIR`
18
+ - `OUTPUT_ROOT`
19
+ - `RUN_TAG`
20
+ - `LIMIT_TRAIN_ROWS`
21
+ - `WANDB_MODE`
22
+ - `WANDB_ENTITY`
23
+
24
+ Example:
25
+
26
+ ```bash
27
+ MIN_STAGE=3 \
28
+ MAX_STAGE=5 \
29
+ NUM_PROCESSES=8 \
30
+ GPU_IDS=0,1,2,3,4,5,6,7 \
31
+ BOOTSTRAP_ADAPTER_DIR=/path/to/stage02_grpo \
32
+ WANDB_MODE=online \
33
+ WANDB_ENTITY=training-dynamics \
34
+ bash launch_nonlocation_pipeline.sh
35
+ ```
large_baseline_extension/launch_nonlocation_grpo.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
7
+ TRAINER="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
8
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
9
+ CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
10
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
11
+ GPU_ID="${GPU_ID:-0}"
12
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
13
+ NUM_PROCESSES="${NUM_PROCESSES:-1}"
14
+ STAGE_I="${STAGE_I:-2}"
15
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
16
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo}"
17
+ OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}"
18
+ WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-grpo}"
19
+ WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_grpo_i${STAGE_I}_${RUN_TAG}}"
20
+ WANDB_GROUP="${WANDB_GROUP:-large_baseline_extension_noloc_grpo_i${STAGE_I}}"
21
+
22
+ case "${STAGE_I}" in
23
+ 2) default_bs=4; default_gas=2 ;;
24
+ 3) default_bs=3; default_gas=2 ;;
25
+ 4) default_bs=2; default_gas=4 ;;
26
+ 5) default_bs=2; default_gas=4 ;;
27
+ *) default_bs=2; default_gas=4 ;;
28
+ esac
29
+
30
+ mkdir -p "${OUTPUT_DIR}"
31
+ export CUDA_DEVICE_ORDER=PCI_BUS_ID
32
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
33
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
34
+ else
35
+ export CUDA_VISIBLE_DEVICES="${GPU_ID}"
36
+ fi
37
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
38
+
39
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
40
+ cmd=(
41
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
42
+ )
43
+ else
44
+ cmd=(
45
+ "${PYTHON_BIN}" -u "${TRAINER}"
46
+ )
47
+ fi
48
+
49
+ cmd+=(
50
+ --model_name "${MODEL_NAME}"
51
+ --train_jsonl "${TRAIN_JSONL}"
52
+ --output_dir "${OUTPUT_DIR}"
53
+ --init_adapter_dir "${INIT_ADAPTER_DIR:?INIT_ADAPTER_DIR must be set}"
54
+ --cache_dir "${CACHE_DIR}"
55
+ --gpu_id 0
56
+ --stage_i "${STAGE_I}"
57
+ --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
58
+ --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}"
59
+ --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}"
60
+ --num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}"
61
+ --learning_rate "${LEARNING_RATE:-1e-6}"
62
+ --logging_steps "${LOGGING_STEPS:-5}"
63
+ --save_steps "${SAVE_STEPS:-25}"
64
+ --eval_steps "${EVAL_STEPS:-25}"
65
+ --eval_rows "${EVAL_ROWS:-20}"
66
+ --num_generations "${NUM_GENERATIONS:-2}"
67
+ --max_prompt_length "${MAX_PROMPT_LENGTH:-1024}"
68
+ --max_completion_length "${MAX_COMPLETION_LENGTH:-24}"
69
+ --beta "${BETA:-0.0}"
70
+ --enable_gradient_checkpointing
71
+ --wandb_project "${WANDB_PROJECT}"
72
+ --wandb_run_name "${WANDB_RUN_NAME}"
73
+ --wandb_group "${WANDB_GROUP}"
74
+ --wandb_mode "${WANDB_MODE:-offline}"
75
+ )
76
+
77
+ if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
78
+ cmd+=(--use_wandb)
79
+ fi
80
+
81
+ if [[ -n "${WANDB_ENTITY:-}" ]]; then
82
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
83
+ fi
84
+
85
+ if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
86
+ cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
87
+ fi
88
+
89
+ if [[ -n "${MAX_STEPS:-}" ]]; then
90
+ cmd+=(--max_steps "${MAX_STEPS}")
91
+ fi
92
+
93
+ if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then
94
+ cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}")
95
+ fi
96
+
97
+ printf 'Launching hard 9x9 baseline GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
98
+ printf 'Output dir: %s\n' "${OUTPUT_DIR}"
99
+ printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}"
100
+ printf 'stage_i=%s batch=%s grad_accum=%s processes=%s\n' \
101
+ "${STAGE_I}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${NUM_PROCESSES}"
102
+
103
+ "${cmd[@]}"
large_baseline_extension/launch_nonlocation_pipeline.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
7
+ PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py"
8
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
9
+ CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
10
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
11
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
12
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
13
+ MIN_STAGE="${MIN_STAGE:-1}"
14
+ MAX_STAGE="${MAX_STAGE:-4}"
15
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
16
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline}"
17
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_30empty_4stage_hard9x9}"
18
+
19
+ cmd=(
20
+ "${PYTHON_BIN}" "${PIPELINE}"
21
+ --python_executable "${PYTHON_BIN}"
22
+ --train_jsonl "${TRAIN_JSONL}"
23
+ --cache_dir "${CACHE_DIR}"
24
+ --model_name "${MODEL_NAME}"
25
+ --checkpoint_root "${CHECKPOINT_ROOT}"
26
+ --output_root "${OUTPUT_ROOT}"
27
+ --run_tag "${RUN_TAG}"
28
+ --min_stage "${MIN_STAGE}"
29
+ --max_stage "${MAX_STAGE}"
30
+ --distributed_gpu_ids "${GPU_IDS}"
31
+ --sft_num_processes "${NUM_PROCESSES}"
32
+ --grpo_num_processes "${NUM_PROCESSES}"
33
+ --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
34
+ --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}"
35
+ --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}"
36
+ --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}"
37
+ --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}"
38
+ --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}"
39
+ --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}"
40
+ --grpo_eval_solve_rate_stop "${GRPO_EVAL_SOLVE_RATE_STOP:-0.8}"
41
+ --grpo_min_steps_before_stop "${GRPO_MIN_STEPS_BEFORE_STOP:-100}"
42
+ --sft_enable_gradient_checkpointing
43
+ --grpo_enable_gradient_checkpointing
44
+ --sft_save_steps "${SFT_SAVE_STEPS:-100}"
45
+ --sft_eval_steps "${SFT_EVAL_STEPS:-100}"
46
+ --grpo_save_steps "${GRPO_SAVE_STEPS:-25}"
47
+ --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}"
48
+ --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}"
49
+ --wandb_mode "${WANDB_MODE:-offline}"
50
+ )
51
+
52
+ if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then
53
+ cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}")
54
+ fi
55
+
56
+ if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
57
+ cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
58
+ fi
59
+
60
+ if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then
61
+ cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}")
62
+ fi
63
+
64
+ if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then
65
+ cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}")
66
+ fi
67
+
68
+ if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
69
+ cmd+=(--use_wandb)
70
+ fi
71
+
72
+ if [[ -n "${WANDB_ENTITY:-}" ]]; then
73
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
74
+ fi
75
+
76
+ printf 'Launching hard 9x9 baseline pipeline on GPUs %s\n' "${GPU_IDS}"
77
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
78
+ printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}"
79
+
80
+ "${cmd[@]}"
large_baseline_extension/launch_nonlocation_sft.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
7
+ TRAINER="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
8
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
9
+ CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
10
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
11
+ GPU_ID="${GPU_ID:-0}"
12
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
13
+ NUM_PROCESSES="${NUM_PROCESSES:-1}"
14
+ STAGE_I="${STAGE_I:-2}"
15
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
16
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft}"
17
+ OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}"
18
+ WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-sft}"
19
+ WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_sft_i${STAGE_I}_${RUN_TAG}}"
20
+
21
+ mkdir -p "${OUTPUT_DIR}"
22
+ export CUDA_DEVICE_ORDER=PCI_BUS_ID
23
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
24
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
25
+ else
26
+ export CUDA_VISIBLE_DEVICES="${GPU_ID}"
27
+ fi
28
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
29
+
30
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
31
+ cmd=(
32
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
33
+ )
34
+ else
35
+ cmd=(
36
+ "${PYTHON_BIN}" -u "${TRAINER}"
37
+ )
38
+ fi
39
+
40
+ cmd+=(
41
+ --model_name "${MODEL_NAME}"
42
+ --train_jsonl "${TRAIN_JSONL}"
43
+ --output_dir "${OUTPUT_DIR}"
44
+ --cache_dir "${CACHE_DIR}"
45
+ --gpu_id 0
46
+ --stage_i "${STAGE_I}"
47
+ --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
48
+ --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-4}"
49
+ --num_epochs "${NUM_EPOCHS:-1.0}"
50
+ --learning_rate "${LEARNING_RATE:-2e-4}"
51
+ --weight_decay "${WEIGHT_DECAY:-0.0}"
52
+ --enable_gradient_checkpointing
53
+ --logging_steps "${LOGGING_STEPS:-10}"
54
+ --save_steps "${SAVE_STEPS:-100}"
55
+ --eval_steps "${EVAL_STEPS:-100}"
56
+ --eval_rows "${EVAL_ROWS:-20}"
57
+ --max_completion_length "${MAX_COMPLETION_LENGTH:-24}"
58
+ --wandb_project "${WANDB_PROJECT}"
59
+ --wandb_run_name "${WANDB_RUN_NAME}"
60
+ --wandb_mode "${WANDB_MODE:-offline}"
61
+ )
62
+
63
+ if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
64
+ cmd+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
65
+ fi
66
+
67
+ if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
68
+ cmd+=(--use_wandb)
69
+ fi
70
+
71
+ if [[ -n "${WANDB_ENTITY:-}" ]]; then
72
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
73
+ fi
74
+
75
+ if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
76
+ cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
77
+ fi
78
+
79
+ if [[ -n "${MAX_STEPS:-}" ]]; then
80
+ cmd+=(--max_steps "${MAX_STEPS}")
81
+ fi
82
+
83
+ printf 'Launching hard 9x9 baseline SFT on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
84
+ printf 'Output dir: %s\n' "${OUTPUT_DIR}"
85
+ printf 'Stage=%s processes=%s\n' "${STAGE_I}" "${NUM_PROCESSES}"
86
+
87
+ "${cmd[@]}"
large_latent_extension/README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large Latent Extension Launchers
2
+
3
+ This folder contains the launch scripts for the non-location latent CoT runs.
4
+
5
+ - `launch_nonlocation_sft.sh`
6
+ - `launch_nonlocation_grpo.sh`
7
+
8
+ These are the scripts used for the distributed multi-GPU non-location curriculum.
9
+
10
+ Useful environment variables:
11
+
12
+ - `NUM_COT_TOKENS`
13
+ - `STAGE_I`
14
+ - `NUM_PROCESSES`
15
+ - `GPU_IDS`
16
+ - `INIT_ADAPTER_DIR`
17
+ - `OUTPUT_DIR`
18
+ - `LIMIT_TRAIN_ROWS`
19
+ - `WANDB_MODE`
20
+ - `WANDB_ENTITY`
21
+
22
+ Example:
23
+
24
+ ```bash
25
+ NUM_COT_TOKENS=3 \
26
+ STAGE_I=3 \
27
+ NUM_PROCESSES=8 \
28
+ GPU_IDS=0,1,2,3,4,5,6,7 \
29
+ WANDB_MODE=online \
30
+ WANDB_ENTITY=training-dynamics \
31
+ bash launch_nonlocation_sft.sh
32
+ ```
large_latent_extension/launch_nonlocation_grpo.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
7
+ TRAINER="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
8
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
9
+ CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
10
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
11
+ GPU_ID="${GPU_ID:-0}"
12
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
13
+ NUM_PROCESSES="${NUM_PROCESSES:-1}"
14
+ NUM_COT_TOKENS="${NUM_COT_TOKENS:?NUM_COT_TOKENS must be set}"
15
+ STAGE_I="${STAGE_I:-2}"
16
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
17
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo}"
18
+ OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}"
19
+ INIT_ADAPTER_DIR="${INIT_ADAPTER_DIR:-}"
20
+ WANDB_PROJECT="${WANDB_PROJECT:-sudoku-latent-multi-output-grpo-residual-projector}"
21
+ WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_latent_noloc_grpo_i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}"
22
+ WANDB_GROUP="${WANDB_GROUP:-large_latent_extension_noloc_grpo_i${STAGE_I}}"
23
+
24
+ case "${NUM_COT_TOKENS}" in
25
+ 2) default_bs=4; default_gas=2 ;;
26
+ 4) default_bs=2; default_gas=4 ;;
27
+ 5) default_bs=2; default_gas=4 ;;
28
+ *) default_bs=2; default_gas=4 ;;
29
+ esac
30
+
31
+ mkdir -p "${OUTPUT_DIR}"
32
+ export CUDA_DEVICE_ORDER=PCI_BUS_ID
33
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
34
+ export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
35
+ else
36
+ export CUDA_VISIBLE_DEVICES="${GPU_ID}"
37
+ fi
38
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
39
+
40
+ if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
41
+ cmd=(
42
+ "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
43
+ )
44
+ else
45
+ cmd=(
46
+ "${PYTHON_BIN}" -u "${TRAINER}"
47
+ )
48
+ fi
49
+
50
+ cmd+=(
51
+ --model_name "${MODEL_NAME}"
52
+ --train_jsonl "${TRAIN_JSONL}"
53
+ --output_dir "${OUTPUT_DIR}"
54
+ --init_adapter_dir "${INIT_ADAPTER_DIR}"
55
+ --cache_dir "${CACHE_DIR}"
56
+ --gpu_id 0
57
+ --stage_i "${STAGE_I}"
58
+ --num_cot_tokens "${NUM_COT_TOKENS}"
59
+ --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
60
+ --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}"
61
+ --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}"
62
+ --num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}"
63
+ --learning_rate "${LEARNING_RATE:-7e-7}"
64
+ --logging_steps "${LOGGING_STEPS:-5}"
65
+ --save_steps "${SAVE_STEPS:-10}"
66
+ --eval_steps "${EVAL_STEPS:-25}"
67
+ --eval_rows "${EVAL_ROWS:-20}"
68
+ --num_generations "${NUM_GENERATIONS:-2}"
69
+ --max_prompt_length "${MAX_PROMPT_LENGTH:-1024}"
70
+ --max_completion_length "${MAX_COMPLETION_LENGTH:-32}"
71
+ --beta "${BETA:-0.01}"
72
+ --enable_gradient_checkpointing
73
+ --wandb_project "${WANDB_PROJECT}"
74
+ --wandb_run_name "${WANDB_RUN_NAME}"
75
+ --wandb_group "${WANDB_GROUP}"
76
+ --wandb_mode "${WANDB_MODE:-offline}"
77
+ )
78
+
79
+ if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
80
+ cmd+=(--use_wandb)
81
+ fi
82
+
83
+ if [[ -n "${WANDB_ENTITY:-}" ]]; then
84
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
85
+ fi
86
+
87
+ if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
88
+ cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
89
+ fi
90
+
91
+ if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then
92
+ cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}")
93
+ fi
94
+
95
+ printf 'Launching hard 9x9 latent GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
96
+ printf 'Output dir: %s\n' "${OUTPUT_DIR}"
97
+ printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}"
98
+ printf 'num_cot_tokens=%s batch=%s grad_accum=%s stage_i=%s num_processes=%s\n' \
99
+ "${NUM_COT_TOKENS}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${STAGE_I}" "${NUM_PROCESSES}"
100
+
101
+ "${cmd[@]}"
large_latent_extension/launch_nonlocation_pipeline.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+ PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
7
+ PIPELINE="${ROOT}/run_latent_residual_projector_pipeline.py"
8
+ TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
9
+ CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
10
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
11
+ GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
12
+ NUM_PROCESSES="${NUM_PROCESSES:-8}"
13
+ MIN_STAGE="${MIN_STAGE:-1}"
14
+ MAX_STAGE="${MAX_STAGE:-4}"
15
+ RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
16
+ CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent}"
17
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/latent_pipeline_30empty_4stage_hard9x9}"
18
+
19
+ cmd=(
20
+ "${PYTHON_BIN}" "${PIPELINE}"
21
+ --python_executable "${PYTHON_BIN}"
22
+ --train_jsonl "${TRAIN_JSONL}"
23
+ --cache_dir "${CACHE_DIR}"
24
+ --model_name "${MODEL_NAME}"
25
+ --checkpoint_root "${CHECKPOINT_ROOT}"
26
+ --output_root "${OUTPUT_ROOT}"
27
+ --run_tag "${RUN_TAG}"
28
+ --min_stage "${MIN_STAGE}"
29
+ --max_stage "${MAX_STAGE}"
30
+ --distributed_gpu_ids "${GPU_IDS}"
31
+ --sft_num_processes "${NUM_PROCESSES}"
32
+ --grpo_num_processes "${NUM_PROCESSES}"
33
+ --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
34
+ --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}"
35
+ --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}"
36
+ --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}"
37
+ --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}"
38
+ --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}"
39
+ --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}"
40
+ --sft_enable_gradient_checkpointing
41
+ --grpo_enable_gradient_checkpointing
42
+ --sft_save_steps "${SFT_SAVE_STEPS:-100}"
43
+ --sft_eval_steps "${SFT_EVAL_STEPS:-100}"
44
+ --grpo_save_steps "${GRPO_SAVE_STEPS:-25}"
45
+ --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}"
46
+ --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}"
47
+ --wandb_mode "${WANDB_MODE:-offline}"
48
+ )
49
+
50
+ if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then
51
+ cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}")
52
+ fi
53
+
54
+ if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then
55
+ cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}")
56
+ fi
57
+
58
+ if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
59
+ cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
60
+ fi
61
+
62
+ if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then
63
+ cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}")
64
+ fi
65
+
66
+ if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then
67
+ cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}")
68
+ fi
69
+
70
+ if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
71
+ cmd+=(--use_wandb)
72
+ fi
73
+
74
+ if [[ -n "${WANDB_ENTITY:-}" ]]; then
75
+ cmd+=(--wandb_entity "${WANDB_ENTITY}")
76
+ fi
77
+
78
+ printf 'Launching hard 9x9 latent pipeline on GPUs %s\n' "${GPU_IDS}"
79
+ printf 'Output root: %s\n' "${OUTPUT_ROOT}"
80
+ printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}"
81
+
82
+ "${cmd[@]}"