qwen3-8b-parallel-cot

A Qwen3-8B LoRA latent chain-of-thought organism in which one soft token per reasoning step carries an entire K-cell state. It solves a coupled cellular automaton (CA β€” a grid of cells updated by a fixed local rule each step) inside a single-token latent chain (no text reasoning), is load-bearing (the answer is a causal read-out of the chain β€” decoded from the final token by a per-cell head), and β€” trained over a mix of chain lengths β€” generalises to chains longer than any seen in training. It is a deliberately constructed model organism (a positive control for activation-oracle / CoT-faithfulness evals): the latent reasoning is induced by supervised fine-tuning with a teacher β€” the head is trained on the ground-truth CA states with scheduled-sampling teacher forcing (annealed 1 β†’ 0). Notation: z_t ∈ ℝ^{d_vocab} (soft tokens), h (residual), c ∈ ℝ^K (GT state).

TL;DR (free-running, self-generated tokens; n=400, chance 0.10). Accuracy 1.000 at the trained length; blind z_1 to the prompt β†’ 0.040. Replace the whole chain with another problem's tokens (keeping this prompt): the answer follows the donor 1.00 vs the prompt 0.00. Trained on chain lengths [2, 3, 4, 5, 6], free-running accuracy on longer chains (up to T=14) averages 0.96.

Task & notation

Task (diffuse) β€” a coupled ring cellular automaton (CA). K=3 cells sit on a ring c1, c2, c3 (with c3 adjacent to c1) at random digits; each step, all cells update simultaneously to the sum mod 10 of their two ring neighbours:

ci(t)=(ciβˆ’1(tβˆ’1)+ci+1(tβˆ’1))β€Šβ€Š10. c_i(t) = \bigl(c_{i-1}(t-1) + c_{i+1}(t-1)\bigr) \bmod 10 .

The rule runs for T steps; then one cell is queried, answered as a single digit \boxed{d}. The query appears after the latent block, so the latents must carry the whole row.

Notation (latent-reasoning conventions). Sizes for this model: K = 3 cells, 10 digit values (0–9), soft-token dimension d_vocab = KΒ·10 = 30, model residual width 4096.

  • c(t) ∈ {0,…,9}^K β€” the ground-truth state (the K-cell row after step t); c_i(t) is cell i.
  • h β€” the residual stream (a 4096-vector at a token position).
  • z_t ∈ ℝ^{d_vocab} β€” the autoregressively-fed soft token at step t (t = 1..T), viewed as a KΓ—10 array of per-cell digit distributions: z_t[i] ∈ ℝ^{10} is cell i's distribution over the 10 digits (z_t[i,d] = P(cell i = d)). So one token carries the whole row c(t).

The point: one reasoning step = one token

Each z_t is a single autoregressive position whose residual h holds the entire K-cell row, and the model computes the next row c_{t+1} for all K cells in that one forward step. The latent chain is a strict single-token Markov chain

prompt→z1→z2→⋯→zT→answer, \text{prompt} \to z_1 \to z_2 \to \cdots \to z_T \to \text{answer},

where z_1 reads the prompt, z_t (for t>1) attends only z_{t-1}, and the answer attends only z_T. (Contrast a design that spends K token positions per step β€” here the whole row is a single token.)

Feedback (the soft token)

A small trainable head H : ℝ^{4096} β†’ ℝ^{KΓ—10} reads the residual h_t at z_t and outputs K per-cell digit logits (reshaped to KΓ—10); a per-cell softmax turns them into the digit distributions z_t[i] ∈ ℝ^{10} (i = 1..K). H is a 2-layer MLP β€” LayerNorm(4096) β†’ Linear(4096β†’1024) β†’ GELU β†’ Linear(1024β†’KΒ·10) β€” shared across all steps and cells (one H, applied at every z_t), and learned jointly with the LoRA (shipped in single_extra.pt). These distributions are embedded through a learned codebook C ∈ ℝ^{KΓ—10Γ—4096} β€” one 4096-vector C[i,d] per (cell i, digit d) symbol β€” giving the vector fed into z_{t+1}:

fed=βˆ‘i=1Kβˆ‘d=09zt[i,d] C[i,d]β€…β€Šβˆˆβ€…β€ŠR4096. \text{fed} = \sum_{i=1}^{K}\sum_{d=0}^{9} z_t[i,d]\,C[i,d] \;\in\; \mathbb{R}^{4096}.

C is initialised from KΒ·10 = 30 distinct in-distribution token embeddings so the K cells occupy separable directions β€” a near-identical per-cell init makes the fed token a positionally-ambiguous digit-sum and training stalls at chance.

Because every feedback is a combination of those 30 codebook vectors, it lives in mean(C) + U, where U is the 29-dim span of the (centred) 30 codebook vectors β€” the "soft-token subspace" of the 4096-dim residual (30 symbols β†’ rank 29 after centring).

