"""Identifiability smoke test — pre-registered architectural decision gate. Take a fixed batch of N=32 GSM8K problems. Freeze the base model. Train ONLY the latent projector + InfoNCE head with the InfoNCE loss alone, for up to 200 steps. Measure the retrieval accuracy of z↔y. Pre-registered decision rule (before launching the 24h pilot): best retrieval_acc >= 0.70 within 200 steps → PASS → launch pilot best retrieval_acc near chance (~1/N=0.031) → FAIL → architecture broken Why this gate matters: a constant-z attractor is mechanically incompatible with high retrieval accuracy (InfoNCE on a batch of B identical z gives loss = log B by construction). So PASS proves the architecture is at least capable of producing identifiable latents — a necessary condition for the pilot to be worth running. Cost: ~5-10 minutes on a single GH200. Cheap gate for a 24h commitment. """ from __future__ import annotations import argparse import json import os import time from pathlib import Path import torch from .data import GSM8KDataset, collate_batch from .losses import InfoNCEHead, infonce_loss from .model import BLTConfig, LatentProjector, build_base, forward_with_latent def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", required=True, help="reuse the pilot config (keys: base_model, K_latents, etc.)") ap.add_argument("--n_problems", type=int, default=32) ap.add_argument("--n_steps", type=int, default=200) ap.add_argument("--lr", type=float, default=1e-3) ap.add_argument("--tau", type=float, default=0.1) ap.add_argument("--threshold", type=float, default=0.70) ap.add_argument("--out", default=None) args = ap.parse_args() with open(args.config) as f: cfg = json.load(f) out_dir = Path(args.out or os.path.join(cfg["output_dir"], "smoke")) out_dir.mkdir(parents=True, exist_ok=True) log_path = out_dir / "smoke_log.txt" summary_path = out_dir / "summary.json" torch.manual_seed(cfg.get("seed", 42)) device = "cuda" if torch.cuda.is_available() else "cpu" blt_cfg = BLTConfig( base_model=cfg["base_model"], use_lora=cfg.get("use_lora", False), lora_r=cfg.get("lora_r", 16), lora_alpha=cfg.get("lora_alpha", 32), lora_dropout=cfg.get("lora_dropout", 0.05), lora_target_modules=tuple(cfg.get("lora_target_modules", ("q_proj", "k_proj", "v_proj", "o_proj"))), K_latents=cfg["K_latents"], block_y_to_x=cfg.get("block_y_to_x", True), proj_init_scale=cfg.get("proj_init_scale", 0.02), dtype=cfg.get("dtype", "bfloat16"), attn_impl=cfg.get("attn_impl", "eager"), ) model, tok = build_base(blt_cfg) model.to(device) # Freeze the base model so the only trainable params are projector + head. for p in model.parameters(): p.requires_grad_(False) model.eval() inner = model.get_base_model() if hasattr(model, "get_base_model") else model d_model = inner.config.hidden_size dtype = getattr(torch, blt_cfg.dtype) projector = LatentProjector(d_model, init_scale=blt_cfg.proj_init_scale).to(device=device, dtype=dtype) head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=cfg.get("nce_proj_dim", 256)).to(device=device, dtype=dtype) # Fixed batch — same problems every step (true identifiability test). ds = GSM8KDataset(split="train", max_examples=args.n_problems) batch = collate_batch( [ds[i] for i in range(args.n_problems)], tok, max_prompt_len=cfg.get("max_prompt_len", 256), max_answer_len=cfg.get("max_answer_len", 256), ) x_ids = batch.x_ids.to(device) x_attn = batch.x_attn.to(device) y_ids = batch.y_ids.to(device) y_attn = batch.y_attn.to(device) # Frozen-base y encoding (target for InfoNCE positives). with torch.no_grad(): from .losses import encode_answer_for_infonce # We can also just use the y_ids end-of-answer text; here we feed gold "#### N" strings. f_y = encode_answer_for_infonce(model, tok, batch.final_strs, device=device, max_len=16) opt = torch.optim.AdamW( list(projector.parameters()) + list(head.parameters()), lr=args.lr, weight_decay=0.0, ) K = blt_cfg.K_latents log = open(log_path, "w") t0 = time.time() best_acc = 0.0 converged_step = None def _log(msg: str): line = f"[{time.time() - t0:6.1f}s] {msg}" print(line, flush=True) log.write(line + "\n"); log.flush() _log(f"smoke start: N={args.n_problems} K={K} steps={args.n_steps} thr={args.threshold}") _log(f"trainable proj+head params = " f"{sum(p.numel() for p in list(projector.parameters()) + list(head.parameters()))}") for step in range(args.n_steps): _, z, _ = forward_with_latent( model, x_ids, x_attn, y_ids, projector, K, block_y_to_x=blt_cfg.block_y_to_x, ) z_pool = z.mean(dim=1).float() # [B, d] z_emb, y_emb = head(z_pool, f_y.float()) # both L2-normalized loss = infonce_loss(z_emb, y_emb, tau=args.tau) # Diagnostic: nearest-neighbor retrieval accuracy. with torch.no_grad(): sims = z_emb @ y_emb.t() preds = sims.argmax(dim=-1) acc = float((preds == torch.arange(sims.size(0), device=device)).float().mean().item()) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_( list(projector.parameters()) + list(head.parameters()), 1.0, ) opt.step() if acc > best_acc: best_acc = acc if converged_step is None and acc >= args.threshold: converged_step = step if step % 10 == 0 or step == args.n_steps - 1: _log(f"step={step:4d} loss={loss.item():.3f} retr_acc={acc:.3f} best={best_acc:.3f}") # Early stop after sustained convergence to save time. if converged_step is not None and step >= converged_step + 20: _log("converged + 20 buffer steps reached, stopping early.") break chance = 1.0 / args.n_problems decision = "PASS" if best_acc >= args.threshold else "FAIL" summary = { "N": args.n_problems, "K": K, "steps_run": step + 1, "best_retr_acc": best_acc, "converged_step": converged_step, "threshold": args.threshold, "chance": chance, "decision": decision, "duration_s": time.time() - t0, } summary_path.write_text(json.dumps(summary, indent=2)) _log(f"summary: {summary}") log.close() print(f"[smoke] decision={decision} best_acc={best_acc:.3f} chance={chance:.4f}") if __name__ == "__main__": main()