"""Minimal z rank computation — no generation, no perturbation curve. Forward N problems through a ckpt and compute the rank statistics of the resulting z (input embeddings from the M-step latent loop). Lightweight enough to run in parallel with training (~15-20 GB activation footprint vs full capacity_diagnostic's ~30 GB). Used to ablate the "harder problems → richer z" hypothesis: run the GSM8K- trained GRPO ckpt against MATH test problems, compare stable_rank to the known GSM8K-eval value (6.73). """ from __future__ import annotations import argparse import json import time from pathlib import Path from typing import Optional import torch from torch.utils.data import DataLoader from ..data import GSM8KDataset, MATHDataset, collate_batch from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent @torch.no_grad() def collect_z_batch(model, projector, loader, device, K, max_batches=20): chunks = [] for i, b in enumerate(loader): if i >= max_batches: break _, z, _ = forward_with_latent( model, b.x_ids.to(device), b.x_attn.to(device), b.y_ids.to(device), projector, K, block_y_to_x=True, return_z=True, ) chunks.append(z.float().reshape(-1, z.size(-1)).cpu()) return torch.cat(chunks, dim=0) def rank_stats(M: torch.Tensor) -> dict: M = M.float() U, S, V = torch.linalg.svd(M, full_matrices=False) sv = S.clamp_min(1e-12) p = sv / sv.sum() eff = float(torch.exp((-p * p.log()).sum()).item()) stable = float((sv.pow(2).sum() / sv.pow(2).max()).item()) cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist() return { "n_singvals": int(S.numel()), "eff_rank_exp_entropy": eff, "stable_rank": stable, "top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()), "top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()), "top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()), "cum_explained_var_first16": cum[:16], "z_std": float(M.std().item()), } def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True) p.add_argument("--config", required=True) p.add_argument("--n", type=int, default=100) p.add_argument("--K", type=int, default=None) p.add_argument("--eval_dataset", required=True, choices=["gsm8k", "math"], help="Which dataset's TEST split to evaluate against (independent of training data)") p.add_argument("--out", default=None) args = p.parse_args() with open(args.config) as f: cfg = json.load(f) K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = Path(args.ckpt) bcfg = BLTConfig( base_model=cfg["base_model"], use_lora=False, lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], lora_dropout=cfg["lora_dropout"], lora_target_modules=tuple(cfg["lora_target_modules"]), K_latents=K, block_y_to_x=cfg["block_y_to_x"], proj_init_scale=cfg["proj_init_scale"], dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], gradient_checkpointing=False, ) base_model, tokenizer = build_base(bcfg) from peft import PeftModel adapter_dir = ckpt / "model" if (adapter_dir / "adapter_config.json").exists(): model = PeftModel.from_pretrained(base_model, str(adapter_dir)) print(f"[load] adapter from {adapter_dir}", flush=True) else: model = base_model model.to(device).eval() inner = model.get_base_model() if hasattr(model, "get_base_model") else model d_model = inner.config.hidden_size projector = LatentProjector( d_model, init_scale=cfg["proj_init_scale"], use_mlp=cfg.get("proj_mlp", False), hidden_mult=cfg.get("proj_hidden_mult", 4), ).to(device).to(next(model.parameters()).dtype) projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) projector.eval() if args.eval_dataset == "math": ds = MATHDataset(split="test", max_examples=args.n) # MATH problems are longer — bump the loader's max_prompt/max_answer max_p, max_a = max(192, cfg["max_prompt_len"]), max(256, cfg["max_answer_len"]) else: ds = GSM8KDataset(split="test", max_examples=args.n) max_p, max_a = cfg["max_prompt_len"], cfg["max_answer_len"] loader = DataLoader( ds, batch_size=4, shuffle=False, collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=max_p, max_answer_len=max_a), ) print(f"[rank] dataset={args.eval_dataset} n={args.n} K={K}", flush=True) t0 = time.time() Z = collect_z_batch(model, projector, loader, device, K, max_batches=args.n // 4 + 1) print(f"[rank] collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)", flush=True) stats = rank_stats(Z) print(f"[rank] stable_rank = {stats['stable_rank']:.2f}", flush=True) print(f"[rank] eff_rank = {stats['eff_rank_exp_entropy']:.2f}", flush=True) print(f"[rank] top-1/4/8 var = {stats['top1_var_frac']:.3f} / {stats['top4_var_frac']:.3f} / {stats['top8_var_frac']:.3f}", flush=True) print(f"[rank] z_std = {stats['z_std']:.4f}", flush=True) summary = {"ckpt": str(ckpt), "eval_dataset": args.eval_dataset, "n": args.n, "K": K, "rank": stats} out = args.out or str(ckpt / f"rank_on_{args.eval_dataset}.json") Path(out).write_text(json.dumps(summary, indent=2)) print(f"[written] {out}", flush=True) if __name__ == "__main__": main()