File size: 24,267 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
# Curriculum CoT for 9x9 Sudoku β€” Rebuttal/Paper-Section Material

_Last updated: 2026-05-24_

This document is a comprehensive, paper-ready reference of (a) the data pipeline,
(b) the instruction-tuning prompt format, (c) the curriculum and reward design,
(d) the latent thought-token architecture, (e) the multi-stage SFT-then-GRPO
training recipe, and (f) the headline numerical results β€” so a rebuttal section
can be assembled directly from this document.

---

## 1. Task

We use the model as a **per-cell value policy** for 9Γ—9 Sudoku. For a fixed
target empty cell, the model emits a JSON set of candidate digits that are
"i-consistent" with the current grid (definition in Β§4). We evaluate two
metrics:

- **per-cell exact set match** (`exact_set_match`) β€” predicted set equals the
  ground-truth i-consistent set;
- **whole-puzzle solve rate** (`solve`) β€” every empty cell on a 20-empty puzzle
  produces an exact set match.

Because solve = ∏ exact_set_match across the ~20 empty cells of a puzzle,
the two metrics are non-linearly coupled:

$$ \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} $$

so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ β€” every percentage
point of per-cell exact maps to a much larger swing in solve.

---

## 2. Data pipeline

### 2.1 Puzzle generation

Generated by `simple_9x9_curriculum/build_dataset.py`:

- Start from a base Latin-square grid; randomly relabel digits, permute
  rows and columns within bands, and transpose.
- Sample `empties=20` cell positions uniformly at random and erase them.
- Save 10 000 train + 1 000 eval puzzles (seed 0, seed 1).
- Output JSONL files
  `data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`.

A single record contains:

```json
{
  "prompt":     "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>",
  "completion": "[7,3,8,2,6,9,4,5,...]",
  "metadata": {
    "grid_size": 9, "box_size": 3, "empties": 20,
    "empty_locs_1based":     [[1,4],[1,9],...],
    "target_triples_1based": [[1,4,7],[1,9,3],...]
  }
}
```

The 20 `target_triples` give the **solved** value at each of the 20 empty
positions, so per-cell training targets are always available. At training
time we expand each puzzle into 20 (puzzle, target_cell) examples.

### 2.2 Cell-policy framing

The model is never asked to solve a whole puzzle in one shot. Each example
is one (current_grid, target_cell) pair, and the supervised target is the
set of digits that are "i-consistent" with the current grid (see Β§4). This
turns Sudoku into a **classification-into-a-set** problem and lets us share
parameters across cells, stages, and puzzle sizes.

### 2.3 Multi-value oversampling (data-side trick)

Implemented in `multi_output_cell_policy/sft_multi_output_train.py` via
`tokenizer._multi_value_oversample_factor` and the CLI flags

```
--multi_value_oversample_factor INT          (default 1)
--train_target_size_min  INT                 (default 0)
--train_target_size_max  INT                 (default 0)
```

Inside the dataset builder, examples whose target set has more than one
digit are repeated `multi_value_oversample_factor` times in the training
mix. This biases gradient steps toward exactly the cells the model gets
wrong (multi-value cells). Empirically, this is the single biggest data-side
lever β€” see Β§10.

### 2.4 Where the bottleneck lives

For 20-empty puzzles in stage 3, only ~25 % of empty cells have a
multi-value target set (the rest collapse to one i-consistent value). Yet
those multi-value cells are responsible for the entire solve-rate gap:
they are the cells where the model under-predicts (returns a singleton
when the target is a 2- or 3-element set), and a single failed cell kills
the whole-puzzle solve. The reward shaping in Β§6 and the oversample in 2.3
both attack this single failure mode.

---

## 3. Instruction format

### 3.1 System prompt

(verbatim from `multi_output_cell_policy/prompt_builder.py`)

```text
You are a Sudoku value policy.
This setup uses puzzles with about 20 empty cells.
You will be given one target empty cell.
Return ONLY one JSON object of the form {"values":[...]}.
The JSON object must contain exactly one key named "values".
The "values" field must be a JSON array of unique integers in [1,9].
You may return as many candidate values as you want, including one, several,
or many values.
Choose the number of returned values yourself based on which values seem
i-consistent.
The order of the values does not matter.
Do not output any explanation, markdown, punctuation outside JSON, or extra text.
Current stage objective: i={i} consistency.
```

### 3.2 User message

