| \documentclass[11pt]{article} |
| \usepackage[a4paper,margin=1in]{geometry} |
| \usepackage[utf8]{inputenc} |
| \usepackage[T1]{fontenc} |
| \usepackage{lmodern} |
| \usepackage{microtype} |
| \usepackage{amsmath,amssymb} |
| \usepackage{booktabs} |
| \usepackage{array} |
| \usepackage{longtable} |
| \usepackage{tabularx} |
| \usepackage{xcolor} |
| \usepackage{listings} |
| \usepackage{hyperref} |
| \usepackage{enumitem} |
| \usepackage{parskip} |
|
|
| \hypersetup{colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue} |
|
|
| \lstset{ |
| basicstyle=\ttfamily\small, |
| breaklines=true, |
| columns=fullflexible, |
| keepspaces=true, |
| frame=single, |
| framerule=0.4pt, |
| xleftmargin=0.5em, |
| xrightmargin=0.5em, |
| showstringspaces=false, |
| } |
|
|
| \newcommand{\code}[1]{\texttt{#1}} |
| \newcommand{\indic}{\mathbf{1}} |
|
|
| \title{Curriculum CoT for $9{\times}9$ Sudoku\\[2pt] |
| \large Rebuttal / Paper-Section Material} |
| \author{} |
| \date{Last updated: 2026--05--24} |
|
|
| \begin{document} |
| \maketitle |
|
|
| \noindent |
| This document is a comprehensive, paper-ready reference of (a) the data |
| pipeline, (b) the instruction-tuning prompt format, (c) the curriculum |
| and reward design, (d) the latent thought-token architecture, (e) the |
| multi-stage SFT-then-GRPO training recipe, and (f) the headline numerical |
| results --- so a rebuttal section can be assembled directly from this |
| document. |
|
|
| \bigskip |
| \hrule |
| \bigskip |
|
|
| \section{Task} |
|
|
| We use the model as a \textbf{per-cell value policy} for $9\times 9$ |
| Sudoku. For a fixed target empty cell, the model emits a JSON set of |
| candidate digits that are ``i-consistent'' with the current grid |
| (definition in \S 4). We evaluate two metrics: |
|
|
| \begin{itemize}[leftmargin=*] |
| \item \textbf{per-cell exact set match} (\code{exact\_set\_match}) --- |
| predicted set equals the ground-truth i-consistent set; |
| \item \textbf{whole-puzzle solve rate} (\code{solve}) --- every empty |
| cell on a 20-empty puzzle produces an exact set match. |
| \end{itemize} |
|
|
| Because $\text{solve} = \prod \text{exact\_set\_match}$ across the $\sim 20$ |
| empty cells of a puzzle, the two metrics are non-linearly coupled: |
| \[ |
| \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} |
| \] |
| so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ --- every |
| percentage point of per-cell exact maps to a much larger swing in solve. |
|
|
| \section{Data pipeline} |
|
|
| \subsection{Puzzle generation} |
|
|
| Generated by \code{simple\_9x9\_curriculum/build\_dataset.py}: |
|
|
| \begin{itemize}[leftmargin=*] |
| \item Start from a base Latin-square grid; randomly relabel digits, |
| permute rows and columns within bands, and transpose. |
| \item Sample \code{empties=20} cell positions uniformly at random and |
| erase them. |
| \item Save 10\,000 train + 1\,000 eval puzzles (seed 0, seed 1). |
| \item Output JSONL files \code{data/sudoku\_t3\_20empty\_value\_qwen\_text\_stage1\_\{train,eval\}.jsonl}. |
| \end{itemize} |
|
|
| A single record contains: |
|
|
| \begin{lstlisting} |
| { |
| "prompt": "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>", |
| "completion": "[7,3,8,2,6,9,4,5,...]", |
| "metadata": { |
| "grid_size": 9, "box_size": 3, "empties": 20, |
| "empty_locs_1based": [[1,4],[1,9],...], |
| "target_triples_1based": [[1,4,7],[1,9,3],...] |
| } |
| } |
| \end{lstlisting} |
|
|
| The 20 \code{target\_triples} give the \textbf{solved} value at each of |
| the 20 empty positions, so per-cell training targets are always |
| available. At training time we expand each puzzle into 20 (puzzle, |
| target\_cell) examples. |
|
|
| \subsection{Cell-policy framing} |
|
|
| The model is never asked to solve a whole puzzle in one shot. Each |
| example is one (current\_grid, target\_cell) pair, and the supervised |
| target is the set of digits that are ``i-consistent'' with the current |
| grid (see \S 4). This turns Sudoku into a |
| \textbf{classification-into-a-set} problem and lets us share parameters |
| across cells, stages, and puzzle sizes. |
|
|
| \subsection{Multi-value oversampling (data-side trick)} |
|
|
| Implemented in \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py} |
| via \code{tokenizer.\_multi\_value\_oversample\_factor} and the CLI flags |
|
|
| \begin{lstlisting} |
| --multi_value_oversample_factor INT (default 1) |
| --train_target_size_min INT (default 0) |
| --train_target_size_max INT (default 0) |
| \end{lstlisting} |
|
|
| Inside the dataset builder, examples whose target set has more than one |
| digit are repeated \code{multi\_value\_oversample\_factor} times in the |
| training mix. This biases gradient steps toward exactly the cells the |
| model gets wrong (multi-value cells). Empirically, this is the single |
| biggest data-side lever --- see \S 10. |
|
|
| \subsection{Where the bottleneck lives} |
|
|
| For 20-empty puzzles in stage 3, only $\sim 25\%$ of empty cells have a |
| multi-value target set (the rest collapse to one i-consistent value). |
| Yet those multi-value cells are responsible for the entire solve-rate |
| gap: they are the cells where the model under-predicts (returns a |
| singleton when the target is a 2- or 3-element set), and a single |
| failed cell kills the whole-puzzle solve. The reward shaping in \S 6 |
| and the oversample in 2.3 both attack this single failure mode. |
|
|
| \section{Instruction format} |
|
|
| \subsection{System prompt} |
|
|
| (verbatim from \code{multi\_output\_cell\_policy/prompt\_builder.py}) |
|
|
| \begin{lstlisting} |
| You are a Sudoku value policy. |
| This setup uses puzzles with about 20 empty cells. |
| You will be given one target empty cell. |
| Return ONLY one JSON object of the form {"values":[...]}. |
| The JSON object must contain exactly one key named "values". |
| The "values" field must be a JSON array of unique integers in [1,9]. |
| You may return as many candidate values as you want, including one, |
| several, or many values. |
| Choose the number of returned values yourself based on which values seem |
| i-consistent. |
| The order of the values does not matter. |
| Do not output any explanation, markdown, punctuation outside JSON, or |
| extra text. |
| Current stage objective: i={i} consistency. |
| \end{lstlisting} |
|
|
| \subsection{User message} |
|
|
| \begin{lstlisting} |
| Sudoku grid (0 means empty): |
| <grid_to_text(grid)> |
| Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), ... |
| Target cell to fill now: (R,C). |
| Turn: t/T. |
| Return only JSON with candidate values for this target cell: {"values":[...]} |
| \end{lstlisting} |
|
|
| We use the Qwen2.5-Instruct chat template |
| (\code{tokenizer.apply\_chat\_template}, \code{add\_generation\_prompt=True}) |
| to wrap system + user into the actual prompt ids. |
| \code{max\_prompt\_length = 768}. |
|
|
| \subsection{Output format} |
|
|
| \begin{lstlisting} |
| {"values":[3,7]} |
| \end{lstlisting} |
|
|
| Strictly canonical JSON (single key \code{values}, sorted unique digit |
| list, no whitespace). Outputs are scored by \code{parse\_values\_json} |
| (\code{shared\_multi\_output\_policy.py}); any deviation collapses the |
| whole prediction to \code{parse\_ok=0} and a hard-coded malformed |
| penalty. |
|
|
| \code{max\_completion\_length = 24} tokens --- enough to emit any |
| 9-digit set. |
|
|
| \section{Curriculum: stage-i consistency} |
|
|
| The curriculum lives in \code{\_stage\_i\_consistent\_values\_for\_grid}: |
|
|
| \begin{itemize}[leftmargin=*] |
| \item \textbf{Stage 1 --- $i=1$ (legal moves).} A value $v$ is $i=1$ |
| consistent at cell $c$ iff placing $v$ at $c$ violates no Sudoku |
| constraint (row, column, $3\times 3$ box). This is just ``legal |
| candidates''. |
|
|
| \item \textbf{Stage 2 --- $i=2$.} $v$ is $i=2$ consistent at $c$ iff |
| (a) it is $i=1$ consistent AND (b) after placing $v$, every other |
| empty cell in the grid still has at least one $i=1$-consistent value |
| (i.e.\ placing $v$ does not immediately make the puzzle unsolvable |
| by 1-step propagation). |
|
|
| \item \textbf{Stage 3 --- $i=3$.} Same recursion one more level deep: |
| $v$ is $i=3$ consistent iff after placing $v$, every other empty cell |
| still has at least one $i=2$ consistent value. |
| \end{itemize} |
|
|
| This is bounded look-ahead constraint propagation. Stage-3 sets are |
| tighter than stage-2 sets which are tighter than stage-1 sets. The |
| curriculum goal at deployment time is stage-3. |
|
|
| In data, we use the same source records and just change \code{--stage\_i}; |
| the target set is regenerated on the fly by |
| \code{stage\_i\_consistent\_values}. |
|
|
| \section{Latent thought-token architecture} |
|
|
| Base model: \textbf{Qwen/Qwen2.5-1.5B-Instruct} + LoRA |
| ($r=32$, $\alpha=64$, dropout $=0.05$) on |
| \code{q,k,v,o,gate,up,down}. The latent variant adds \textbf{$k$ |
| thought-token slots} between the prompt and the next-token logits. |
|
|
| Four modes are implemented (\code{latent\_multi\_output\_cell\_policy/}); |
| the winning mode for the final number is \textbf{\code{recurrent\_hidden}}: |
|
|
| \begin{quote} |
| \code{build\_recurrent\_hidden\_latent\_hidden(model, ids, mask, k)} |
| \begin{enumerate}[leftmargin=*,nosep] |
| \item Run the backbone once on the prompt. Keep |
| \code{base\_hidden = h[:,-1,:]}. |
| \item Set \code{latent\_token = base\_hidden}. |
| \item Repeat $k$ times: append \code{latent\_token} (as an embedding) |
| to the running sequence, run the backbone again on the extended |
| sequence, and replace \code{latent\_token} with the new last hidden |
| state. |
| \item After $k$ recursions, \code{latent\_hidden} is fed through the LM |
| head to produce the next-token distribution. |
| \end{enumerate} |
| \end{quote} |
|
|
| In equations, with $E$ the input embedding lookup, $f_\theta$ the |
| LoRA-decorated backbone, $U$ the LM head: |
| \begin{align*} |
| z_0 &= f_\theta\bigl(E([x_1,\dots,x_T])\bigr)_T \\ |
| z_{j+1} &= f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1, \dots, z_j]\bigr)_{T+j+1},\quad j=0,\dots,k-1 \\ |
| p(\cdot \mid x_{1:T}) &= \mathrm{softmax}(U z_k) |
| \end{align*} |
|
|
| The model can therefore ``iterate'' $k$ extra forward passes on the |
| same prompt before committing to a token, with the $k$ extra hidden |
| states carrying intermediate computation. Setting $k=0$ recovers the |
| vanilla baseline. |
|
|
| The other three latent modes are alternatives that we ablated: |
| \code{fixed\_slots} (concatenate $k$ trainable seed embeddings --- |
| Option-2), \code{latent\_seeds} (similar to \code{fixed\_slots}), and |
| \code{residual} (project $k$ extra hidden states back onto the base |
| hidden state via a learned residual). All modes share the SFT and GRPO |
| trainers; only the next-token logit function changes. |
|
|
| For the curriculum, we grow $k$ stage by stage: |
|
|
| \begin{center} |
| \begin{tabular}{ccl} |
| \toprule |
| \textbf{stage} & \textbf{num\_cot\_tokens} & \textbf{comment} \\ |
| \midrule |
| 1 & 1 & one extra recursion as soon as the model has the surface form \\ |
| 2 & 2 & two --- needed for 1-step propagation reasoning \\ |
| 3 & 3 & three --- needed for 2-step propagation reasoning \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| \section{The reward function} |
|
|
| Defined in \code{multi\_output\_cell\_policy/rewards.py}. |
|
|
| Given target set $T$, predicted set $P$ (after JSON parse), let |
| \begin{itemize}[leftmargin=*,nosep] |
| \item \code{num\_good} $= |P \cap T|$ |
| \item \code{num\_bad} $= |P \setminus T|$ |
| \item \code{num\_missing} $= \max(0, |T| - \text{num\_good})$ |
| \item \code{is\_exact} $= (P \neq \varnothing) \land (P = T)$ |
| \item $\mathrm{tri}(n) = n(n+1)/2$ (rewards larger correct sets superlinearly) |
| \end{itemize} |
|
|
| Then |
| \begin{align*} |
| r &= \mathrm{tri}(\text{num\_good}) \cdot R_g \;-\; \text{num\_bad} \cdot P_b \\ |
| &\quad - \indic[P=\varnothing]\, P_e \;-\; \indic[|P|=1, |T|>1, i<2]\, P_s \\ |
| &\quad - \text{num\_missing}\cdot P_m \;+\; \indic[\text{is\_exact}]\, B_x \\ |
| &\quad - \indic[|P|<|T|, |T|>1]\, P_c |
| \end{align*} |
|
|
| with parameters (this is the recipe that produced the 0.58/0.68 latent |
| solve): |
|
|
| \begin{center} |
| \begin{tabular}{cllr} |
| \toprule |
| \textbf{symbol} & \textbf{flag} & \textbf{role} & \textbf{value} \\ |
| \midrule |
| $R_g$ & \code{--reward\_good\_value} & per-correct-value reward (triangular shape) & 1.25 \\ |
| $P_b$ & \code{--penalty\_bad\_value} & per-extra-wrong-value penalty & 1.0 \\ |
| $P_{\text{mal}}$ & \code{--penalty\_malformed} & flat penalty if JSON parse fails & 4.0 \\ |
| $P_e$ & \code{--penalty\_empty} & flat penalty if predicted set is empty & 0.5 \\ |
| $P_s$ & \code{--penalty\_singleton} & only at stage$<$2: punishes singleton on multi-value targets & 1.5 \\ |
| $P_m$ & \code{--penalty\_missing} & per-missing-value (recall pressure) --- \textbf{NEW} & \textbf{0.75} \\ |
| $B_x$ & \code{--exact\_match\_bonus} & only when $P = T$ --- \textbf{NEW} & \textbf{2.0} \\ |
| $P_c$ & \code{--cardinality\_mismatch\_penalty} & when $|P| < |T|$ and $|T|>1$ --- \textbf{NEW} & \textbf{1.0} \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| Parse failures short-circuit to $r = -P_{\text{mal}}$ and zero per-cell |
| metrics. |
|
|
| \subsection{Why those three new terms exist (the breakthrough)} |
|
|
| Diagnosis: at the v3/v4 plateau, eval reported |
|
|
| \begin{lstlisting} |
| exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000 |
| \end{lstlisting} |
|
|
| across all checkpoints. Per-cell exact and precision/recall were all |
| near 0.95 but the model \textbf{always predicted a single digit} |
| (\code{avg\_set\_size=1.000}). On a multi-value target $T=\{8,9\}$, |
| predicting $\{8\}$ keeps precision $=1.0$, recall $=0.5$ and yet |
| \code{exact\_set\_match}$=0$. Solve $= \text{exact\_set\_match}^N$ is |
| catastrophic in $N$ ($=20$), so even a small fraction of multi-value |
| cells killed it. |
|
|
| Without any of the new terms the optimum of $r$ on a multi-value cell |
| is trivially ``predict the singleton you are most confident about'' --- |
| there is no upside to enumerate the second value. The three new terms |
| close exactly that hole: |
|
|
| \begin{itemize}[leftmargin=*,nosep] |
| \item $P_m$ (\code{penalty\_missing}) directly penalises recall; |
| \item $B_x$ (\code{exact\_match\_bonus}) makes $P=T$ strictly dominate any singleton; |
| \item $P_c$ (\code{cardinality\_mismatch\_penalty}) is a flat hammer whenever $|P|<|T|$. |
| \end{itemize} |
|
|
| After these terms were added, GRPO on the latent variant moved solve |
| from $\sim 0.30$ to $\sim 0.58$ (100-puzzle eval) over $\sim 200$ |
| steps. The same fix is what we ported back into the baseline pipeline |
| this evening (see \S 10). |
|
|
| \section{Multi-stage warm-baseline pipeline (the recipe that worked)} |
|
|
| Master script: |
| \code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}. |
|
|
| For each curriculum stage we run \textbf{three sub-phases in order}: |
|
|
| \begin{lstlisting} |
| [stage i] |
| (1) baseline warm SFT (no latent tokens, k=0, vanilla LM) |
| (2) latent SFT (k = i, latent mode = recurrent_hidden) |
| (3) latent GRPO (k = i) |
| \end{lstlisting} |
|
|
| \textbf{The warm baseline phase (1) is the trick that makes the |
| curriculum work.} At every stage transition the data distribution |
| changes ($i$ increases $\Rightarrow$ target sets shrink) and a new |
| latent slot appears. Doing a vanilla SFT on the new distribution first |
| lets the LM relearn the surface form on familiar parameters; THEN the |
| latent SFT adds the extra thought slot on top of an already-good policy. |
| When we tried to add a new latent slot directly on top of the previous |
| stage's GRPO checkpoint, training loss did NOT decrease. |
|
|
| Concrete LR schedule used for the champion run: |
|
|
| \begin{center} |
| \begin{tabular}{lllc} |
| \toprule |
| \textbf{phase} & \textbf{init from} & \textbf{LR} & \textbf{k} \\ |
| \midrule |
| S1 baseline SFT & base Qwen & 2e-4 & 0 \\ |
| S1 latent SFT & S1 baseline & 2e-4 & 1 \\ |
| S1 latent GRPO & S1 latent SFT & 1e-6 & 1 \\ |
| S2 baseline warm SFT & S1 GRPO & 5e-5 & 0 \\ |
| S2 latent SFT & S2 baseline & 5e-5 & 2 \\ |
| S2 latent GRPO & S2 latent SFT & 1e-6 & 2 \\ |
| S3 baseline warm SFT & S2 GRPO & 5e-5 & 0 \\ |
| S3 latent SFT & S3 baseline & 5e-5 $\rightarrow$ 1e-5 (champion) & 3 \\ |
| S3 latent GRPO & S3 latent SFT & 5e-6 ($\beta=0$) & 3 \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| Other shared knobs: |
|
|
| \begin{lstlisting} |
| LoRA: r=32 a=64 dropout=0.05 on q,k,v,o,gate,up,down |
| SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128 |
| GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64 |
| num_generations=4 beta=0.0 max_prompt_length=1024 |
| max_completion_length=24 |
| multi_value_oversample_factor=5, exact_match_bonus=2.0, |
| penalty_missing=0.75, cardinality_mismatch_penalty=1.0 |
| \end{lstlisting} |
|
|
| \section{GRPO settings that mattered} |
|
|
| \begin{itemize}[leftmargin=*] |
| \item \textbf{$\beta = 0$.} The KL anchor was harmful in every sweep |
| where we tried $\beta>0$. \code{s3\_grpo\_kl04} ($\beta=0.04$) peaked |
| at solve $=0.625$ (40p) at step 100 and regressed to $0.525$ by step |
| 500. |
|
|
| \item \textbf{\code{num\_generations} $= 4$.} With \code{num\_generations}$=2$ |
| we routinely saw \code{reward\_std}$=0$ (all sampled completions |
| identical $\Rightarrow$ no gradient). Bumping to 4 fixed it. |
|
|
| \item \textbf{Low LR.} \code{lr=5e-6} was the steadiest. \code{lr=1e-5} |
| peaked at step 200 (solve $0.65$) then collapsed back to $0.54$ --- |
| classic mode collapse. |
|
|
| \item \textbf{Effective bs $\geq 64$.} TRL's GRPOConfig requires |
| \code{eff\_bs * grad\_accum \% num\_generations == 0}; with 8 GPUs we |
| hit this trivially, but we caution single-GPU rerunners to set |
| \code{per\_device\_bs=4 grad\_accum=2 num\_generations=4}. |
|
|
| \item \textbf{\code{enable\_input\_require\_grads()} on the wrapped backbone.} |
| Required for TRL 0.15.x + PEFT LoRA + gradient checkpointing --- |
| otherwise the loss tensor produced by GRPOTrainer has |
| \code{requires\_grad=False} and \code{.backward()} raises. Also |
| \code{unwrapped.config.use\_cache = False}. |
| \end{itemize} |
|
|
| \section{Final hyperparameters table --- champion latent run} |
|
|
| \begin{center} |
| \begin{longtable}{lll} |
| \toprule |
| \textbf{group} & \textbf{hyperparameter} & \textbf{value} \\ |
| \midrule |
| \endfirsthead |
| \toprule |
| \textbf{group} & \textbf{hyperparameter} & \textbf{value} \\ |
| \midrule |
| \endhead |
| Backbone & model & Qwen/Qwen2.5-1.5B-Instruct \\ |
| Backbone & dtype & bf16 \\ |
| Backbone & LoRA target modules & q,k,v,o,gate,up,down \\ |
| Backbone & LoRA $r$ / $\alpha$ / dropout & 32 / 64 / 0.05 \\ |
| Latent & mode & \code{recurrent\_hidden} \\ |
| Latent & \code{num\_cot\_tokens} (S1/S2/S3) & 1 / 2 / 3 \\ |
| Latent & \code{max\_latent\_slots} / seeds & 8 / 8 \\ |
| Data & total empties & 20 \\ |
| Data & train rows / eval rows & 10\,000 / 100 \\ |
| Data & \code{multi\_value\_oversample\_factor} & 5 \\ |
| Data & \code{mixed\_stage1\_ratio} (S1) & 1 \\ |
| Data & \code{mixed\_stage2\_ratio} (S$\geq 2$) & 1 \\ |
| SFT & per\_device\_bs / grad\_accum & 8 / 2 \\ |
| SFT & \code{num\_epochs} (cap) & 64 \\ |
| SFT & LR (S1 latent) & 2e-4 \\ |
| SFT & LR (S2/S3 baseline warm + latent) & 5e-5 \\ |
| SFT & LR (S3 latent champion \code{s3b\_lr1e5\_o5}) & 1e-5 \\ |
| SFT & weight\_decay & 0.0 \\ |
| SFT & gradient checkpointing & on \\ |
| GRPO & per\_device\_bs / grad\_accum & 4 / 2 \\ |
| GRPO & \code{num\_generations} & 4 \\ |
| GRPO & LR & 5e-6 (S3); 1e-6 (S1, S2) \\ |
| GRPO & $\beta$ (KL) & 0.0 \\ |
| GRPO & \code{max\_prompt\_length} & 1024 \\ |
| GRPO & \code{max\_completion\_length} & 24 \\ |
| Reward & \code{reward\_good\_value} & 1.25 \\ |
| Reward & \code{penalty\_bad\_value} & 1.0 \\ |
| Reward & \code{penalty\_malformed} & 4.0 \\ |
| Reward & \code{penalty\_empty} & 0.5 \\ |
| Reward & \code{penalty\_singleton} & 1.5 \\ |
| Reward & \code{penalty\_missing} & 0.75 \\ |
| Reward & \code{exact\_match\_bonus} & 2.0 \\ |
| Reward & \code{cardinality\_mismatch\_penalty} & 1.0 \\ |
| Eval & early-stop on prec/recall & 0.98 \\ |
| \bottomrule |
| \end{longtable} |
| \end{center} |
|
|
| \section{Headline results} |
|
|
| \subsection{Latent (with thought tokens, \code{recurrent\_hidden})} |
|
|
| \begin{center} |
| \begin{tabular}{llrrrrr} |
| \toprule |
| \textbf{eval} & \textbf{model / phase} & \textbf{step} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} \\ |
| \midrule |
| \textbf{100p (auth.)} & \code{s3\_grpo\_baseline} (S3 GRPO, $\beta=0$, lr=5e-6) & 200 & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580 (58/100)} \\ |
| 40p & \code{s3\_grpo\_sharp\_rwd} ($B_x{=}4$, $P_c{=}3$) & 300 & --- & --- & --- & \textbf{0.675 (27/40)} \\ |
| 40p & \code{s3\_grpo\_lr1e5} & 200 & 0.978 & 0.978 & 0.979 & 0.650 \\ |
| 40p & \code{s3b\_lr1e5\_o5} (S3 SFT champion) & 2400 & 0.974 & 0.974 & 0.975 & 0.600 \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| \subsection{Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)} |
|
|
| \begin{center} |
| \begin{tabular}{llrrr} |
| \toprule |
| \textbf{sweep} & \textbf{best variant} & \textbf{best step} & \textbf{exact} & \textbf{solve (100p)} \\ |
| \midrule |
| v3 (single-GPU, no oversample, no new reward) & \code{baseline\_3stage\_20260522} & --- & 0.730 & \textbf{0.000} \\ |
| v4 (LR sweep, multi-GPU, original reward) & \code{pipe\_v\_sft\_extend} (S3 SFT extended) & 4000 & 0.948 & \textbf{0.400} \\ |
| \textbf{v6 (this evening; ports latent reward + oversample)} & \code{v6\_i\_sft\_v\_oversample10} & running & 0.952$+$ & \textbf{0.440 (best so far)} \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| The v6 sweep is still running --- \code{v6\_e/f/i} are in S3 SFT |
| continuation, GRPO follow-on phases queued. The \code{v6\_i} variant |
| has hit \textbf{solve $=0.44$} at SFT eval (new baseline best, |
| $+0.04$ over v4) and is still climbing. |
|
|
| \subsection{Stage-by-stage trajectory (latent, 40-puzzle eval)} |
|
|
| \begin{lstlisting} |
| S1 SFT : exact ~ 0.85, solve ~ 0.20 |
| S1 GRPO : exact ~ 0.90, solve ~ 0.20 |
| S2 SFT (no oversample) : exact ~ 0.94, solve ~ 0.20-0.25 <- the wall |
| S2 SFT + multi_value_oversample=5 : exact ~ 0.96, solve ~ 0.30-0.35 |
| S2 GRPO + new reward terms : exact ~ 0.96, solve ~ 0.35-0.40 |
| S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion |
| S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion |
| S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak |
| \end{lstlisting} |
|
|
| \subsection{Latent vs baseline gap (head-to-head, same 100p eval, same prompts)} |
|
|
| \begin{center} |
| \begin{tabular}{lrrrrr} |
| \toprule |
| \textbf{model} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} & \textbf{solved/100} \\ |
| \midrule |
| Latent \code{recurrent\_hidden}, S3 GRPO & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580} & 58 \\ |
| Vanilla baseline, \code{v6\_i} (best at time of writing) & 0.952 & 0.952 & 0.952 & \textbf{0.440} & 44 \\ |
| \bottomrule |
| \end{tabular} |
| \end{center} |
|
|
| Gap on 100-puzzle solve: $\approx$ \textbf{$+0.14$ absolute / $+32\%$ |
| relative} for latent over the strongest baseline we have. |
|
|
| \section{Why the latent works (interpretation hypotheses)} |
|
|
| These are the working hypotheses the experiments are consistent with; |
| none is fully proven and ablations are still WIP. |
|
|
| \begin{enumerate}[leftmargin=*] |
| \item \textbf{Constraint-propagation depth.} Stage-3 i-consistency is |
| essentially 2-ply lookahead. With $k=3$ recurrent hidden tokens the |
| model gets exactly three extra forward passes between prompt and |
| output --- one for the legality check, one for 1-step propagation, |
| one for the second step of propagation. Empirically the gap to the |
| no-thought-token baseline appears at stages where multi-step |
| propagation matters (stage 2 onward; stage 1 numbers are essentially |
| identical). |
|
|
| \item \textbf{Multi-value cells require enumeration, which a singleton |
| softmax can't do in one forward pass.} A vanilla LM at 1.5B |
| parameters predicts essentially deterministically once temperature is |
| low; for a target set $\{8, 9\}$ the LM picks one of the two and |
| stops. The latent model can use one of the recurrent hidden steps to |
| ``consider'' each option without committing yet, which is exactly |
| the failure mode in the data (\code{avg\_set\_size} $= 1.000$ for the |
| baseline, $\approx 1.05$ for the latent S3 model on the same eval). |
|
|
| \item \textbf{Stable curriculum capacity growth.} Adding a new latent |
| slot at every stage gives the model a ``fresh slate'' of |
| representational capacity at the exact transition where the task |
| gets harder. The warm-baseline SFT between stages prevents the new |
| slot from corrupting the previously learned policy. Without warm |
| baseline, training loss did not decrease at all (we observed this |
| directly when we tried to skip the warm baseline). |
|
|
| \item \textbf{GRPO without latent slots is starved of variance.} With |
| \code{max\_completion\_length} 24 and the model essentially |
| deterministic, GRPO's 4 sampled completions per prompt collapse to a |
| single answer --- \code{reward\_std}$=0$, no gradient. With latent |
| recurrence + the new \code{exact\_match\_bonus} reward, the model |
| occasionally samples a 2-element set, gets a much higher reward, and |
| that prompt gets a real gradient signal. |
| \end{enumerate} |
|
|
| \section{Reproducibility} |
|
|
| \noindent |
| Code repository: \url{https://github.com/Avra98/curriculum_cot} \\ |
| Latent checkpoints: \url{https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages} \\ |
| Baseline checkpoints: \url{https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep} |
|
|
| Key scripts: |
|
|
| \begin{itemize}[leftmargin=*,nosep] |
| \item Master orchestrator (latent, 9-phase warm-baseline pipeline): |
| \code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh} |
| \item Vanilla baseline pipeline: |
| \code{\_runs/baseline\_1p5b\_pipeline\_v4.sh} (with v6 launchers |
| \code{\_runs/launch\_baseline\_push\_v6.sh}) |
| \item SFT trainer (vanilla): |
| \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py} |
| \item GRPO trainer (vanilla): |
| \code{multi\_output\_cell\_policy/grpo\_multi\_output\_train.py} |
| \item SFT trainer (latent): |
| \code{latent\_multi\_output\_cell\_policy/sft\_latent\_multi\_output\_train.py} |
| \item GRPO trainer (latent): |
| \code{latent\_multi\_output\_cell\_policy/grpo\_residual\_projector\_latent\_train.py} |
| \item Reward function: \code{multi\_output\_cell\_policy/rewards.py} |
| \item Prompt builder: \code{multi\_output\_cell\_policy/prompt\_builder.py} |
| \item Stage-i consistency: |
| \code{multi\_output\_cell\_policy/shared\_multi\_output\_policy.py} |
| \item 100-puzzle evaluator: \code{analysis/eval\_stage2\_checkpoint.py} |
| \end{itemize} |
|
|
| To reproduce the latent champion (1.5B, 9-phase, $\sim 16$ GPU$\cdot$h |
| on $8\times$H100 80GB): |
|
|
| \begin{lstlisting} |
| export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter |
| bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh |
| \end{lstlisting} |
|
|
| To reproduce the v6 baseline push (single-GPU per variant, $\sim 6$ |
| GPU$\cdot$h): |
|
|
| \begin{lstlisting} |
| bash _runs/launch_baseline_push_v6.sh |
| \end{lstlisting} |
|
|
| \appendix |
|
|
| \section{The reward fix as a one-line patch} |
|
|
| The single most consequential code change in this whole project, as a |
| self-contained patch on \code{multi\_output\_cell\_policy/rewards.py}: |
|
|
| \begin{lstlisting}[language=Python] |
| # new args (default 0 preserves legacy behaviour) |
| penalty_missing: float = 0.0 |
| exact_match_bonus: float = 0.0 |
| cardinality_mismatch_penalty: float = 0.0 |
|
|
| num_missing = max(0, len(target_set) - num_good) |
| is_exact = bool(predicted_values) and (set(predicted_values) == target_set) |
|
|
| # ... base reward (triangular_number(num_good)*reward_good_value |
| # - num_bad*penalty_bad_value) |
|
|
| if num_missing > 0: |
| reward -= num_missing * penalty_missing |
| if is_exact: |
| reward += exact_match_bonus |
| if len(predicted_values) < len(target_values) and len(target_values) > 1: |
| reward -= cardinality_mismatch_penalty |
| \end{lstlisting} |
|
|
| Defaults are zero so old runs are unaffected; the recipe sets |
| $(P_m, B_x, P_c) = (0.75, 2.0, 1.0)$ for the vanilla recipe and |
| $(1.0, 4.0, 3.0)$ for the ``sharp\_rwd'' variant. |
|
|
| \section{The warm-baseline trick as a sequence diagram} |
|
|
| \begin{lstlisting} |
| Stage 1 Stage 2 Stage 3 |
| --------- --------- --------- |
| [base Qwen] | | |
| | | | |
| v v v |
| S1 baseline SFT -> S2 baseline SFT -> S3 baseline SFT |
| (no latent, k=0) (no latent, k=0) (no latent, k=0) |
| | | | |
| v v v |
| S1 latent SFT -> S2 latent SFT -> S3 latent SFT |
| (k=1) (k=2) (k=3) |
| | | | |
| v v v |
| S1 latent GRPO -> S2 latent GRPO -> S3 latent GRPO |
| (k=1, b=0, lr 1e-6) (k=2, b=0) (k=3, b=0, lr 5e-6) |
| | | | |
| [final policy] |
| \end{lstlisting} |
|
|
| Every arrow is \code{init\_adapter\_dir = <previous output>}. Each row |
| is a ``slot in the curriculum''; the column adds reasoning capacity |
| ($k\mathrel{+}=1$) and moves to a harder target distribution |
| ($i\mathrel{+}=1$). The diagonal across the diagram is the actual |
| training trajectory. |
|
|
| \bigskip |
| \noindent\emph{End of report.} |
|
|
| \end{document} |
|
|