| """Minimal z rank computation — no generation, no perturbation curve. |
| |
| Forward N problems through a ckpt and compute the rank statistics of the |
| resulting z (input embeddings from the M-step latent loop). Lightweight |
| enough to run in parallel with training (~15-20 GB activation footprint |
| vs full capacity_diagnostic's ~30 GB). |
| |
| Used to ablate the "harder problems → richer z" hypothesis: run the GSM8K- |
| trained GRPO ckpt against MATH test problems, compare stable_rank to the |
| known GSM8K-eval value (6.73). |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from ..data import GSM8KDataset, MATHDataset, collate_batch |
| from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent |
|
|
|
|
| @torch.no_grad() |
| def collect_z_batch(model, projector, loader, device, K, max_batches=20): |
| chunks = [] |
| for i, b in enumerate(loader): |
| if i >= max_batches: break |
| _, z, _ = forward_with_latent( |
| model, b.x_ids.to(device), b.x_attn.to(device), |
| b.y_ids.to(device), projector, K, |
| block_y_to_x=True, return_z=True, |
| ) |
| chunks.append(z.float().reshape(-1, z.size(-1)).cpu()) |
| return torch.cat(chunks, dim=0) |
|
|
|
|
| def rank_stats(M: torch.Tensor) -> dict: |
| M = M.float() |
| U, S, V = torch.linalg.svd(M, full_matrices=False) |
| sv = S.clamp_min(1e-12) |
| p = sv / sv.sum() |
| eff = float(torch.exp((-p * p.log()).sum()).item()) |
| stable = float((sv.pow(2).sum() / sv.pow(2).max()).item()) |
| cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist() |
| return { |
| "n_singvals": int(S.numel()), |
| "eff_rank_exp_entropy": eff, |
| "stable_rank": stable, |
| "top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()), |
| "top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()), |
| "top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()), |
| "cum_explained_var_first16": cum[:16], |
| "z_std": float(M.std().item()), |
| } |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--config", required=True) |
| p.add_argument("--n", type=int, default=100) |
| p.add_argument("--K", type=int, default=None) |
| p.add_argument("--eval_dataset", required=True, choices=["gsm8k", "math"], |
| help="Which dataset's TEST split to evaluate against (independent of training data)") |
| p.add_argument("--out", default=None) |
| args = p.parse_args() |
|
|
| with open(args.config) as f: |
| cfg = json.load(f) |
| K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = Path(args.ckpt) |
| bcfg = BLTConfig( |
| base_model=cfg["base_model"], use_lora=False, |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=K, block_y_to_x=cfg["block_y_to_x"], |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| gradient_checkpointing=False, |
| ) |
| base_model, tokenizer = build_base(bcfg) |
| from peft import PeftModel |
| adapter_dir = ckpt / "model" |
| if (adapter_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(base_model, str(adapter_dir)) |
| print(f"[load] adapter from {adapter_dir}", flush=True) |
| else: |
| model = base_model |
| model.to(device).eval() |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(next(model.parameters()).dtype) |
| projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) |
| projector.eval() |
|
|
| if args.eval_dataset == "math": |
| ds = MATHDataset(split="test", max_examples=args.n) |
| |
| max_p, max_a = max(192, cfg["max_prompt_len"]), max(256, cfg["max_answer_len"]) |
| else: |
| ds = GSM8KDataset(split="test", max_examples=args.n) |
| max_p, max_a = cfg["max_prompt_len"], cfg["max_answer_len"] |
| loader = DataLoader( |
| ds, batch_size=4, shuffle=False, |
| collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=max_p, max_answer_len=max_a), |
| ) |
|
|
| print(f"[rank] dataset={args.eval_dataset} n={args.n} K={K}", flush=True) |
| t0 = time.time() |
| Z = collect_z_batch(model, projector, loader, device, K, max_batches=args.n // 4 + 1) |
| print(f"[rank] collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)", flush=True) |
| stats = rank_stats(Z) |
| print(f"[rank] stable_rank = {stats['stable_rank']:.2f}", flush=True) |
| print(f"[rank] eff_rank = {stats['eff_rank_exp_entropy']:.2f}", flush=True) |
| print(f"[rank] top-1/4/8 var = {stats['top1_var_frac']:.3f} / {stats['top4_var_frac']:.3f} / {stats['top8_var_frac']:.3f}", flush=True) |
| print(f"[rank] z_std = {stats['z_std']:.4f}", flush=True) |
|
|
| summary = {"ckpt": str(ckpt), "eval_dataset": args.eval_dataset, "n": args.n, |
| "K": K, "rank": stats} |
| out = args.out or str(ckpt / f"rank_on_{args.eval_dataset}.json") |
| Path(out).write_text(json.dumps(summary, indent=2)) |
| print(f"[written] {out}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|