What the U-patch does (used in the results). With Q ∈ ℝ^{4096Γ—29} an orthonormal basis of U and projector P = QΒ·Qα΅€, for each fed vector v (this problem) and its donor counterpart w at the same chain position, the U-patch sets v ← (Iβˆ’P)Β·v + PΒ·w (keep this problem's part orthogonal to U, take the donor's part in U); the complement patch does the reverse, v ← PΒ·v + (Iβˆ’P)Β·w. Only the fed soft-token vectors are touched (not the residual, head, or codebook). Since every fed vector lies in mean(C) + U they share the same UβŠ₯-part, so the complement swap is a no-op (answer follows the prompt) and the U-swap moves the entire message (answer follows the donor).

Readout / decoding

Let h_t be the last-layer residual at token z_t. The per-cell head H (the same head that produces the feedback) maps it to KΓ—10 digit logits; the soft token fed forward is the per-cell softmax, and the decoded value of cell i is the argmax over that cell's 10 logits:

β„“t=H(ht)∈RKΓ—10,zt[i]=softmax⁑(β„“t[i,:]),c^i(t)=arg max⁑d∈{0,…,9}β„“t[i,d]. \ell_t = H(h_t) \in \mathbb{R}^{K\times 10}, \qquad z_t[i] = \operatorname{softmax}\bigl(\ell_t[i,:]\bigr), \qquad \hat c_i(t) = \operatorname*{arg\,max}_{d\in\{0,\dots,9\}} \ell_t[i,d].

z_t[i] (softmax) is the soft token that drives the recurrence; Δ‰_i(t) (argmax) is the decoded digit β€” used for the trajectories below and for the answer. The answer is the queried cell of the final token, Δ‰_q(T) β€” a single read-out path, with no separate text/LM route to bypass. The probe in Figure 2 confirms each h_t linearly encodes the whole row.

Decoded sample trajectories

Free-running, each latent z_t is decoded as in Readout (Δ‰(t) = the per-cell argmax of the head logits β„“_t). The decoded row matches the true CA evolution at every step β€” including a chain longer than any trained on (the per-step recurrence is shared across lengths):

[trained length, T=6]   initial row c(0) = [6, 6, 0]   (query: cell c2)
   z_1  decodes to  6 6 2    (CA truth 6 6 2)
   z_2  decodes to  8 8 2    (CA truth 8 8 2)
   z_3  decodes to  0 0 6    (CA truth 0 0 6)
   z_4  decodes to  6 6 0    (CA truth 6 6 0)
   z_5  decodes to  6 6 2    (CA truth 6 6 2)
   z_6  decodes to  8 8 2    (CA truth 8 8 2)
   answer = read cell c2 of z_6 -> 8   (CA truth 8)

[trained length, T=6]   initial row c(0) = [8, 7, 6]   (query: cell c2)
   z_1  decodes to  3 4 5    (CA truth 3 4 5)
   z_2  decodes to  9 8 7    (CA truth 9 8 7)
   z_3  decodes to  5 6 7    (CA truth 5 6 7)
   z_4  decodes to  3 2 1    (CA truth 3 2 1)
   z_5  decodes to  3 4 5    (CA truth 3 4 5)
   z_6  decodes to  9 8 7    (CA truth 9 8 7)
   answer = read cell c2 of z_6 -> 8   (CA truth 8)

[LONGER than any trained length, T=10]   initial row c(0) = [7, 5, 9]   (query: cell c1)
   z_1  decodes to  4 6 2    (CA truth 4 6 2)
   z_2  decodes to  8 6 0    (CA truth 8 6 0)
   z_3  decodes to  6 8 4    (CA truth 6 8 4)
   z_4  decodes to  2 0 4    (CA truth 2 0 4)
   z_5  decodes to  4 6 2    (CA truth 4 6 2)
   z_6  decodes to  8 6 0    (CA truth 8 6 0)
   z_7  decodes to  6 8 4    (CA truth 6 8 4)
   z_8  decodes to  2 0 4    (CA truth 2 0 4)
   z_9  decodes to  4 6 2    (CA truth 4 6 2)
   z_10 decodes to  8 6 0    (CA truth 8 6 0)
   answer = read cell c1 of z_10 -> 8   (CA truth 8)

Training (annealing SFT)

