| # BLT-Reasoner — Bottlenecked Latent Thoughts |
|
|
| A compute-constrained replacement for Abstract-CoT's discrete-sampled z̃. |
|
|
| ## Motivation |
|
|
| The Abstract-CoT (`experiments/abstract_cot/`) recipe at our scale produced |
| a 7B GRPO checkpoint whose z̃ is **decorative** (z_ablation: normal 0.57, |
| random-z 0.54, zero-z 0.52 on MATH-500 maj@8). Root cause: there is no |
| gradient path from *which* abstract token is sampled to the LM loss — z̃ |
| comes from a `torch.no_grad()` `multinomial`. Closing the four side-channels |
| that bypass z̃ (Phase-B causal mask, delimiter contextualization, prior-y AR, |
| direct x→y) collapses primary loss to constant-token or empty z̃. |
|
|
| ## Approach (3 components) |
|
|
| 1. **Continuous latent loop**: K=4..16 latent vectors `z_t = W_proj(h_{t-1})`, |
| with full backprop. No sampling, no `no_grad`. Bottleneck via 4D |
| attention mask: `y` rows cannot attend to `x` columns. |
| 2. **InfoNCE identifiability lock**: contrastive loss between mean-pooled z |
| and the frozen-base encoding of the gold answer (adapters disabled). |
| Constant-z attractor is mechanically impossible (InfoNCE lower-bound is |
| `log B` for any z that's constant across the batch). |
| 3. **β-VAE KL prior** on z magnitude: keep z in a bounded region of the |
| residual stream so it can be interpreted as a learned thought slot. |
|
|
| ## Files |
|
|
| - `model.py` — `BLTConfig`, LoRA wrap, `LatentProjector`, the |
| `forward_with_latent` two-pass routine, and `generate_with_latent` |
| (greedy/temperature decoding with optional `override_z` for ablations). |
| - `losses.py` — `InfoNCEHead`, `infonce_loss`, `kl_to_gaussian`, |
| `encode_answer_for_infonce` (frozen-base encoder), `lm_loss_on_y`. |
| - `data.py` — GSM8K loader, `format_prompt`, `collate_batch`. |
| - `train.py` — training loop, K curriculum, cosine LR, eval cadence. |
| - `eval.py` — **pre-registered** ablation: normal-z / random-z / zero-z. |
| - `configs/pilot_qwen15b_gsm8k.json` — first 24h pilot config. |
|
|
| ## Pre-registered success criterion (before looking at raw accuracy) |
|
|
| ``` |
| acc(normal-z) - acc(random-z) >= 0.15 |
| acc(normal-z) - acc(zero-z) >= 0.25 |
| ``` |
|
|
| on GSM8K-test n=200. Both must hold for H1 (z carries information). |
|
|
| ## Run |
|
|
| ```bash |
| # On the box (after rsyncing): |
| bash experiments/blt_reasoner/scripts/run_pilot.sh |
| # Smoke test first (~2 min on GH200): |
| python3 -m experiments.blt_reasoner.scripts.smoke_test |
| ``` |
|
|
| ## Deliberate non-features |
|
|
| - No discrete abstract vocabulary in v1. Codebook overlay is a v2 extension. |
| - No RL. Continuous-loop SFT with InfoNCE first; RL only if the ablation |
| Δ confirms H1. |
| - No multi-GPU. Single GH200; if the pilot needs scale we move to 7B base |
| after H1 confirmation. |
|
|