File size: 5,698 Bytes
bc7101b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """Minimal z rank computation — no generation, no perturbation curve.
Forward N problems through a ckpt and compute the rank statistics of the
resulting z (input embeddings from the M-step latent loop). Lightweight
enough to run in parallel with training (~15-20 GB activation footprint
vs full capacity_diagnostic's ~30 GB).
Used to ablate the "harder problems → richer z" hypothesis: run the GSM8K-
trained GRPO ckpt against MATH test problems, compare stable_rank to the
known GSM8K-eval value (6.73).
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import Optional
import torch
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, MATHDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent
@torch.no_grad()
def collect_z_batch(model, projector, loader, device, K, max_batches=20):
chunks = []
for i, b in enumerate(loader):
if i >= max_batches: break
_, z, _ = forward_with_latent(
model, b.x_ids.to(device), b.x_attn.to(device),
b.y_ids.to(device), projector, K,
block_y_to_x=True, return_z=True,
)
chunks.append(z.float().reshape(-1, z.size(-1)).cpu())
return torch.cat(chunks, dim=0)
def rank_stats(M: torch.Tensor) -> dict:
M = M.float()
U, S, V = torch.linalg.svd(M, full_matrices=False)
sv = S.clamp_min(1e-12)
p = sv / sv.sum()
eff = float(torch.exp((-p * p.log()).sum()).item())
stable = float((sv.pow(2).sum() / sv.pow(2).max()).item())
cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist()
return {
"n_singvals": int(S.numel()),
"eff_rank_exp_entropy": eff,
"stable_rank": stable,
"top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()),
"top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()),
"top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()),
"cum_explained_var_first16": cum[:16],
"z_std": float(M.std().item()),
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--config", required=True)
p.add_argument("--n", type=int, default=100)
p.add_argument("--K", type=int, default=None)
p.add_argument("--eval_dataset", required=True, choices=["gsm8k", "math"],
help="Which dataset's TEST split to evaluate against (independent of training data)")
p.add_argument("--out", default=None)
args = p.parse_args()
with open(args.config) as f:
cfg = json.load(f)
K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = Path(args.ckpt)
bcfg = BLTConfig(
base_model=cfg["base_model"], use_lora=False,
lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
lora_dropout=cfg["lora_dropout"],
lora_target_modules=tuple(cfg["lora_target_modules"]),
K_latents=K, block_y_to_x=cfg["block_y_to_x"],
proj_init_scale=cfg["proj_init_scale"],
dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
gradient_checkpointing=False,
)
base_model, tokenizer = build_base(bcfg)
from peft import PeftModel
adapter_dir = ckpt / "model"
if (adapter_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(base_model, str(adapter_dir))
print(f"[load] adapter from {adapter_dir}", flush=True)
else:
model = base_model
model.to(device).eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
projector = LatentProjector(
d_model, init_scale=cfg["proj_init_scale"],
use_mlp=cfg.get("proj_mlp", False),
hidden_mult=cfg.get("proj_hidden_mult", 4),
).to(device).to(next(model.parameters()).dtype)
projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
projector.eval()
if args.eval_dataset == "math":
ds = MATHDataset(split="test", max_examples=args.n)
# MATH problems are longer — bump the loader's max_prompt/max_answer
max_p, max_a = max(192, cfg["max_prompt_len"]), max(256, cfg["max_answer_len"])
else:
ds = GSM8KDataset(split="test", max_examples=args.n)
max_p, max_a = cfg["max_prompt_len"], cfg["max_answer_len"]
loader = DataLoader(
ds, batch_size=4, shuffle=False,
collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=max_p, max_answer_len=max_a),
)
print(f"[rank] dataset={args.eval_dataset} n={args.n} K={K}", flush=True)
t0 = time.time()
Z = collect_z_batch(model, projector, loader, device, K, max_batches=args.n // 4 + 1)
print(f"[rank] collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)", flush=True)
stats = rank_stats(Z)
print(f"[rank] stable_rank = {stats['stable_rank']:.2f}", flush=True)
print(f"[rank] eff_rank = {stats['eff_rank_exp_entropy']:.2f}", flush=True)
print(f"[rank] top-1/4/8 var = {stats['top1_var_frac']:.3f} / {stats['top4_var_frac']:.3f} / {stats['top8_var_frac']:.3f}", flush=True)
print(f"[rank] z_std = {stats['z_std']:.4f}", flush=True)
summary = {"ckpt": str(ckpt), "eval_dataset": args.eval_dataset, "n": args.n,
"K": K, "rank": stats}
out = args.out or str(ckpt / f"rank_on_{args.eval_dataset}.json")
Path(out).write_text(json.dumps(summary, indent=2))
print(f"[written] {out}", flush=True)
if __name__ == "__main__":
main()
|