LoRA (r=32, Ξ±=64) on a frozen bf16 Qwen3-8B; trainable: the adapter, the per-cell head, the codebook C, and the step-1 query. Data is synthetic and unlimited.

  • Loss = a single per-cell cross-entropy: the head predicts every cell of c_t at every step (including the queried cell at the final step), weighted 4Γ—. Because the answer is read from this same head, there is no separate answer objective that could bypass the per-step state encoding (an earlier LM-head answer path did exactly that and had to be removed).
  • Scheduled-sampling teacher forcing β€” with prob annealed 1 β†’ 0 over 1500 steps, the fed token is the ground-truth one-hot row; then handed off to the self-generated chain.
  • Mixed chain lengths β€” each step draws T from [2, 3, 4, 5, 6], so the single per-step recurrence is shared across lengths and extrapolates to longer chains at test time.
  • Separable codebook init (see above) β€” the fix that lets the single token carry all K cells.
  • AdamW lr 0.0001, grad-clip 1.0; mastery curriculum (stop at acc β‰₯ 0.9 twice at tf=0).

training

Training curve. The per-cell state decodability (orange) rises first β€” the head learns to read the whole row from one token; the length curriculum (purple) then grows T from 2 to 6; and the free-running readout (blue) reaches 1.0, all before teacher forcing (grey) finishes annealing to 0. The separable codebook is what lets a single token hold the whole row.

Results β€” causal load-bearing battery

Run free-running, then intervene on the T-token chain and re-read the queried cell from z_T via the per-cell head (n=400; latent_threads/eval_single_report.py).

interventions

Figure 1. (a) Shuffling the step-order of the chain or replacing it breaks the answer. (b) Overwriting the chain with a different problem's tokens makes the answer follow that donor (1.00) not the prompt (0.00); patching only the codebook subspace U (29-dim) reproduces it (1.00) while its complement UβŠ₯ does nothing (0.00). (c) Free-running accuracy vs chain length T β€” the green band is the trained range; accuracy holds on longer chains.

decode

Figure 2. (a) The answer read-out is near-deterministic (P(correct) 1.00). (b) A linear probe recovers the whole K-cell row c_t from each single token's residual, reaching 1.00 by layer 36 β€” confirming one token holds all K cells.

measurement value what it shows
free-running accuracy (trained length) 1.000 solves it (chance 0.10)
ablate z_1 ↛ prompt 0.040 no input β†’ chance
worst single-step noise 0.040 every step is load-bearing
shuffle step order 0.465 order matters
donor patch β†’ donor / prompt 1.000 / 0.000 follows latents, not prompt
donor patch, codebook subspace U (29-dim) 1.000 message lives in the soft-token subspace
donor patch, complement UβŠ₯ 0.000 the rest is inert
single-cell patch c_i(t):=d β†’ CA-propagated 0.56 the edit propagates through the rule
single-cell patch β†’ keeps original (ignores it) 0.07 the edit is (almost) never ignored
longer-chain accuracy (T > 6) 0.96 generalises beyond trained lengths

Single-cell patch β€” does editing one c_i(t) change the answer the expected way?

The donor/subspace patches above swap the whole chain. This one is surgical: free-run to step t, then overwrite one cell's distribution z_t[i] ∈ ℝ^{10} with the one-hot of a target digit d ∈ {0,…,9} β€” i.e. force c_i(t) = d, leaving the other Kβˆ’1 cells untouched β€” re-embed via the codebook, and continue free-running so the change propagates through the learned CA rule; then compare the answer to the true CA evolved from c(t) with cell i set to d.

cellpatch

The edit is causal and almost never ignored β€” the answer keeps its original value only 0.07 of the time. It matches the exact CA-propagated counterfactual the majority of the time (0.56 overall, 0.73 at one propagation step; chance 0.10). It is not the clean ~1.0 of the whole-chain patch: forcing one cell to an arbitrary value yields an off-trajectory row that the model's learned CA computes less reliably than a full in-distribution chain, and the unpatched z_t residual still feeds the original cell via attention at the boundary. So a single c_i(t) is load-bearing and propagates through the rule the majority of the time, though counterfactual single-cell edits are not computed perfectly. A systematic sweep over all 15 (t,i) sites Γ— 7 problems reproduces this (0.55 overall) and shows it is site-dependent β€” s1.c1 propagates perfectly (7/7) while a few late/edge sites fall to β‰ˆ0.29 (eval_single_cellpatch.py --sweep).

Files & reproduction

LoRA adapter + tokenizer + single_extra.pt (per-cell head, codebook C, step-1 query) + lt_cfg.json; interventions.png / decode.png / training.png; training_code/.

  • Train: python -m latent_threads.train_single --config latent_threads/configs/single_k3m6.json
  • Eval: python -m latent_threads.eval_single_report --ckpt <dir> --n 400

A clean, fully-controlled organism where latent chain-of-thought is provably load-bearing and length-generalising, with one token per step (the whole row computed in parallel).

Downloads last month
45
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for cds-jb/qwen3-8b-parallel-cot

Finetuned
Qwen/Qwen3-8B
Adapter
(1468)
this model