| """Pre-registered ablation evaluation for BLT-Reasoner. |
| |
| Computes GSM8K accuracy under three conditions: |
| A. normal-z : latents from W_proj(h_{t-1}) loop |
| B. random-z : latents drawn from N(0, σ²) with σ matched to mean ||z|| |
| C. zero-z : K=0 (no latents at all; y attends directly to x, but |
| block_y_to_x is still on so y has no information path — |
| expected to be ~0% if the bottleneck is working) |
| |
| H1 success: acc(A) - acc(B) >= 15pp AND acc(A) - acc(C) >= 25pp |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import re |
| import time |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from .data import GSM8KDataset, MATHDataset, collate_batch, extract_final_number, extract_boxed_answer |
| from .model import BLTConfig, LatentProjector, build_base, forward_with_latent, generate_with_latent |
|
|
|
|
| def parse_pred(text: str, dataset: str = "gsm8k") -> Optional[str]: |
| """Extract final answer from model output. Dataset-aware: |
| * gsm8k: look for "#### N", fall back to last number. |
| * math: look for the LAST ``\\boxed{...}`` (handles latex), fall back to last number. |
| """ |
| ds = (dataset or "gsm8k").lower() |
| if ds == "math": |
| boxed = extract_boxed_answer(text) |
| if boxed is not None: |
| return boxed.strip() |
| |
| nums = re.findall(r"-?\d+(?:\.\d+)?", text) |
| return nums[-1] if nums else None |
| |
| m = re.findall(r"####\s*(-?\d+(?:\.\d+)?)", text) |
| if m: |
| return m[-1].strip() |
| nums = re.findall(r"-?\d+(?:\.\d+)?", text) |
| return nums[-1] if nums else None |
|
|
|
|
| def _normalize_math_answer(s: str) -> str: |
| """Aggressively normalize MATH-style answer strings for comparison. |
| |
| Strips whitespace, LaTeX wrappers, dollar signs, common formatting noise. |
| Not a complete LaTeX-equivalent checker — close to but weaker than the |
| Hendrycks et al. evaluator. For our purposes we want a fast, deterministic |
| string compare that catches the common-case correctness signals. |
| """ |
| if s is None: |
| return "" |
| s = s.strip().replace(" ", "") |
| |
| while s.startswith("$") and s.endswith("$") and len(s) > 2: |
| s = s[1:-1] |
| |
| s = re.sub(r"\\text\{([^{}]*)\}", r"\1", s) |
| |
| while s.startswith("{") and s.endswith("}") and len(s) > 2: |
| s = s[1:-1] |
| |
| if s.endswith("."): |
| s = s[:-1] |
| return s |
|
|
|
|
| def correct(pred: Optional[str], gold: str, dataset: str = "gsm8k") -> bool: |
| if pred is None: |
| return False |
| ds = (dataset or "gsm8k").lower() |
| if ds == "math": |
| p = _normalize_math_answer(pred) |
| g = _normalize_math_answer(gold) |
| if p == g: |
| return True |
| |
| try: |
| return abs(float(p) - float(g)) < 1e-4 |
| except ValueError: |
| return False |
| |
| try: |
| return abs(float(pred) - float(gold)) < 1e-4 |
| except ValueError: |
| return False |
|
|
|
|
| def estimate_z_std(model, projector, tokenizer, val_loader, device, K) -> float: |
| """Run model on a few batches to estimate the per-coordinate std of z.""" |
| from .model import forward_with_latent |
| model.eval() |
| all_z = [] |
| with torch.no_grad(): |
| for i, batch in enumerate(val_loader): |
| if i >= 4: break |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device) |
| _, z, _ = forward_with_latent(model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=True) |
| all_z.append(z.float().cpu()) |
| z_cat = torch.cat(all_z, dim=0) |
| return float(z_cat.std().item()) |
|
|
|
|
| def run_condition(model, projector, tokenizer, val_loader, device, K, condition: str, |
| z_std: float, max_new_tokens: int, temperature: float, |
| dataset: str = "gsm8k", |
| block_y_to_x: bool = True) -> dict: |
| """condition in {"normal", "random", "zero"}. |
| |
| `dataset` controls parsing of the gold final answer and of the prediction: |
| "gsm8k" → "#### N", "math" → \\boxed{...}. |
| """ |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| correct_n = 0 |
| total = 0 |
| examples = [] |
| model.eval() |
| for batch in val_loader: |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| B = x_ids.size(0) |
| if condition == "normal": |
| override_z = None |
| K_eff = K |
| elif condition == "random": |
| override_z = torch.randn(B, K, d_model, device=device, |
| dtype=next(projector.parameters()).dtype) * z_std |
| K_eff = K |
| elif condition == "zero": |
| override_z = torch.zeros(B, 0, d_model, device=device, |
| dtype=next(projector.parameters()).dtype) |
| K_eff = 0 |
| else: |
| raise ValueError(condition) |
|
|
| gen = generate_with_latent( |
| model, tokenizer, projector, |
| x_ids=x_ids, x_attn=x_attn, K=K_eff, |
| block_y_to_x=block_y_to_x, max_new_tokens=max_new_tokens, |
| temperature=temperature, eos_token_id=tokenizer.eos_token_id, |
| override_z=override_z, |
| ) |
| for b in range(B): |
| text = tokenizer.decode(gen[b], skip_special_tokens=True) |
| pred = parse_pred(text, dataset=dataset) |
| |
| raw_gold = batch.final_strs[b] |
| gold = raw_gold.replace("#### ", "").strip() if dataset.lower() != "math" else raw_gold.strip() |
| ok = correct(pred, gold, dataset=dataset) |
| correct_n += int(ok) |
| total += 1 |
| if len(examples) < 5: |
| examples.append({"text": text[:200], "pred": pred, "gold": gold, "ok": ok}) |
| return {"condition": condition, "K": K_eff, "acc": correct_n / max(total, 1), |
| "n": total, "correct": correct_n, "examples": examples} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def _get_z_for_batch(model, projector, x_ids, x_attn, K): |
| """Compute z by running the latent loop, no backprop, no y forward.""" |
| _, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids=None, projector=projector, K=K, |
| block_y_to_x=True, return_z=True, |
| ) |
| return z |
|
|
|
|
| def _perturb_z(z: torch.Tensor, severity: float, z_std: float, seed: int) -> torch.Tensor: |
| """Replace ⌊severity·K⌋ randomly-chosen latent positions per example with |
| Gaussian noise matched to z_std. Deterministic given seed for fair compare |
| across severities and conditions. |
| """ |
| B, K, d = z.shape |
| if severity <= 0.0: |
| return z |
| n_replace = max(1, int(round(severity * K))) |
| g = torch.Generator(device=z.device).manual_seed(seed) |
| out = z.clone() |
| for b in range(B): |
| idx = torch.randperm(K, generator=g, device=z.device)[:n_replace] |
| noise = torch.randn(n_replace, d, device=z.device, generator=g, |
| dtype=z.dtype) * z_std |
| out[b, idx] = noise |
| return out |
|
|
|
|
| def run_perturbation_curve(model, projector, tokenizer, val_loader, device, K, |
| z_std: float, severities, max_new_tokens: int, |
| temperature: float, seed: int = 0) -> dict: |
| """For each severity p, replace fraction p of latent positions with noise |
| and evaluate accuracy. severities is a list of floats in [0, 1]. |
| """ |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| model.eval() |
| curve = [] |
| examples_at_p100 = [] |
| for p in severities: |
| correct_n, total = 0, 0 |
| for bi, batch in enumerate(val_loader): |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| B = x_ids.size(0) |
| z = _get_z_for_batch(model, projector, x_ids, x_attn, K) |
| z_pert = _perturb_z(z, severity=p, z_std=z_std, seed=seed + bi) |
| gen = generate_with_latent( |
| model, tokenizer, projector, |
| x_ids=x_ids, x_attn=x_attn, K=K, |
| block_y_to_x=True, max_new_tokens=max_new_tokens, |
| temperature=temperature, eos_token_id=tokenizer.eos_token_id, |
| override_z=z_pert, |
| ) |
| for b in range(B): |
| text = tokenizer.decode(gen[b], skip_special_tokens=True) |
| pred = parse_pred(text) |
| gold = batch.final_strs[b].replace("#### ", "").strip() |
| ok = correct(pred, gold) |
| correct_n += int(ok) |
| total += 1 |
| if p == severities[-1] and len(examples_at_p100) < 3: |
| examples_at_p100.append({"text": text[:200], "pred": pred, |
| "gold": gold, "ok": ok}) |
| acc = correct_n / max(total, 1) |
| curve.append({"severity": p, "acc": acc, "correct": correct_n, "n": total}) |
| print(f"[perturb p={p:.2f}] acc={acc:.4f} ({correct_n}/{total})") |
| |
| accs = [c["acc"] for c in curve] |
| n_monotone = sum(1 for i in range(len(accs) - 1) if accs[i] >= accs[i + 1]) |
| return { |
| "curve": curve, |
| "n_pairs_monotone_decreasing": n_monotone, |
| "n_pairs_total": len(accs) - 1, |
| "acc_at_0": accs[0], |
| "acc_at_1": accs[-1], |
| "drop_0_to_1": accs[0] - accs[-1], |
| "examples_at_max_severity": examples_at_p100, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt", required=True, help="path to ckpt dir containing model/, projector.pt, head.pt") |
| parser.add_argument("--config", required=True) |
| parser.add_argument("--n", type=int, default=200) |
| parser.add_argument("--K", type=int, default=None, help="latent count to use (defaults to config end-of-curriculum)") |
| parser.add_argument("--max_new_tokens", type=int, default=256) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--out", default=None) |
| parser.add_argument("--perturbation_curve", action="store_true", |
| help="Also run a perturbation-severity sweep (Viteri-style)") |
| parser.add_argument("--severities", default="0.0,0.25,0.5,0.75,1.0", |
| help="Comma-separated severities for the perturbation curve") |
| parser.add_argument("--no_block_y_to_x", action="store_true", |
| help="EVALUATE without the y→only-z bottleneck mask (lets y " |
| "attend to x directly during generation). Tests " |
| "bottleneck-as-regularizer: does z's learned structure " |
| "help when the inference constraint is lifted?") |
| args = parser.parse_args() |
|
|
| with open(args.config) as f: |
| cfg = json.load(f) |
| K = args.K if args.K is not None else cfg["K_curriculum"][-1][1] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = Path(args.ckpt) |
| |
| bcfg_nolora = BLTConfig( |
| base_model=cfg["base_model"], use_lora=False, |
| K_latents=K, block_y_to_x=cfg["block_y_to_x"], |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| ) |
| base_model, tokenizer = build_base(bcfg_nolora) |
| from peft import PeftModel |
| adapter_dir = ckpt / "model" |
| if (adapter_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(base_model, str(adapter_dir)) |
| print(f"[load] adapter from {adapter_dir}") |
| else: |
| model = base_model |
| print(f"[load] no adapter at {adapter_dir} (using base only)") |
| model.to(device).eval() |
|
|
| inner_base = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner_base.config.hidden_size |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(next(model.parameters()).dtype) |
| projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) |
| projector.eval() |
|
|
| dataset_name = cfg.get("dataset", "gsm8k") |
| if dataset_name.lower() == "math": |
| val_ds = MATHDataset(split="test", max_examples=args.n) |
| else: |
| val_ds = GSM8KDataset(split="test", max_examples=args.n) |
| val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, |
| collate_fn=lambda b: collate_batch( |
| b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"], |
| )) |
|
|
| z_std = estimate_z_std(model, projector, tokenizer, val_loader, device, K) |
| print(f"[z_std estimate] {z_std:.4f} dataset={dataset_name}") |
|
|
| eval_block_y_to_x = not args.no_block_y_to_x |
| print(f"[mode] eval_block_y_to_x={eval_block_y_to_x}") |
| results = {} |
| t0 = time.time() |
| for cond in ["normal", "random", "zero"]: |
| r = run_condition(model, projector, tokenizer, val_loader, device, K, |
| cond, z_std, args.max_new_tokens, args.temperature, |
| dataset=dataset_name, block_y_to_x=eval_block_y_to_x) |
| results[cond] = r |
| print(f"[{cond}] acc={r['acc']:.4f} ({r['correct']}/{r['n']}) elapsed={time.time()-t0:.0f}s") |
|
|
| summary = { |
| "ckpt": str(ckpt), "K": K, "n": args.n, "z_std": z_std, |
| "eval_block_y_to_x": eval_block_y_to_x, |
| "dataset": dataset_name, |
| "results": results, |
| "delta_normal_minus_random": results["normal"]["acc"] - results["random"]["acc"], |
| "delta_normal_minus_zero": results["normal"]["acc"] - results["zero"]["acc"], |
| } |
| success_random = summary["delta_normal_minus_random"] >= 0.15 |
| success_zero = summary["delta_normal_minus_zero"] >= 0.25 |
| summary["H1_supported"] = bool(success_random and success_zero) |
|
|
| if args.perturbation_curve: |
| severities = [float(s) for s in args.severities.split(",")] |
| print(f"[perturbation_curve] severities={severities}") |
| curve = run_perturbation_curve( |
| model, projector, tokenizer, val_loader, device, K, z_std, |
| severities=severities, max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, seed=0, |
| ) |
| summary["perturbation_curve"] = curve |
| print(f"[perturbation_curve] acc(p=0)={curve['acc_at_0']:.3f} -> " |
| f"acc(p=1)={curve['acc_at_1']:.3f} drop={curve['drop_0_to_1']:.3f} " |
| f"monotone={curve['n_pairs_monotone_decreasing']}/{curve['n_pairs_total']}") |
|
|
| out = args.out or str(ckpt / "ablation.json") |
| with open(out, "w") as f: |
| json.dump(summary, f, indent=2) |
| print(f"[written] {out}") |
| print(f"H1 supported? {summary['H1_supported']} " |
| f"(Δ_random={summary['delta_normal_minus_random']:.3f}, " |
| f"Δ_zero={summary['delta_normal_minus_zero']:.3f})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|