File size: 6,815 Bytes
9477b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Identifiability smoke test — pre-registered architectural decision gate.
Take a fixed batch of N=32 GSM8K problems. Freeze the base model. Train ONLY
the latent projector + InfoNCE head with the InfoNCE loss alone, for up to
200 steps. Measure the retrieval accuracy of z↔y.
Pre-registered decision rule (before launching the 24h pilot):
best retrieval_acc >= 0.70 within 200 steps → PASS → launch pilot
best retrieval_acc near chance (~1/N=0.031) → FAIL → architecture broken
Why this gate matters: a constant-z attractor is mechanically incompatible
with high retrieval accuracy (InfoNCE on a batch of B identical z gives
loss = log B by construction). So PASS proves the architecture is at least
capable of producing identifiable latents — a necessary condition for the
pilot to be worth running.
Cost: ~5-10 minutes on a single GH200. Cheap gate for a 24h commitment.
"""
from __future__ import annotations
import argparse
import json
import os
import time
from pathlib import Path
import torch
from .data import GSM8KDataset, collate_batch
from .losses import InfoNCEHead, infonce_loss
from .model import BLTConfig, LatentProjector, build_base, forward_with_latent
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, help="reuse the pilot config (keys: base_model, K_latents, etc.)")
ap.add_argument("--n_problems", type=int, default=32)
ap.add_argument("--n_steps", type=int, default=200)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--tau", type=float, default=0.1)
ap.add_argument("--threshold", type=float, default=0.70)
ap.add_argument("--out", default=None)
args = ap.parse_args()
with open(args.config) as f:
cfg = json.load(f)
out_dir = Path(args.out or os.path.join(cfg["output_dir"], "smoke"))
out_dir.mkdir(parents=True, exist_ok=True)
log_path = out_dir / "smoke_log.txt"
summary_path = out_dir / "summary.json"
torch.manual_seed(cfg.get("seed", 42))
device = "cuda" if torch.cuda.is_available() else "cpu"
blt_cfg = BLTConfig(
base_model=cfg["base_model"],
use_lora=cfg.get("use_lora", False),
lora_r=cfg.get("lora_r", 16), lora_alpha=cfg.get("lora_alpha", 32),
lora_dropout=cfg.get("lora_dropout", 0.05),
lora_target_modules=tuple(cfg.get("lora_target_modules",
("q_proj", "k_proj", "v_proj", "o_proj"))),
K_latents=cfg["K_latents"], block_y_to_x=cfg.get("block_y_to_x", True),
proj_init_scale=cfg.get("proj_init_scale", 0.02),
dtype=cfg.get("dtype", "bfloat16"),
attn_impl=cfg.get("attn_impl", "eager"),
)
model, tok = build_base(blt_cfg)
model.to(device)
# Freeze the base model so the only trainable params are projector + head.
for p in model.parameters():
p.requires_grad_(False)
model.eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
dtype = getattr(torch, blt_cfg.dtype)
projector = LatentProjector(d_model, init_scale=blt_cfg.proj_init_scale).to(device=device, dtype=dtype)
head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=cfg.get("nce_proj_dim", 256)).to(device=device, dtype=dtype)
# Fixed batch — same problems every step (true identifiability test).
ds = GSM8KDataset(split="train", max_examples=args.n_problems)
batch = collate_batch(
[ds[i] for i in range(args.n_problems)], tok,
max_prompt_len=cfg.get("max_prompt_len", 256),
max_answer_len=cfg.get("max_answer_len", 256),
)
x_ids = batch.x_ids.to(device)
x_attn = batch.x_attn.to(device)
y_ids = batch.y_ids.to(device)
y_attn = batch.y_attn.to(device)
# Frozen-base y encoding (target for InfoNCE positives).
with torch.no_grad():
from .losses import encode_answer_for_infonce
# We can also just use the y_ids end-of-answer text; here we feed gold "#### N" strings.
f_y = encode_answer_for_infonce(model, tok, batch.final_strs, device=device, max_len=16)
opt = torch.optim.AdamW(
list(projector.parameters()) + list(head.parameters()),
lr=args.lr, weight_decay=0.0,
)
K = blt_cfg.K_latents
log = open(log_path, "w")
t0 = time.time()
best_acc = 0.0
converged_step = None
def _log(msg: str):
line = f"[{time.time() - t0:6.1f}s] {msg}"
print(line, flush=True)
log.write(line + "\n"); log.flush()
_log(f"smoke start: N={args.n_problems} K={K} steps={args.n_steps} thr={args.threshold}")
_log(f"trainable proj+head params = "
f"{sum(p.numel() for p in list(projector.parameters()) + list(head.parameters()))}")
for step in range(args.n_steps):
_, z, _ = forward_with_latent(
model, x_ids, x_attn, y_ids, projector, K,
block_y_to_x=blt_cfg.block_y_to_x,
)
z_pool = z.mean(dim=1).float() # [B, d]
z_emb, y_emb = head(z_pool, f_y.float()) # both L2-normalized
loss = infonce_loss(z_emb, y_emb, tau=args.tau)
# Diagnostic: nearest-neighbor retrieval accuracy.
with torch.no_grad():
sims = z_emb @ y_emb.t()
preds = sims.argmax(dim=-1)
acc = float((preds == torch.arange(sims.size(0), device=device)).float().mean().item())
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(projector.parameters()) + list(head.parameters()), 1.0,
)
opt.step()
if acc > best_acc:
best_acc = acc
if converged_step is None and acc >= args.threshold:
converged_step = step
if step % 10 == 0 or step == args.n_steps - 1:
_log(f"step={step:4d} loss={loss.item():.3f} retr_acc={acc:.3f} best={best_acc:.3f}")
# Early stop after sustained convergence to save time.
if converged_step is not None and step >= converged_step + 20:
_log("converged + 20 buffer steps reached, stopping early.")
break
chance = 1.0 / args.n_problems
decision = "PASS" if best_acc >= args.threshold else "FAIL"
summary = {
"N": args.n_problems, "K": K,
"steps_run": step + 1,
"best_retr_acc": best_acc,
"converged_step": converged_step,
"threshold": args.threshold,
"chance": chance,
"decision": decision,
"duration_s": time.time() - t0,
}
summary_path.write_text(json.dumps(summary, indent=2))
_log(f"summary: {summary}")
log.close()
print(f"[smoke] decision={decision} best_acc={best_acc:.3f} chance={chance:.4f}")
if __name__ == "__main__":
main()
|