| """Identifiability smoke test — pre-registered architectural decision gate. |
| |
| Take a fixed batch of N=32 GSM8K problems. Freeze the base model. Train ONLY |
| the latent projector + InfoNCE head with the InfoNCE loss alone, for up to |
| 200 steps. Measure the retrieval accuracy of z↔y. |
| |
| Pre-registered decision rule (before launching the 24h pilot): |
| best retrieval_acc >= 0.70 within 200 steps → PASS → launch pilot |
| best retrieval_acc near chance (~1/N=0.031) → FAIL → architecture broken |
| |
| Why this gate matters: a constant-z attractor is mechanically incompatible |
| with high retrieval accuracy (InfoNCE on a batch of B identical z gives |
| loss = log B by construction). So PASS proves the architecture is at least |
| capable of producing identifiable latents — a necessary condition for the |
| pilot to be worth running. |
| |
| Cost: ~5-10 minutes on a single GH200. Cheap gate for a 24h commitment. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| from .data import GSM8KDataset, collate_batch |
| from .losses import InfoNCEHead, infonce_loss |
| from .model import BLTConfig, LatentProjector, build_base, forward_with_latent |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--config", required=True, help="reuse the pilot config (keys: base_model, K_latents, etc.)") |
| ap.add_argument("--n_problems", type=int, default=32) |
| ap.add_argument("--n_steps", type=int, default=200) |
| ap.add_argument("--lr", type=float, default=1e-3) |
| ap.add_argument("--tau", type=float, default=0.1) |
| ap.add_argument("--threshold", type=float, default=0.70) |
| ap.add_argument("--out", default=None) |
| args = ap.parse_args() |
|
|
| with open(args.config) as f: |
| cfg = json.load(f) |
|
|
| out_dir = Path(args.out or os.path.join(cfg["output_dir"], "smoke")) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| log_path = out_dir / "smoke_log.txt" |
| summary_path = out_dir / "summary.json" |
|
|
| torch.manual_seed(cfg.get("seed", 42)) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| blt_cfg = BLTConfig( |
| base_model=cfg["base_model"], |
| use_lora=cfg.get("use_lora", False), |
| lora_r=cfg.get("lora_r", 16), lora_alpha=cfg.get("lora_alpha", 32), |
| lora_dropout=cfg.get("lora_dropout", 0.05), |
| lora_target_modules=tuple(cfg.get("lora_target_modules", |
| ("q_proj", "k_proj", "v_proj", "o_proj"))), |
| K_latents=cfg["K_latents"], block_y_to_x=cfg.get("block_y_to_x", True), |
| proj_init_scale=cfg.get("proj_init_scale", 0.02), |
| dtype=cfg.get("dtype", "bfloat16"), |
| attn_impl=cfg.get("attn_impl", "eager"), |
| ) |
| model, tok = build_base(blt_cfg) |
| model.to(device) |
|
|
| |
| for p in model.parameters(): |
| p.requires_grad_(False) |
| model.eval() |
|
|
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| dtype = getattr(torch, blt_cfg.dtype) |
|
|
| projector = LatentProjector(d_model, init_scale=blt_cfg.proj_init_scale).to(device=device, dtype=dtype) |
| head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=cfg.get("nce_proj_dim", 256)).to(device=device, dtype=dtype) |
|
|
| |
| ds = GSM8KDataset(split="train", max_examples=args.n_problems) |
| batch = collate_batch( |
| [ds[i] for i in range(args.n_problems)], tok, |
| max_prompt_len=cfg.get("max_prompt_len", 256), |
| max_answer_len=cfg.get("max_answer_len", 256), |
| ) |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device) |
| y_attn = batch.y_attn.to(device) |
|
|
| |
| with torch.no_grad(): |
| from .losses import encode_answer_for_infonce |
| |
| f_y = encode_answer_for_infonce(model, tok, batch.final_strs, device=device, max_len=16) |
|
|
| opt = torch.optim.AdamW( |
| list(projector.parameters()) + list(head.parameters()), |
| lr=args.lr, weight_decay=0.0, |
| ) |
|
|
| K = blt_cfg.K_latents |
| log = open(log_path, "w") |
| t0 = time.time() |
| best_acc = 0.0 |
| converged_step = None |
|
|
| def _log(msg: str): |
| line = f"[{time.time() - t0:6.1f}s] {msg}" |
| print(line, flush=True) |
| log.write(line + "\n"); log.flush() |
|
|
| _log(f"smoke start: N={args.n_problems} K={K} steps={args.n_steps} thr={args.threshold}") |
| _log(f"trainable proj+head params = " |
| f"{sum(p.numel() for p in list(projector.parameters()) + list(head.parameters()))}") |
|
|
| for step in range(args.n_steps): |
| _, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=blt_cfg.block_y_to_x, |
| ) |
| z_pool = z.mean(dim=1).float() |
| z_emb, y_emb = head(z_pool, f_y.float()) |
| loss = infonce_loss(z_emb, y_emb, tau=args.tau) |
|
|
| |
| with torch.no_grad(): |
| sims = z_emb @ y_emb.t() |
| preds = sims.argmax(dim=-1) |
| acc = float((preds == torch.arange(sims.size(0), device=device)).float().mean().item()) |
|
|
| opt.zero_grad(set_to_none=True) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| list(projector.parameters()) + list(head.parameters()), 1.0, |
| ) |
| opt.step() |
|
|
| if acc > best_acc: |
| best_acc = acc |
| if converged_step is None and acc >= args.threshold: |
| converged_step = step |
|
|
| if step % 10 == 0 or step == args.n_steps - 1: |
| _log(f"step={step:4d} loss={loss.item():.3f} retr_acc={acc:.3f} best={best_acc:.3f}") |
|
|
| |
| if converged_step is not None and step >= converged_step + 20: |
| _log("converged + 20 buffer steps reached, stopping early.") |
| break |
|
|
| chance = 1.0 / args.n_problems |
| decision = "PASS" if best_acc >= args.threshold else "FAIL" |
| summary = { |
| "N": args.n_problems, "K": K, |
| "steps_run": step + 1, |
| "best_retr_acc": best_acc, |
| "converged_step": converged_step, |
| "threshold": args.threshold, |
| "chance": chance, |
| "decision": decision, |
| "duration_s": time.time() - t0, |
| } |
| summary_path.write_text(json.dumps(summary, indent=2)) |
| _log(f"summary: {summary}") |
| log.close() |
| print(f"[smoke] decision={decision} best_acc={best_acc:.3f} chance={chance:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|