```text
Sudoku grid (0 means empty):
<grid_to_text(grid)>
Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), …
Target cell to fill now: (R,C).
Turn: t/T.
Return only JSON with candidate values for this target cell: {"values":[...]}
```

We use the Qwen2.5-Instruct chat template (`tokenizer.apply_chat_template`,
`add_generation_prompt=True`) to wrap system + user into the actual prompt
ids. `max_prompt_length = 768`.

### 3.3 Output format

```json
{"values":[3,7]}
```

Strictly canonical JSON (single key `values`, sorted unique digit list,
no whitespace). Outputs are scored by `parse_values_json`
(`shared_multi_output_policy.py`); any deviation collapses the whole
prediction to `parse_ok=0` and a hard-coded malformed penalty.

`max_completion_length = 24` tokens β€” enough to emit any 9-digit set.

---

## 4. Curriculum: stage-i consistency

The curriculum lives in `_stage_i_consistent_values_for_grid`:

- **Stage 1 β€” i=1 (legal moves).** A value v is i=1 consistent at cell c
  iff placing v at c violates no Sudoku constraint (row, column, 3Γ—3 box).
  This is just "legal candidates".

- **Stage 2 β€” i=2.** v is i=2 consistent at c iff (a) it is i=1 consistent
  AND (b) after placing v, every other empty cell in the grid still has at
  least one i=1-consistent value (i.e. placing v does not immediately make
  the puzzle unsolvable by 1-step propagation).

- **Stage 3 β€” i=3.** Same recursion one more level deep: v is i=3 consistent
  iff after placing v, every other empty cell still has at least one i=2
  consistent value.

This is bounded look-ahead constraint propagation. Stage-3 sets are tighter
than stage-2 sets which are tighter than stage-1 sets. The curriculum
goal at deployment time is stage-3.

In data, we use the same source records and just change `--stage_i`; the
target set is regenerated on the fly by `stage_i_consistent_values`.

---

## 5. Latent thought-token architecture

Base model: **Qwen/Qwen2.5-1.5B-Instruct** + LoRA (r=32, Ξ±=64, dropout=0.05)
on `q,k,v,o,gate,up,down`. The latent variant adds **k thought-token slots**
between the prompt and the next-token logits.

Four modes are implemented (`latent_multi_output_cell_policy/`); the winning
mode for the final number is **`recurrent_hidden`**:

> `build_recurrent_hidden_latent_hidden(model, ids, mask, k)`
>
> 1. Run the backbone once on the prompt. Keep `base_hidden = h[:,-1,:]`.
> 2. Set `latent_token = base_hidden`.
> 3. Repeat k times: append `latent_token` (as an embedding) to the running
>    sequence, run the backbone again on the extended sequence, and replace
>    `latent_token` with the new last hidden state.
> 4. After k recursions, `latent_hidden` is fed through the LM head to
>    produce the next-token distribution.

In equations, with E the input embedding lookup, f_ΞΈ the LoRA-decorated
backbone, U the LM head:

$$ z_0 = f_\theta(E([x_1,\dots,x_T]))_{T} $$
$$ z_{j+1} = f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1,\dots,z_j]\bigr)_{T+j+1},\ j=0,\dots,k-1 $$
$$ p(\cdot \mid x_{1:T}) = \mathrm{softmax}(U z_k) $$

The model can therefore "iterate" k extra forward passes on the same prompt
before committing to a token, with the k extra hidden states carrying
intermediate computation. Setting k=0 recovers the vanilla baseline.

The other three latent modes are alternatives that we ablated:
`fixed_slots` (concatenate k trainable seed embeddings β€” Option-2),
`latent_seeds` (similar to fixed_slots), and `residual` (project k extra
hidden states back onto the base hidden state via a learned residual). All
modes share the SFT and GRPO trainers; only the next-token logit function
changes.

For the curriculum, we grow k stage by stage:

| stage | num_cot_tokens | comment |
| --- | ---: | --- |
| 1 | 1 | one extra recursion as soon as the model has the surface form |
| 2 | 2 | two β€” needed for 1-step propagation reasoning |
| 3 | 3 | three β€” needed for 2-step propagation reasoning |

---

## 6. The reward function

Defined in `multi_output_cell_policy/rewards.py`.

Given target set T, predicted set P (after JSON parse), let

- `num_good = |P ∩ T|`
- `num_bad  = |P \ T|`
- `num_missing = max(0, |T| βˆ’ num_good)`
- `is_exact = (P β‰  βˆ…) ∧ (P = T)`
- `tri(n) = n(n+1)/2` (rewards larger correct sets superlinearly)

