LauraGG's picture
BLT-Reasoner pilot 1: ckpts + code + logs + ablations
9477b5c verified

BLT-Reasoner — Bottlenecked Latent Thoughts

A compute-constrained replacement for Abstract-CoT's discrete-sampled z̃.

Motivation

The Abstract-CoT (experiments/abstract_cot/) recipe at our scale produced a 7B GRPO checkpoint whose z̃ is decorative (z_ablation: normal 0.57, random-z 0.54, zero-z 0.52 on MATH-500 maj@8). Root cause: there is no gradient path from which abstract token is sampled to the LM loss — z̃ comes from a torch.no_grad() multinomial. Closing the four side-channels that bypass z̃ (Phase-B causal mask, delimiter contextualization, prior-y AR, direct x→y) collapses primary loss to constant-token or empty z̃.

Approach (3 components)

  1. Continuous latent loop: K=4..16 latent vectors z_t = W_proj(h_{t-1}), with full backprop. No sampling, no no_grad. Bottleneck via 4D attention mask: y rows cannot attend to x columns.
  2. InfoNCE identifiability lock: contrastive loss between mean-pooled z and the frozen-base encoding of the gold answer (adapters disabled). Constant-z attractor is mechanically impossible (InfoNCE lower-bound is log B for any z that's constant across the batch).
  3. β-VAE KL prior on z magnitude: keep z in a bounded region of the residual stream so it can be interpreted as a learned thought slot.

Files

  • model.pyBLTConfig, LoRA wrap, LatentProjector, the forward_with_latent two-pass routine, and generate_with_latent (greedy/temperature decoding with optional override_z for ablations).
  • losses.pyInfoNCEHead, infonce_loss, kl_to_gaussian, encode_answer_for_infonce (frozen-base encoder), lm_loss_on_y.
  • data.py — GSM8K loader, format_prompt, collate_batch.
  • train.py — training loop, K curriculum, cosine LR, eval cadence.
  • eval.pypre-registered ablation: normal-z / random-z / zero-z.
  • configs/pilot_qwen15b_gsm8k.json — first 24h pilot config.

Pre-registered success criterion (before looking at raw accuracy)

acc(normal-z) - acc(random-z)  >=  0.15
acc(normal-z) - acc(zero-z)    >=  0.25

on GSM8K-test n=200. Both must hold for H1 (z carries information).

Run

# On the box (after rsyncing):
bash experiments/blt_reasoner/scripts/run_pilot.sh
# Smoke test first (~2 min on GH200):
python3 -m experiments.blt_reasoner.scripts.smoke_test

Deliberate non-features

  • No discrete abstract vocabulary in v1. Codebook overlay is a v2 extension.
  • No RL. Continuous-loop SFT with InfoNCE first; RL only if the ablation Δ confirms H1.
  • No multi-GPU. Single GH200; if the pilot needs scale we move to 7B base after H1 confirmation.