File size: 2,669 Bytes
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# BLT-Reasoner — Bottlenecked Latent Thoughts

A compute-constrained replacement for Abstract-CoT's discrete-sampled z̃.

## Motivation

The Abstract-CoT (`experiments/abstract_cot/`) recipe at our scale produced
a 7B GRPO checkpoint whose z̃ is **decorative** (z_ablation: normal 0.57,
random-z 0.54, zero-z 0.52 on MATH-500 maj@8). Root cause: there is no
gradient path from *which* abstract token is sampled to the LM loss — z̃
comes from a `torch.no_grad()` `multinomial`. Closing the four side-channels
that bypass z̃ (Phase-B causal mask, delimiter contextualization, prior-y AR,
direct x→y) collapses primary loss to constant-token or empty z̃.

## Approach (3 components)

1. **Continuous latent loop**: K=4..16 latent vectors `z_t = W_proj(h_{t-1})`,
   with full backprop. No sampling, no `no_grad`. Bottleneck via 4D
   attention mask: `y` rows cannot attend to `x` columns.
2. **InfoNCE identifiability lock**: contrastive loss between mean-pooled z
   and the frozen-base encoding of the gold answer (adapters disabled).
   Constant-z attractor is mechanically impossible (InfoNCE lower-bound is
   `log B` for any z that's constant across the batch).
3. **β-VAE KL prior** on z magnitude: keep z in a bounded region of the
   residual stream so it can be interpreted as a learned thought slot.

## Files

- `model.py``BLTConfig`, LoRA wrap, `LatentProjector`, the
  `forward_with_latent` two-pass routine, and `generate_with_latent`
  (greedy/temperature decoding with optional `override_z` for ablations).
- `losses.py``InfoNCEHead`, `infonce_loss`, `kl_to_gaussian`,
  `encode_answer_for_infonce` (frozen-base encoder), `lm_loss_on_y`.
- `data.py` — GSM8K loader, `format_prompt`, `collate_batch`.
- `train.py` — training loop, K curriculum, cosine LR, eval cadence.
- `eval.py`**pre-registered** ablation: normal-z / random-z / zero-z.
- `configs/pilot_qwen15b_gsm8k.json` — first 24h pilot config.

## Pre-registered success criterion (before looking at raw accuracy)

```
acc(normal-z) - acc(random-z)  >=  0.15
acc(normal-z) - acc(zero-z)    >=  0.25
```

on GSM8K-test n=200. Both must hold for H1 (z carries information).

## Run

```bash
# On the box (after rsyncing):
bash experiments/blt_reasoner/scripts/run_pilot.sh
# Smoke test first (~2 min on GH200):
python3 -m experiments.blt_reasoner.scripts.smoke_test
```

## Deliberate non-features

- No discrete abstract vocabulary in v1. Codebook overlay is a v2 extension.
- No RL. Continuous-loop SFT with InfoNCE first; RL only if the ablation
  Δ confirms H1.
- No multi-GPU. Single GH200; if the pilot needs scale we move to 7B base
  after H1 confirmation.