Then

$$
r = \mathrm{tri}(\mathrm{num\_good}) \cdot R_g \;-\; \mathrm{num\_bad} \cdot P_b \\
- \mathbb{1}[P=\varnothing]\, P_e \;-\; \mathbb{1}[|P|=1, |T|>1, i<2]\, P_s \\
- \mathrm{num\_missing}\cdot P_m \;+\; \mathbb{1}[\text{is\_exact}]\, B_x \\
- \mathbb{1}[|P|<|T|, |T|>1]\, P_c
$$

with parameters (this is the recipe that produced the 0.58/0.68 latent solve):

| symbol | flag | value | role |
|---|---|---:|---|
| $R_g$ | `--reward_good_value` | 1.25 | per-correct-value reward (with triangular shape) |
| $P_b$ | `--penalty_bad_value` | 1.0 | per-extra-wrong-value penalty |
| $P_{\!\text{mal}}$ | `--penalty_malformed` | 4.0 | flat penalty if JSON parse fails |
| $P_e$ | `--penalty_empty` | 0.5 | flat penalty if predicted set is empty |
| $P_s$ | `--penalty_singleton` | 1.5 | only at stage<2: punishes singleton on multi-value targets |
| $P_m$ | `--penalty_missing` | **0.75** | per-missing-value (recall pressure) β€” **NEW** |
| $B_x$ | `--exact_match_bonus` | **2.0** | only when P = T β€” **NEW** |
| $P_c$ | `--cardinality_mismatch_penalty` | **1.0** | when |P| < |T| and |T|>1 β€” **NEW** |

Parse failures short-circuit to `r = -P_mal` and zero per-cell metrics.

### 6.1 Why those three new terms exist (the breakthrough)

Diagnosis: at the v3/v4 plateau, eval reported

```
exact=0.95  precision=0.95  recall=0.95  solve=0.30  avg_set_size=1.000
```

across all checkpoints. Per-cell exact and precision/recall were all near
0.95 but the model **always predicted a single digit** (`avg_set_size=1.000`).
On a multi-value target $T=\{8,9\}$, predicting $\{8\}$ keeps precision=1.0,
recall=0.5 and yet `exact_set_match=0`. Solve = exact_set_match^N is
catastrophic in $N$ (=20), so even a small fraction of multi-value cells
killed it.

Without any of the new terms the optimum of $r$ on a multi-value cell is
trivially "predict the singleton you are most confident about" β€” there is
no upside to enumerate the second value. The three new terms close exactly
that hole:

- $P_m$ (`penalty_missing`) directly penalises recall;
- $B_x$ (`exact_match_bonus`) makes $P=T$ strictly dominate any singleton;
- $P_c$ (`cardinality_mismatch_penalty`) is a flat hammer whenever $|P|<|T|$.

After these terms were added, GRPO on the latent variant moved solve from
~0.30 to ~0.58 (100-puzzle eval) over ~200 steps. The same fix is what we
ported back into the baseline pipeline this evening (see Β§10).

---

## 7. Multi-stage warm-baseline pipeline (the recipe that worked)

Master script:
`hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`.

For each curriculum stage we run **three sub-phases in order**:

```
[stage i]
  (1) baseline warm SFT     (no latent tokens, k=0, vanilla LM)
  (2) latent SFT            (k = i, latent mode = recurrent_hidden)
  (3) latent GRPO           (k = i)
```

**The warm baseline phase (1) is the trick that makes the curriculum work.**
At every stage transition the data distribution changes (i increases β†’
target sets shrink) and a new latent slot appears. Doing a vanilla SFT on
the new distribution first lets the LM relearn the surface form on familiar
parameters; THEN the latent SFT adds the extra thought slot on top of an
already-good policy. When we tried to add a new latent slot directly on
top of the previous stage's GRPO checkpoint, training loss did NOT
decrease.

Concrete LR schedule used for the champion run:

| phase | init from | LR | k |
|---|---|---:|---:|
| S1 baseline SFT | base Qwen | 2e-4 | 0 |
| S1 latent SFT   | S1 baseline | 2e-4 | 1 |
| S1 latent GRPO  | S1 latent SFT | 1e-6 | 1 |
| S2 baseline warm SFT | S1 GRPO | 5e-5 | 0 |
| S2 latent SFT   | S2 baseline | 5e-5 | 2 |
| S2 latent GRPO  | S2 latent SFT | 1e-6 | 2 |
| S3 baseline warm SFT | S2 GRPO | 5e-5 | 0 |
| S3 latent SFT   | S3 baseline | 5e-5 β†’ 1e-5 (champion) | 3 |
| S3 latent GRPO  | S3 latent SFT | 5e-6 (Ξ²=0) | 3 |

