| """Teacher-forced 3-way z-ablation eval. |
| |
| Uses `forward_with_latent` directly (which respects the `block_z_to_x` flag via |
| the 4D mask) and computes per-token accuracy on y under three conditions: |
| normal-z : z computed by the M-step loop |
| random-z : z input replaced by Gaussian noise (matched to z_std) |
| zero-z : K=0 (no z slots at all) |
| |
| This is *teacher-forced* accuracy (the model sees gold y prefix when predicting |
| each token), so it's not the same metric as autoregressive `generate` accuracy. |
| But it directly tests "does z's content carry the signal y needs?" — which is |
| exactly the question the leak hypothesis is about. Autoregressive generation |
| with `block_z_to_x` would require non-trivial changes to `generate_with_latent` |
| (its KV-cache path doesn't use the 4D mask). For the principled experiment, |
| teacher-forced acc is the cleaner signal. |
| |
| Usage: |
| python -m experiments.blt_reasoner.scripts.ablate_teacher_forced \ |
| --ckpt /path/to/final --config <config.json> --n 200 --K 8 \ |
| --out /path/to/ablation_tf.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from ..data import GSM8KDataset, MATHDataset, collate_batch |
| from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent |
|
|
|
|
| @torch.no_grad() |
| def estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x): |
| all_z = [] |
| for i, b in enumerate(loader): |
| if i >= 4: break |
| _, z, _ = forward_with_latent( |
| model, b.x_ids.to(device), b.x_attn.to(device), |
| b.y_ids.to(device), projector, K, |
| block_y_to_x=True, block_z_to_x=block_z_to_x, |
| ) |
| all_z.append(z.float().cpu()) |
| return float(torch.cat(all_z, 0).std().item()) |
|
|
|
|
| def teacher_forced_accuracy( |
| model, projector, tokenizer, loader, device, K, |
| *, condition: str, z_std: float, block_z_to_x: bool, seed: int = 0, |
| ) -> dict: |
| """Per-token accuracy on y, scored token-by-token vs gold y under |
| teacher forcing (the model sees gold prefix for each prediction). |
| """ |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| proj_dtype = next(projector.parameters()).dtype |
| total_correct = 0 |
| total = 0 |
| sample_texts = [] |
| for batch in loader: |
| x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device); y_mask = batch.y_attn.to(device) |
| B = x_ids.size(0) |
|
|
| override = None |
| K_eff = K |
| if condition == "random": |
| g = torch.Generator(device=device).manual_seed(seed + total) |
| override = torch.randn(B, K, d_model, device=device, generator=g, dtype=proj_dtype) * z_std |
| elif condition == "zero": |
| override = torch.zeros(B, 0, d_model, device=device, dtype=proj_dtype) |
| K_eff = 0 |
|
|
| if override is not None: |
| |
| |
| embed_in = inner.get_input_embeddings() |
| x_embeds = embed_in(x_ids) |
| y_embeds = embed_in(y_ids) |
| P = x_ids.size(1); L_y = y_ids.size(1) |
| full_embeds = torch.cat([x_embeds, override.to(y_embeds.dtype), y_embeds], dim=1) |
| from ..model import build_blt_mask |
| mask = build_blt_mask(B, P, K_eff, L_y, device=device, dtype=full_embeds.dtype, |
| block_y_to_x=True, block_z_to_x=block_z_to_x) |
| |
| if (x_attn == 0).any(): |
| pad_kv = torch.cat([(x_attn == 0), |
| torch.zeros(B, K_eff + L_y, device=device, dtype=torch.bool)], dim=1) |
| mask = mask.clone() |
| mask.masked_fill_(pad_kv[:, None, None, :], -1e9) |
| transformer = inner.model |
| lm_head = inner.get_output_embeddings() |
| out = transformer(inputs_embeds=full_embeds, attention_mask=mask, |
| use_cache=False, return_dict=True) |
| logits_all = lm_head(out.last_hidden_state) |
| logits_y = logits_all[:, P + K_eff - 1: P + K_eff - 1 + L_y, :] if K_eff > 0 else \ |
| logits_all[:, P - 1: P - 1 + L_y, :] |
| else: |
| logits_y, _, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K_eff, |
| block_y_to_x=True, block_z_to_x=block_z_to_x, |
| ) |
|
|
| pred = logits_y.argmax(dim=-1) |
| |
| correct = ((pred == y_ids) * y_mask).sum().item() |
| n = y_mask.sum().item() |
| total_correct += correct |
| total += n |
| if len(sample_texts) < 3: |
| t = tokenizer.decode(pred[0].clamp(min=0), skip_special_tokens=True) |
| sample_texts.append(t[:200]) |
| return { |
| "condition": condition, |
| "K": K_eff, |
| "tok_acc": total_correct / max(total, 1), |
| "n_tokens": total, |
| "sample_preds": sample_texts, |
| } |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--config", required=True) |
| p.add_argument("--n", type=int, default=200) |
| p.add_argument("--K", type=int, default=None) |
| p.add_argument("--out", default=None) |
| args = p.parse_args() |
| with open(args.config) as f: |
| cfg = json.load(f) |
| K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1] |
| block_z_to_x = bool(cfg.get("block_z_to_x", False)) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = Path(args.ckpt) |
| bcfg_nolora = BLTConfig( |
| base_model=cfg["base_model"], use_lora=False, |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=K, block_y_to_x=cfg["block_y_to_x"], |
| block_z_to_x=block_z_to_x, |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| gradient_checkpointing=False, |
| ) |
| base_model, tokenizer = build_base(bcfg_nolora) |
| from peft import PeftModel |
| adapter_dir = ckpt / "model" |
| if (adapter_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(base_model, str(adapter_dir)) |
| print(f"[load] adapter from {adapter_dir}") |
| else: |
| model = base_model |
| model.to(device).eval() |
| inner_base = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner_base.config.hidden_size |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(next(model.parameters()).dtype) |
| projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) |
| projector.eval() |
|
|
| ds_name = cfg.get("dataset", "gsm8k") |
| val_ds = MATHDataset(split="test", max_examples=args.n) if ds_name.lower() == "math" \ |
| else GSM8KDataset(split="test", max_examples=args.n) |
| loader = DataLoader( |
| val_ds, batch_size=8, shuffle=False, |
| collate_fn=lambda b: collate_batch(b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"]), |
| ) |
| z_std = estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x) |
| print(f"[z_std] {z_std:.4f}") |
| results = {} |
| t0 = time.time() |
| for cond in ["normal", "random", "zero"]: |
| r = teacher_forced_accuracy(model, projector, tokenizer, loader, device, K, |
| condition=cond, z_std=z_std, |
| block_z_to_x=block_z_to_x, seed=0) |
| results[cond] = r |
| print(f"[{cond}] tok_acc={r['tok_acc']:.4f} elapsed={time.time()-t0:.0f}s") |
|
|
| summary = { |
| "ckpt": str(ckpt), "K": K, "n": args.n, "z_std": z_std, |
| "block_z_to_x_at_train_and_eval": block_z_to_x, |
| "results": results, |
| "delta_tokacc_normal_minus_random": results["normal"]["tok_acc"] - results["random"]["tok_acc"], |
| "delta_tokacc_normal_minus_zero": results["normal"]["tok_acc"] - results["zero"]["tok_acc"], |
| } |
| out = args.out or str(ckpt / "ablation_teacher_forced.json") |
| Path(out).write_text(json.dumps(summary, indent=2)) |
| print(f"[written] {out}") |
| print(f"Δ_random_tok = {summary['delta_tokacc_normal_minus_random']:+.4f}") |
| print(f"Δ_zero_tok = {summary['delta_tokacc_normal_minus_zero']:+.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|