blt-reasoner-pilot1 / code /smoke_test.py
LauraGG's picture
BLT-Reasoner pilot 1: ckpts + code + logs + ablations
9477b5c verified
"""Identifiability smoke test — pre-registered architectural decision gate.
Take a fixed batch of N=32 GSM8K problems. Freeze the base model. Train ONLY
the latent projector + InfoNCE head with the InfoNCE loss alone, for up to
200 steps. Measure the retrieval accuracy of z↔y.
Pre-registered decision rule (before launching the 24h pilot):
best retrieval_acc >= 0.70 within 200 steps → PASS → launch pilot
best retrieval_acc near chance (~1/N=0.031) → FAIL → architecture broken
Why this gate matters: a constant-z attractor is mechanically incompatible
with high retrieval accuracy (InfoNCE on a batch of B identical z gives
loss = log B by construction). So PASS proves the architecture is at least
capable of producing identifiable latents — a necessary condition for the
pilot to be worth running.
Cost: ~5-10 minutes on a single GH200. Cheap gate for a 24h commitment.
"""
from __future__ import annotations
import argparse
import json
import os
import time
from pathlib import Path
import torch
from .data import GSM8KDataset, collate_batch
from .losses import InfoNCEHead, infonce_loss
from .model import BLTConfig, LatentProjector, build_base, forward_with_latent
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, help="reuse the pilot config (keys: base_model, K_latents, etc.)")
ap.add_argument("--n_problems", type=int, default=32)
ap.add_argument("--n_steps", type=int, default=200)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--tau", type=float, default=0.1)
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--out", default=None)
args = ap.parse_args()
with open(args.config) as f:
cfg = json.load(f)
out_dir = Path(args.out or os.path.join(cfg["output_dir"], "smoke"))
out_dir.mkdir(parents=True, exist_ok=True)
log_path = out_dir / "smoke_log.txt"
summary_path = out_dir / "summary.json"
torch.manual_seed(cfg.get("seed", 42))
device = "cuda" if torch.cuda.is_available() else "cpu"
blt_cfg = BLTConfig(
base_model=cfg["base_model"],
use_lora=cfg.get("use_lora", False),
lora_r=cfg.get("lora_r", 16), lora_alpha=cfg.get("lora_alpha", 32),
lora_dropout=cfg.get("lora_dropout", 0.05),
lora_target_modules=tuple(cfg.get("lora_target_modules",
("q_proj", "k_proj", "v_proj", "o_proj"))),
K_latents=cfg["K_latents"], block_y_to_x=cfg.get("block_y_to_x", True),
proj_init_scale=cfg.get("proj_init_scale", 0.02),
dtype=cfg.get("dtype", "bfloat16"),
attn_impl=cfg.get("attn_impl", "eager"),
)
model, tok = build_base(blt_cfg)
model.to(device)
# Freeze the base model so the only trainable params are projector + head.
for p in model.parameters():
p.requires_grad_(False)
model.eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
dtype = getattr(torch, blt_cfg.dtype)
projector = LatentProjector(d_model, init_scale=blt_cfg.proj_init_scale).to(device=device, dtype=dtype)
head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=cfg.get("nce_proj_dim", 256)).to(device=device, dtype=dtype)
# Fixed batch — same problems every step (true identifiability test).
ds = GSM8KDataset(split="train", max_examples=args.n_problems)
batch = collate_batch(
[ds[i] for i in range(args.n_problems)], tok,
max_prompt_len=cfg.get("max_prompt_len", 256),
max_answer_len=cfg.get("max_answer_len", 256),
)
x_ids = batch.x_ids.to(device)
x_attn = batch.x_attn.to(device)
y_ids = batch.y_ids.to(device)
y_attn = batch.y_attn.to(device)
# Frozen-base y encoding (target for InfoNCE positives).
with torch.no_grad():
from .losses import encode_answer_for_infonce
# We can also just use the y_ids end-of-answer text; here we feed gold "#### N" strings.
f_y = encode_answer_for_infonce(model, tok, batch.final_strs, device=device, max_len=16)
opt = torch.optim.AdamW(
list(projector.parameters()) + list(head.parameters()),
lr=args.lr, weight_decay=0.0,
)
K = blt_cfg.K_latents
log = open(log_path, "w")
t0 = time.time()
best_acc = 0.0
converged_step = None
def _log(msg: str):
line = f"[{time.time() - t0:6.1f}s] {msg}"
print(line, flush=True)
log.write(line + "\n"); log.flush()
_log(f"smoke start: N={args.n_problems} K={K} steps={args.n_steps} thr={args.threshold}")
_log(f"trainable proj+head params = "
f"{sum(p.numel() for p in list(projector.parameters()) + list(head.parameters()))}")
for step in range(args.n_steps):
_, z, _ = forward_with_latent(
model, x_ids, x_attn, y_ids, projector, K,
block_y_to_x=blt_cfg.block_y_to_x,
)
z_pool = z.mean(dim=1).float() # [B, d]
z_emb, y_emb = head(z_pool, f_y.float()) # both L2-normalized
loss = infonce_loss(z_emb, y_emb, tau=args.tau)
# Diagnostic: nearest-neighbor retrieval accuracy.
with torch.no_grad():
sims = z_emb @ y_emb.t()
preds = sims.argmax(dim=-1)
acc = float((preds == torch.arange(sims.size(0), device=device)).float().mean().item())
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(projector.parameters()) + list(head.parameters()), 1.0,
)
opt.step()
if acc > best_acc:
best_acc = acc
if converged_step is None and acc >= args.threshold:
converged_step = step
if step % 10 == 0 or step == args.n_steps - 1:
_log(f"step={step:4d} loss={loss.item():.3f} retr_acc={acc:.3f} best={best_acc:.3f}")
# Early stop after sustained convergence to save time.
if converged_step is not None and step >= converged_step + 20:
_log("converged + 20 buffer steps reached, stopping early.")
break
chance = 1.0 / args.n_problems
decision = "PASS" if best_acc >= args.threshold else "FAIL"
summary = {
"N": args.n_problems, "K": K,
"steps_run": step + 1,
"best_retr_acc": best_acc,
"converged_step": converged_step,
"threshold": args.threshold,
"chance": chance,
"decision": decision,
"duration_s": time.time() - t0,
}
summary_path.write_text(json.dumps(summary, indent=2))
_log(f"summary: {summary}")
log.close()
print(f"[smoke] decision={decision} best_acc={best_acc:.3f} chance={chance:.4f}")
if __name__ == "__main__":
main()