Other shared knobs:

```
LoRA: r=32 Ξ±=64 dropout=0.05 on q,k,v,o,gate,up,down
SFT:   per_device_bs=8 grad_accum=2 nproc=8  -> eff_bs=128
GRPO:  per_device_bs=4 grad_accum=2 nproc=8  -> eff_bs=64
       num_generations=4    beta=0.0    max_prompt_length=1024
       max_completion_length=24
multi_value_oversample_factor=5,  exact_match_bonus=2.0,
penalty_missing=0.75, cardinality_mismatch_penalty=1.0
```

---

## 8. GRPO settings that mattered

- **Ξ² = 0.** The KL anchor was harmful in every sweep where we tried Ξ²>0.
  `s3_grpo_kl04` (Ξ²=0.04) peaked at solve=0.625 (40p) at step 100 and
  regressed to 0.525 by step 500.

- **num_generations = 4.** With num_generations=2 we routinely saw
  `reward_std = 0` (all sampled completions identical β†’ no gradient).
  Bumping to 4 fixed it.

- **Low LR.** `lr=5e-6` was the steadiest. `lr=1e-5` peaked at step 200
  (solve 0.65) then collapsed back to 0.54 β€” classic mode collapse.

- **Effective bs β‰₯ 64.** TRL's GRPOConfig requires
  `eff_bs * grad_accum % num_generations == 0`; with 8 GPUs we hit this
  trivially, but we caution single-GPU rerunners to set
  `per_device_bs=4 grad_accum=2 num_generations=4`.

- **`enable_input_require_grads()` on the wrapped backbone.** Required for
  TRL 0.15.x + PEFT LoRA + gradient checkpointing β€” otherwise the loss
  tensor produced by GRPOTrainer has `requires_grad=False` and `.backward()`
  raises. Also `unwrapped.config.use_cache = False`.

---

## 9. Final hyperparameters table β€” champion latent run

| group | hyperparameter | value |
|---|---|---|
| Backbone | model | Qwen/Qwen2.5-1.5B-Instruct |
| Backbone | dtype | bf16 |
| Backbone | LoRA target modules | q,k,v,o,gate,up,down |
| Backbone | LoRA r / Ξ± / dropout | 32 / 64 / 0.05 |
| Latent  | mode | recurrent_hidden |
| Latent  | num_cot_tokens (S1/S2/S3) | 1 / 2 / 3 |
| Latent  | max_latent_slots / seeds | 8 / 8 |
| Data    | total empties | 20 |
| Data    | train rows / eval rows | 10 000 / 100 |
| Data    | multi_value_oversample_factor | 5 |
| Data    | mixed_stage1_ratio (S1) | 1 |
| Data    | mixed_stage2_ratio (Sβ‰₯2) | 1 |
| SFT     | per_device_bs / grad_accum | 8 / 2 |
| SFT     | num_epochs (cap) | 64 |
| SFT     | LR (S1 latent) | 2e-4 |
| SFT     | LR (S2/S3 baseline warm + latent) | 5e-5 |
| SFT     | LR (S3 latent champion `s3b_lr1e5_o5`) | 1e-5 |
| SFT     | weight_decay | 0.0 |
| SFT     | gradient checkpointing | on |
| GRPO    | per_device_bs / grad_accum | 4 / 2 |
| GRPO    | num_generations | 4 |
| GRPO    | LR | 5e-6 (S3); 1e-6 (S1, S2) |
| GRPO    | Ξ² (KL) | 0.0 |
| GRPO    | max_prompt_length | 1024 |
| GRPO    | max_completion_length | 24 |
| Reward  | reward_good_value | 1.25 |
| Reward  | penalty_bad_value | 1.0 |
| Reward  | penalty_malformed | 4.0 |
| Reward  | penalty_empty | 0.5 |
| Reward  | penalty_singleton | 1.5 |
| Reward  | penalty_missing | 0.75 |
| Reward  | exact_match_bonus | 2.0 |
| Reward  | cardinality_mismatch_penalty | 1.0 |
| Eval    | early-stop on prec/recall | 0.98 |

