blt-reasoner-pilot1 / code /scripts /z_rank_only.py
LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""Minimal z rank computation — no generation, no perturbation curve.
Forward N problems through a ckpt and compute the rank statistics of the
resulting z (input embeddings from the M-step latent loop). Lightweight
enough to run in parallel with training (~15-20 GB activation footprint
vs full capacity_diagnostic's ~30 GB).
Used to ablate the "harder problems → richer z" hypothesis: run the GSM8K-
trained GRPO ckpt against MATH test problems, compare stable_rank to the
known GSM8K-eval value (6.73).
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import Optional
import torch
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, MATHDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent
@torch.no_grad()
def collect_z_batch(model, projector, loader, device, K, max_batches=20):
chunks = []
for i, b in enumerate(loader):
if i >= max_batches: break
_, z, _ = forward_with_latent(
model, b.x_ids.to(device), b.x_attn.to(device),
b.y_ids.to(device), projector, K,
block_y_to_x=True, return_z=True,
)
chunks.append(z.float().reshape(-1, z.size(-1)).cpu())
return torch.cat(chunks, dim=0)
def rank_stats(M: torch.Tensor) -> dict:
M = M.float()
U, S, V = torch.linalg.svd(M, full_matrices=False)
sv = S.clamp_min(1e-12)
p = sv / sv.sum()
eff = float(torch.exp((-p * p.log()).sum()).item())
stable = float((sv.pow(2).sum() / sv.pow(2).max()).item())
cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist()
return {
"n_singvals": int(S.numel()),
"eff_rank_exp_entropy": eff,
"stable_rank": stable,
"top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()),
"top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()),
"top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()),
"cum_explained_var_first16": cum[:16],
"z_std": float(M.std().item()),
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--config", required=True)
p.add_argument("--n", type=int, default=100)
p.add_argument("--K", type=int, default=None)
p.add_argument("--eval_dataset", required=True, choices=["gsm8k", "math"],
help="Which dataset's TEST split to evaluate against (independent of training data)")
p.add_argument("--out", default=None)
args = p.parse_args()
with open(args.config) as f:
cfg = json.load(f)
K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = Path(args.ckpt)
bcfg = BLTConfig(
base_model=cfg["base_model"], use_lora=False,
lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
lora_dropout=cfg["lora_dropout"],
lora_target_modules=tuple(cfg["lora_target_modules"]),
K_latents=K, block_y_to_x=cfg["block_y_to_x"],
proj_init_scale=cfg["proj_init_scale"],
dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
gradient_checkpointing=False,
)
base_model, tokenizer = build_base(bcfg)
from peft import PeftModel
adapter_dir = ckpt / "model"
if (adapter_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(base_model, str(adapter_dir))
print(f"[load] adapter from {adapter_dir}", flush=True)
else:
model = base_model
model.to(device).eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
projector = LatentProjector(
d_model, init_scale=cfg["proj_init_scale"],
use_mlp=cfg.get("proj_mlp", False),
hidden_mult=cfg.get("proj_hidden_mult", 4),
).to(device).to(next(model.parameters()).dtype)
projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
projector.eval()
if args.eval_dataset == "math":
ds = MATHDataset(split="test", max_examples=args.n)
# MATH problems are longer — bump the loader's max_prompt/max_answer
max_p, max_a = max(192, cfg["max_prompt_len"]), max(256, cfg["max_answer_len"])
else:
ds = GSM8KDataset(split="test", max_examples=args.n)
max_p, max_a = cfg["max_prompt_len"], cfg["max_answer_len"]
loader = DataLoader(
ds, batch_size=4, shuffle=False,
collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=max_p, max_answer_len=max_a),
)
print(f"[rank] dataset={args.eval_dataset} n={args.n} K={K}", flush=True)
t0 = time.time()
Z = collect_z_batch(model, projector, loader, device, K, max_batches=args.n // 4 + 1)
print(f"[rank] collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)", flush=True)
stats = rank_stats(Z)
print(f"[rank] stable_rank = {stats['stable_rank']:.2f}", flush=True)
print(f"[rank] eff_rank = {stats['eff_rank_exp_entropy']:.2f}", flush=True)
print(f"[rank] top-1/4/8 var = {stats['top1_var_frac']:.3f} / {stats['top4_var_frac']:.3f} / {stats['top8_var_frac']:.3f}", flush=True)
print(f"[rank] z_std = {stats['z_std']:.4f}", flush=True)
summary = {"ckpt": str(ckpt), "eval_dataset": args.eval_dataset, "n": args.n,
"K": K, "rank": stats}
out = args.out or str(ckpt / f"rank_on_{args.eval_dataset}.json")
Path(out).write_text(json.dumps(summary, indent=2))
print(f"[written] {out}", flush=True)
if __name__ == "__main__":
main()