LauraGG's picture
BLT-Reasoner pilot 1: ckpts + code + logs + ablations
9477b5c verified
# BLT-Reasoner — Bottlenecked Latent Thoughts
A compute-constrained replacement for Abstract-CoT's discrete-sampled z̃.
## Motivation
The Abstract-CoT (`experiments/abstract_cot/`) recipe at our scale produced
a 7B GRPO checkpoint whose z̃ is **decorative** (z_ablation: normal 0.57,
random-z 0.54, zero-z 0.52 on MATH-500 maj@8). Root cause: there is no
gradient path from *which* abstract token is sampled to the LM loss — z̃
comes from a `torch.no_grad()` `multinomial`. Closing the four side-channels
that bypass z̃ (Phase-B causal mask, delimiter contextualization, prior-y AR,
direct x→y) collapses primary loss to constant-token or empty z̃.
## Approach (3 components)
1. **Continuous latent loop**: K=4..16 latent vectors `z_t = W_proj(h_{t-1})`,
with full backprop. No sampling, no `no_grad`. Bottleneck via 4D
attention mask: `y` rows cannot attend to `x` columns.
2. **InfoNCE identifiability lock**: contrastive loss between mean-pooled z
and the frozen-base encoding of the gold answer (adapters disabled).
Constant-z attractor is mechanically impossible (InfoNCE lower-bound is
`log B` for any z that's constant across the batch).
3. **β-VAE KL prior** on z magnitude: keep z in a bounded region of the
residual stream so it can be interpreted as a learned thought slot.
## Files
- `model.py``BLTConfig`, LoRA wrap, `LatentProjector`, the
`forward_with_latent` two-pass routine, and `generate_with_latent`
(greedy/temperature decoding with optional `override_z` for ablations).
- `losses.py``InfoNCEHead`, `infonce_loss`, `kl_to_gaussian`,
`encode_answer_for_infonce` (frozen-base encoder), `lm_loss_on_y`.
- `data.py` — GSM8K loader, `format_prompt`, `collate_batch`.
- `train.py` — training loop, K curriculum, cosine LR, eval cadence.
- `eval.py`**pre-registered** ablation: normal-z / random-z / zero-z.
- `configs/pilot_qwen15b_gsm8k.json` — first 24h pilot config.
## Pre-registered success criterion (before looking at raw accuracy)
```
acc(normal-z) - acc(random-z) >= 0.15
acc(normal-z) - acc(zero-z) >= 0.25
```
on GSM8K-test n=200. Both must hold for H1 (z carries information).
## Run
```bash
# On the box (after rsyncing):
bash experiments/blt_reasoner/scripts/run_pilot.sh
# Smoke test first (~2 min on GH200):
python3 -m experiments.blt_reasoner.scripts.smoke_test
```
## Deliberate non-features
- No discrete abstract vocabulary in v1. Codebook overlay is a v2 extension.
- No RL. Continuous-loop SFT with InfoNCE first; RL only if the ablation
Δ confirms H1.
- No multi-GPU. Single GH200; if the pilot needs scale we move to 7B base
after H1 confirmation.