---

## 10. Headline results

### 10.1 Latent (with thought tokens, recurrent_hidden)

| eval | model / phase | step | exact | prec | recall | **solve** |
|---|---|---:|---:|---:|---:|---:|
| **100p (auth.)** | `s3_grpo_baseline` (S3 GRPO, Ξ²=0, lr=5e-6) | 200 | 0.9665 | 0.9673 | 0.9680 | **0.580 (58/100)** |
| 40p | `s3_grpo_sharp_rwd` (exact_b=4, card_pen=3) | 300 | β€” | β€” | β€” | **0.675 (27/40)** |
| 40p | `s3_grpo_lr1e5` | 200 | 0.978 | 0.978 | 0.979 | 0.650 |
| 40p | `s3b_lr1e5_o5` (S3 SFT champion) | 2400 | 0.974 | 0.974 | 0.975 | 0.600 |

### 10.2 Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)

| sweep | best variant | best step | exact | **solve (100p)** |
|---|---|---:|---:|---:|
| v3 (single-GPU LR=2e-5, no oversample, no new reward terms) | `baseline_3stage_20260522` | β€” | 0.730 | **0.000** |
| v4 (LR sweep, multi-GPU, original reward) | `pipe_v_sft_extend` (S3 SFT extended) | 4000 | 0.948 | **0.400** |
| **v6 (this evening, ports latent reward + oversample)** | `v6_i_sft_v_oversample10` (oversample=10) | running | 0.952+ | **0.440 (best so far)** |

The v6 sweep is still running β€” `v6_e/f/i` are in S3 SFT continuation,
GRPO follow-on phases queued. The v6_i variant has hit **solve=0.44** at
SFT eval (new baseline best, +0.04 over v4) and is still climbing.

### 10.3 Stage-by-stage trajectory (latent, 40-puzzle eval)

```
S1 SFT                                  : exact β‰ˆ 0.85,  solve β‰ˆ 0.20
S1 GRPO                                 : exact β‰ˆ 0.90,  solve β‰ˆ 0.20
S2 SFT (no oversample)                  : exact β‰ˆ 0.94,  solve β‰ˆ 0.20-0.25  <- the wall
S2 SFT  + multi_value_oversample=5      : exact β‰ˆ 0.96,  solve β‰ˆ 0.30-0.35
S2 GRPO + new reward terms              : exact β‰ˆ 0.96,  solve β‰ˆ 0.35-0.40
S3 SFT  (s3b_lr1e5_o5 step 2400)        : exact 0.974,   solve 0.600       <- SFT champion
S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967,   solve 0.580       <- 100p champion
S3 GRPO (s3_grpo_sharp_rwd step 300,40p):                solve 0.675       <- 40p peak
```

### 10.4 Latent vs baseline gap (head-to-head, same 100p eval, same prompts)

| model | exact | prec | recall | **solve** | solved/100 |
|---|---:|---:|---:|---:|---:|
| Latent recurrent_hidden, S3 GRPO | 0.9665 | 0.9673 | 0.9680 | **0.580** | 58 |
| Vanilla baseline, v6_i (best at time of writing) | 0.952 | 0.952 | 0.952 | **0.440** | 44 |

Gap on 100-puzzle solve: β‰ˆ **+0.14 absolute / +32 % relative** for latent
over the strongest baseline we have.

---

## 11. Why the latent works (interpretation hypotheses)

These are the working hypotheses the experiments are consistent with;
none is fully proven and ablations are still WIP.

1. **Constraint-propagation depth.** Stage-3 i-consistency is essentially
   2-ply lookahead. With $k=3$ recurrent hidden tokens the model gets
   exactly three extra forward passes between prompt and output β€” one for
   the legality check, one for 1-step propagation, one for the second
   step of propagation. Empirically the gap to the no-thought-token
   baseline appears at stages where multi-step propagation matters
   (stage 2 onward; stage 1 numbers are essentially identical).

2. **Multi-value cells require enumeration, which a singleton softmax can't
   do in one forward pass.** A vanilla LM at 1.5B parameters predicts
   essentially deterministically once temperature is low; for a target set
   {8, 9} the LM picks one of the two and stops. The latent model can use
   one of the recurrent hidden steps to "consider" each option without
   committing yet, which is exactly the failure mode in the data
   (`avg_set_size = 1.000` for the baseline, `β‰ˆ 1.05` for the latent S3
   model on the same eval).

