| """Capacity diagnostic for the trained BLT latent: does K=32 have headroom? |
| |
| Two cheap tests on an existing K-trained ckpt: |
| |
| 1. **Effective rank of z.** Compute z for N test problems, stack into a |
| matrix [N*K, d], compute SVD, then `eff_rank = exp(H(σ/sum(σ)))` where |
| H is entropy of the normalized singular value distribution. If eff_rank |
| ≈ K, the K slots are all carrying distinct information (K=32 might |
| add more). If eff_rank << K, slots are redundant (K=32 unlikely to |
| help). |
| |
| 2. **Perturbation curve.** Replace fraction p ∈ {0, 0.25, 0.5, 0.75, 1.0} |
| of slots with Gaussian noise (std-matched). Run AR generation, |
| measure GSM8K accuracy. If acc drops gradually with p, all slots |
| contribute. If acc stays flat up to high p then crashes, many |
| slots are redundant. |
| |
| Usage: |
| python -m experiments.blt_reasoner.scripts.capacity_diagnostic \ |
| --ckpt /path/to/grpo_final --config <config.json> --n 100 --K 16 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import re |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from ..data import GSM8KDataset, collate_batch |
| from ..model import ( |
| BLTConfig, LatentProjector, build_base, |
| forward_with_latent, generate_with_latent, |
| ) |
| from ..eval import parse_pred, correct, _perturb_z |
|
|
|
|
| @torch.no_grad() |
| def collect_z_batch(model, projector, loader, device, K, max_batches=20): |
| """Return z stacked across batches: [N*K, d].""" |
| chunks = [] |
| for i, b in enumerate(loader): |
| if i >= max_batches: break |
| _, z, _ = forward_with_latent( |
| model, b.x_ids.to(device), b.x_attn.to(device), |
| b.y_ids.to(device), projector, K, |
| block_y_to_x=True, return_z=True, |
| ) |
| |
| chunks.append(z.float().reshape(-1, z.size(-1)).cpu()) |
| return torch.cat(chunks, dim=0) |
|
|
|
|
| def effective_rank(M: torch.Tensor) -> dict: |
| """eff_rank = exp(entropy(σ/sum(σ))). Also report stable rank and explained variance curve.""" |
| M = M.float() |
| U, S, V = torch.linalg.svd(M, full_matrices=False) |
| sv = S.clamp_min(1e-12) |
| p = sv / sv.sum() |
| eff = float(torch.exp((-p * p.log()).sum()).item()) |
| stable = float((sv.pow(2).sum() / sv.pow(2).max()).item()) |
| |
| cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist() |
| return { |
| "n_singvals": int(S.numel()), |
| "eff_rank_exp_entropy": eff, |
| "stable_rank": stable, |
| "top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()), |
| "top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()), |
| "top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()), |
| "cum_explained_var_first16": cum[:16], |
| } |
|
|
|
|
| @torch.no_grad() |
| def estimate_z_std(model, projector, loader, device, K, max_batches=4): |
| z = collect_z_batch(model, projector, loader, device, K, max_batches=max_batches) |
| return float(z.std().item()) |
|
|
|
|
| def run_perturbation_curve(model, projector, tokenizer, loader, device, K, *, |
| z_std, severities, max_new_tokens, temperature, seed=0): |
| """For each severity, replace `severity` fraction of slots with N(0, z_std²) and |
| measure GSM8K AR accuracy. Reuses _perturb_z from eval.py.""" |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| proj_dtype = next(projector.parameters()).dtype |
| results = {} |
| for sev in severities: |
| correct_n = total = 0 |
| for bi, batch in enumerate(loader): |
| x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device) |
| B = x_ids.size(0) |
| |
| _, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=True, return_z=True, |
| ) |
| z_pert = _perturb_z(z.to(device=device, dtype=proj_dtype), |
| severity=sev, z_std=z_std, seed=seed + bi) |
| gen = generate_with_latent( |
| model, tokenizer, projector, |
| x_ids=x_ids, x_attn=x_attn, K=K, |
| block_y_to_x=True, max_new_tokens=max_new_tokens, |
| temperature=temperature, eos_token_id=tokenizer.eos_token_id, |
| override_z=z_pert, |
| ) |
| for b in range(B): |
| text = tokenizer.decode(gen[b], skip_special_tokens=True) |
| pred = parse_pred(text) |
| gold = batch.final_strs[b].replace("#### ", "").strip() |
| if correct(pred, gold): |
| correct_n += 1 |
| total += 1 |
| results[float(sev)] = {"acc": correct_n / max(total, 1), "n": total, "correct": correct_n} |
| print(f" severity={sev:.2f} acc={results[float(sev)]['acc']:.3f} ({correct_n}/{total})", flush=True) |
| return results |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--config", required=True) |
| p.add_argument("--n", type=int, default=100) |
| p.add_argument("--K", type=int, default=None) |
| p.add_argument("--max_new_tokens", type=int, default=128) |
| p.add_argument("--out", default=None) |
| args = p.parse_args() |
| with open(args.config) as f: |
| cfg = json.load(f) |
| K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = Path(args.ckpt) |
| bcfg = BLTConfig( |
| base_model=cfg["base_model"], use_lora=False, |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=K, block_y_to_x=cfg["block_y_to_x"], |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| gradient_checkpointing=False, |
| ) |
| base_model, tokenizer = build_base(bcfg) |
| from peft import PeftModel |
| adapter_dir = ckpt / "model" |
| if (adapter_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(base_model, str(adapter_dir)) |
| print(f"[load] adapter from {adapter_dir}") |
| else: |
| model = base_model |
| model.to(device).eval() |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(next(model.parameters()).dtype) |
| projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) |
| projector.eval() |
|
|
| val_ds = GSM8KDataset(split="test", max_examples=args.n) |
| loader = DataLoader( |
| val_ds, batch_size=8, shuffle=False, |
| collate_fn=lambda b: collate_batch(b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"]), |
| ) |
|
|
| |
| print("\n[diagnostic 1] effective rank of z across test problems") |
| t0 = time.time() |
| Z = collect_z_batch(model, projector, loader, device, K, max_batches=20) |
| print(f" collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)") |
| rank_stats = effective_rank(Z) |
| print(f" n_singvals={rank_stats['n_singvals']}") |
| print(f" eff_rank (exp_entropy) = {rank_stats['eff_rank_exp_entropy']:.2f} (K={K})") |
| print(f" stable_rank = {rank_stats['stable_rank']:.2f}") |
| print(f" top-1 variance frac = {rank_stats['top1_var_frac']:.3f}") |
| print(f" top-4 variance frac = {rank_stats['top4_var_frac']:.3f}") |
| print(f" top-8 variance frac = {rank_stats['top8_var_frac']:.3f}") |
|
|
| |
| if rank_stats['eff_rank_exp_entropy'] >= K * 0.7: |
| verdict_rank = "HIGH eff_rank — slots are using distinct directions → K=32 plausibly helps" |
| elif rank_stats['eff_rank_exp_entropy'] >= K * 0.4: |
| verdict_rank = "MEDIUM eff_rank — partial slot redundancy; K=32 may give marginal lift" |
| else: |
| verdict_rank = "LOW eff_rank — many redundant slots; K=32 unlikely to help" |
| print(f" verdict: {verdict_rank}") |
|
|
| |
| print("\n[diagnostic 2] perturbation curve (AR accuracy vs fraction-of-slots-replaced)") |
| z_std = estimate_z_std(model, projector, loader, device, K, max_batches=4) |
| print(f" z_std={z_std:.4f}") |
| severities = [0.0, 0.125, 0.25, 0.5, 0.75, 0.875, 1.0] |
| pert_results = run_perturbation_curve( |
| model, projector, tokenizer, loader, device, K, |
| z_std=z_std, severities=severities, |
| max_new_tokens=args.max_new_tokens, temperature=0.0, |
| ) |
|
|
| summary = { |
| "ckpt": str(ckpt), |
| "n": args.n, "K": K, "z_std": z_std, |
| "effective_rank": rank_stats, |
| "verdict_eff_rank": verdict_rank, |
| "perturbation_curve": pert_results, |
| } |
| out = args.out or str(ckpt / "capacity_diagnostic.json") |
| Path(out).write_text(json.dumps(summary, indent=2)) |
| print(f"\n[written] {out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|