File size: 9,411 Bytes
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""Capacity diagnostic for the trained BLT latent: does K=32 have headroom?

Two cheap tests on an existing K-trained ckpt:

  1. **Effective rank of z.** Compute z for N test problems, stack into a
     matrix [N*K, d], compute SVD, then `eff_rank = exp(H(σ/sum(σ)))` where
     H is entropy of the normalized singular value distribution. If eff_rank
     ≈ K, the K slots are all carrying distinct information (K=32 might
     add more). If eff_rank << K, slots are redundant (K=32 unlikely to
     help).

  2. **Perturbation curve.** Replace fraction p ∈ {0, 0.25, 0.5, 0.75, 1.0}
     of slots with Gaussian noise (std-matched). Run AR generation,
     measure GSM8K accuracy. If acc drops gradually with p, all slots
     contribute. If acc stays flat up to high p then crashes, many
     slots are redundant.

Usage:
    python -m experiments.blt_reasoner.scripts.capacity_diagnostic \
        --ckpt /path/to/grpo_final --config <config.json> --n 100 --K 16
"""
from __future__ import annotations

import argparse
import json
import math
import re
import time
from pathlib import Path
from typing import Optional

import torch
from torch.utils.data import DataLoader

from ..data import GSM8KDataset, collate_batch
from ..model import (
    BLTConfig, LatentProjector, build_base,
    forward_with_latent, generate_with_latent,
)
from ..eval import parse_pred, correct, _perturb_z


@torch.no_grad()
def collect_z_batch(model, projector, loader, device, K, max_batches=20):
    """Return z stacked across batches: [N*K, d]."""
    chunks = []
    for i, b in enumerate(loader):
        if i >= max_batches: break
        _, z, _ = forward_with_latent(
            model, b.x_ids.to(device), b.x_attn.to(device),
            b.y_ids.to(device), projector, K,
            block_y_to_x=True, return_z=True,
        )
        # z: [B, K, d] -> [B*K, d]
        chunks.append(z.float().reshape(-1, z.size(-1)).cpu())
    return torch.cat(chunks, dim=0)


def effective_rank(M: torch.Tensor) -> dict:
    """eff_rank = exp(entropy(σ/sum(σ))). Also report stable rank and explained variance curve."""
    M = M.float()
    U, S, V = torch.linalg.svd(M, full_matrices=False)
    sv = S.clamp_min(1e-12)
    p = sv / sv.sum()
    eff = float(torch.exp((-p * p.log()).sum()).item())
    stable = float((sv.pow(2).sum() / sv.pow(2).max()).item())
    # Cumulative explained variance
    cum = (sv.pow(2).cumsum(0) / sv.pow(2).sum()).tolist()
    return {
        "n_singvals": int(S.numel()),
        "eff_rank_exp_entropy": eff,
        "stable_rank": stable,
        "top1_var_frac": float((sv[0].pow(2) / sv.pow(2).sum()).item()),
        "top4_var_frac": float((sv[:4].pow(2).sum() / sv.pow(2).sum()).item()),
        "top8_var_frac": float((sv[:8].pow(2).sum() / sv.pow(2).sum()).item()),
        "cum_explained_var_first16": cum[:16],
    }


@torch.no_grad()
def estimate_z_std(model, projector, loader, device, K, max_batches=4):
    z = collect_z_batch(model, projector, loader, device, K, max_batches=max_batches)
    return float(z.std().item())


def run_perturbation_curve(model, projector, tokenizer, loader, device, K, *,
                            z_std, severities, max_new_tokens, temperature, seed=0):
    """For each severity, replace `severity` fraction of slots with N(0, z_std²) and
    measure GSM8K AR accuracy. Reuses _perturb_z from eval.py."""
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    d_model = inner.config.hidden_size
    proj_dtype = next(projector.parameters()).dtype
    results = {}
    for sev in severities:
        correct_n = total = 0
        for bi, batch in enumerate(loader):
            x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device)
            y_ids = batch.y_ids.to(device)
            B = x_ids.size(0)
            # Compute z, then perturb
            _, z, _ = forward_with_latent(
                model, x_ids, x_attn, y_ids, projector, K,
                block_y_to_x=True, return_z=True,
            )
            z_pert = _perturb_z(z.to(device=device, dtype=proj_dtype),
                                 severity=sev, z_std=z_std, seed=seed + bi)
            gen = generate_with_latent(
                model, tokenizer, projector,
                x_ids=x_ids, x_attn=x_attn, K=K,
                block_y_to_x=True, max_new_tokens=max_new_tokens,
                temperature=temperature, eos_token_id=tokenizer.eos_token_id,
                override_z=z_pert,
            )
            for b in range(B):
                text = tokenizer.decode(gen[b], skip_special_tokens=True)
                pred = parse_pred(text)
                gold = batch.final_strs[b].replace("#### ", "").strip()
                if correct(pred, gold):
                    correct_n += 1
                total += 1
        results[float(sev)] = {"acc": correct_n / max(total, 1), "n": total, "correct": correct_n}
        print(f"  severity={sev:.2f}  acc={results[float(sev)]['acc']:.3f}  ({correct_n}/{total})", flush=True)
    return results


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", required=True)
    p.add_argument("--config", required=True)
    p.add_argument("--n", type=int, default=100)
    p.add_argument("--K", type=int, default=None)
    p.add_argument("--max_new_tokens", type=int, default=128)
    p.add_argument("--out", default=None)
    args = p.parse_args()
    with open(args.config) as f:
        cfg = json.load(f)
    K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = Path(args.ckpt)
    bcfg = BLTConfig(
        base_model=cfg["base_model"], use_lora=False,
        lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
        lora_dropout=cfg["lora_dropout"],
        lora_target_modules=tuple(cfg["lora_target_modules"]),
        K_latents=K, block_y_to_x=cfg["block_y_to_x"],
        proj_init_scale=cfg["proj_init_scale"],
        dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
        gradient_checkpointing=False,
    )
    base_model, tokenizer = build_base(bcfg)
    from peft import PeftModel
    adapter_dir = ckpt / "model"
    if (adapter_dir / "adapter_config.json").exists():
        model = PeftModel.from_pretrained(base_model, str(adapter_dir))
        print(f"[load] adapter from {adapter_dir}")
    else:
        model = base_model
    model.to(device).eval()
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    d_model = inner.config.hidden_size
    projector = LatentProjector(
        d_model, init_scale=cfg["proj_init_scale"],
        use_mlp=cfg.get("proj_mlp", False),
        hidden_mult=cfg.get("proj_hidden_mult", 4),
    ).to(device).to(next(model.parameters()).dtype)
    projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
    projector.eval()

    val_ds = GSM8KDataset(split="test", max_examples=args.n)
    loader = DataLoader(
        val_ds, batch_size=8, shuffle=False,
        collate_fn=lambda b: collate_batch(b, tokenizer,
                                            max_prompt_len=cfg["max_prompt_len"],
                                            max_answer_len=cfg["max_answer_len"]),
    )

    # 1. Effective rank
    print("\n[diagnostic 1] effective rank of z across test problems")
    t0 = time.time()
    Z = collect_z_batch(model, projector, loader, device, K, max_batches=20)
    print(f"  collected Z: shape={tuple(Z.shape)} ({time.time()-t0:.0f}s)")
    rank_stats = effective_rank(Z)
    print(f"  n_singvals={rank_stats['n_singvals']}")
    print(f"  eff_rank (exp_entropy) = {rank_stats['eff_rank_exp_entropy']:.2f} (K={K})")
    print(f"  stable_rank            = {rank_stats['stable_rank']:.2f}")
    print(f"  top-1  variance frac   = {rank_stats['top1_var_frac']:.3f}")
    print(f"  top-4  variance frac   = {rank_stats['top4_var_frac']:.3f}")
    print(f"  top-8  variance frac   = {rank_stats['top8_var_frac']:.3f}")

    # Interpretation hint
    if rank_stats['eff_rank_exp_entropy'] >= K * 0.7:
        verdict_rank = "HIGH eff_rank — slots are using distinct directions → K=32 plausibly helps"
    elif rank_stats['eff_rank_exp_entropy'] >= K * 0.4:
        verdict_rank = "MEDIUM eff_rank — partial slot redundancy; K=32 may give marginal lift"
    else:
        verdict_rank = "LOW eff_rank — many redundant slots; K=32 unlikely to help"
    print(f"  verdict: {verdict_rank}")

    # 2. Perturbation curve
    print("\n[diagnostic 2] perturbation curve (AR accuracy vs fraction-of-slots-replaced)")
    z_std = estimate_z_std(model, projector, loader, device, K, max_batches=4)
    print(f"  z_std={z_std:.4f}")
    severities = [0.0, 0.125, 0.25, 0.5, 0.75, 0.875, 1.0]
    pert_results = run_perturbation_curve(
        model, projector, tokenizer, loader, device, K,
        z_std=z_std, severities=severities,
        max_new_tokens=args.max_new_tokens, temperature=0.0,
    )

    summary = {
        "ckpt": str(ckpt),
        "n": args.n, "K": K, "z_std": z_std,
        "effective_rank": rank_stats,
        "verdict_eff_rank": verdict_rank,
        "perturbation_curve": pert_results,
    }
    out = args.out or str(ckpt / "capacity_diagnostic.json")
    Path(out).write_text(json.dumps(summary, indent=2))
    print(f"\n[written] {out}")


if __name__ == "__main__":
    main()