3. **Stable curriculum capacity growth.** Adding a new latent slot at every
   stage gives the model a "fresh slate" of representational capacity at the
   exact transition where the task gets harder. The warm-baseline SFT
   between stages prevents the new slot from corrupting the previously
   learned policy. Without warm baseline, training loss did not decrease
   at all (we observed this directly when we tried to skip the warm
   baseline yesterday).

4. **GRPO without latent slots is starved of variance.** With max_completion
   length 24 and the model essentially deterministic, GRPO's 4 sampled
   completions per prompt collapse to a single answer β€” `reward_std = 0`,
   no gradient. With latent recurrence + the new exact_match_bonus reward,
   the model occasionally samples a 2-element set, gets a much higher
   reward, and that prompt gets a real gradient signal.

---

## 12. Reproducibility

Code repository: `https://github.com/Avra98/curriculum_cot`
Latent checkpoints: `https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages`
Baseline checkpoints: `https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep`

Key scripts:

- Master orchestrator (latent, 9-phase warm-baseline pipeline):
  `hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`
- Vanilla baseline pipeline:
  `_runs/baseline_1p5b_pipeline_v4.sh` (with v6 launchers
  `_runs/launch_baseline_push_v6.sh`)
- SFT trainer (vanilla):  `multi_output_cell_policy/sft_multi_output_train.py`
- GRPO trainer (vanilla): `multi_output_cell_policy/grpo_multi_output_train.py`
- SFT trainer (latent):   `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`
- GRPO trainer (latent):  `latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py`
- Reward function:        `multi_output_cell_policy/rewards.py`
- Prompt builder:         `multi_output_cell_policy/prompt_builder.py`
- Stage-i consistency:    `multi_output_cell_policy/shared_multi_output_policy.py`
- 100-puzzle evaluator:   `analysis/eval_stage2_checkpoint.py`

To reproduce the latent champion (1.5B, 9-phase, ~16 GPUΒ·h on 8Γ—H100 80GB):

```bash
export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter
bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
```

To reproduce the v6 baseline push (single-GPU per variant, ~6 GPUΒ·h):

```bash
bash _runs/launch_baseline_push_v6.sh
```

---

## Appendix A. The reward fix as a one-line patch

The single most consequential code change in this whole project, as a
self-contained patch on `multi_output_cell_policy/rewards.py`:

```python
# new args (default 0 preserves legacy behaviour)
penalty_missing: float = 0.0
exact_match_bonus: float = 0.0
cardinality_mismatch_penalty: float = 0.0

num_missing = max(0, len(target_set) - num_good)
is_exact = bool(predicted_values) and (set(predicted_values) == target_set)

# ... base reward (triangular_number(num_good)*reward_good_value - num_bad*penalty_bad_value)

if num_missing > 0:
    reward -= num_missing * penalty_missing
if is_exact:
    reward += exact_match_bonus
if len(predicted_values) < len(target_values) and len(target_values) > 1:
    reward -= cardinality_mismatch_penalty
```

Defaults are zero so old runs are unaffected; the recipe sets
`(P_m, B_x, P_c) = (0.75, 2.0, 1.0)` for the vanilla recipe and
`(1.0, 4.0, 3.0)` for the "sharp_rwd" variant.

## Appendix B. The warm-baseline trick as a sequence diagram

```
Stage 1                Stage 2                Stage 3
─────────              ─────────              ─────────
[base Qwen]            ↓                      ↓
   ↓                   ↓                      ↓
S1 baseline SFT  β†’  S2 baseline SFT  β†’  S3 baseline SFT
(no latent, k=0)    (no latent, k=0)    (no latent, k=0)
   ↓                   ↓                      ↓
S1 latent SFT     β†’  S2 latent SFT     β†’  S3 latent SFT
(k=1)                (k=2)                  (k=3)
   ↓                   ↓                      ↓
S1 latent GRPO    β†’  S2 latent GRPO    β†’  S3 latent GRPO
(k=1, Ξ²=0, lr 1e-6)  (k=2, Ξ²=0)            (k=3, Ξ²=0, lr 5e-6)
   ↓                   ↓                      ↓
                                          [final policy]
```

Every arrow is `init_adapter_dir = <previous output>`. Each row is a
"slot in the curriculum"; the column adds reasoning capacity (k+=1) and
moves to a harder target distribution (i+=1). The diagonal across the
diagram is the actual training trajectory.

---

_End of report._