"""Teacher-forced 3-way z-ablation eval. Uses `forward_with_latent` directly (which respects the `block_z_to_x` flag via the 4D mask) and computes per-token accuracy on y under three conditions: normal-z : z computed by the M-step loop random-z : z input replaced by Gaussian noise (matched to z_std) zero-z : K=0 (no z slots at all) This is *teacher-forced* accuracy (the model sees gold y prefix when predicting each token), so it's not the same metric as autoregressive `generate` accuracy. But it directly tests "does z's content carry the signal y needs?" — which is exactly the question the leak hypothesis is about. Autoregressive generation with `block_z_to_x` would require non-trivial changes to `generate_with_latent` (its KV-cache path doesn't use the 4D mask). For the principled experiment, teacher-forced acc is the cleaner signal. Usage: python -m experiments.blt_reasoner.scripts.ablate_teacher_forced \ --ckpt /path/to/final --config --n 200 --K 8 \ --out /path/to/ablation_tf.json """ from __future__ import annotations import argparse import json import time from pathlib import Path from typing import Optional import torch import torch.nn.functional as F from torch.utils.data import DataLoader from ..data import GSM8KDataset, MATHDataset, collate_batch from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent @torch.no_grad() def estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x): all_z = [] for i, b in enumerate(loader): if i >= 4: break _, z, _ = forward_with_latent( model, b.x_ids.to(device), b.x_attn.to(device), b.y_ids.to(device), projector, K, block_y_to_x=True, block_z_to_x=block_z_to_x, ) all_z.append(z.float().cpu()) return float(torch.cat(all_z, 0).std().item()) def teacher_forced_accuracy( model, projector, tokenizer, loader, device, K, *, condition: str, z_std: float, block_z_to_x: bool, seed: int = 0, ) -> dict: """Per-token accuracy on y, scored token-by-token vs gold y under teacher forcing (the model sees gold prefix for each prediction). """ inner = model.get_base_model() if hasattr(model, "get_base_model") else model d_model = inner.config.hidden_size proj_dtype = next(projector.parameters()).dtype total_correct = 0 total = 0 sample_texts = [] for batch in loader: x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device) y_ids = batch.y_ids.to(device); y_mask = batch.y_attn.to(device) B = x_ids.size(0) override = None K_eff = K if condition == "random": g = torch.Generator(device=device).manual_seed(seed + total) override = torch.randn(B, K, d_model, device=device, generator=g, dtype=proj_dtype) * z_std elif condition == "zero": override = torch.zeros(B, 0, d_model, device=device, dtype=proj_dtype) K_eff = 0 if override is not None: # Run pass 2 directly with the override z (skipping the M-step loop). # forward_with_latent doesn't expose override_z, so we mimic it manually. embed_in = inner.get_input_embeddings() x_embeds = embed_in(x_ids) y_embeds = embed_in(y_ids) P = x_ids.size(1); L_y = y_ids.size(1) full_embeds = torch.cat([x_embeds, override.to(y_embeds.dtype), y_embeds], dim=1) from ..model import build_blt_mask mask = build_blt_mask(B, P, K_eff, L_y, device=device, dtype=full_embeds.dtype, block_y_to_x=True, block_z_to_x=block_z_to_x) # Mask out x-pad positions in keys if (x_attn == 0).any(): pad_kv = torch.cat([(x_attn == 0), torch.zeros(B, K_eff + L_y, device=device, dtype=torch.bool)], dim=1) mask = mask.clone() mask.masked_fill_(pad_kv[:, None, None, :], -1e9) transformer = inner.model lm_head = inner.get_output_embeddings() out = transformer(inputs_embeds=full_embeds, attention_mask=mask, use_cache=False, return_dict=True) logits_all = lm_head(out.last_hidden_state) logits_y = logits_all[:, P + K_eff - 1: P + K_eff - 1 + L_y, :] if K_eff > 0 else \ logits_all[:, P - 1: P - 1 + L_y, :] else: logits_y, _, _ = forward_with_latent( model, x_ids, x_attn, y_ids, projector, K_eff, block_y_to_x=True, block_z_to_x=block_z_to_x, ) pred = logits_y.argmax(dim=-1) # Shifted: logits at t predict token at t (already aligned by forward_with_latent). correct = ((pred == y_ids) * y_mask).sum().item() n = y_mask.sum().item() total_correct += correct total += n if len(sample_texts) < 3: t = tokenizer.decode(pred[0].clamp(min=0), skip_special_tokens=True) sample_texts.append(t[:200]) return { "condition": condition, "K": K_eff, "tok_acc": total_correct / max(total, 1), "n_tokens": total, "sample_preds": sample_texts, } def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True) p.add_argument("--config", required=True) p.add_argument("--n", type=int, default=200) p.add_argument("--K", type=int, default=None) p.add_argument("--out", default=None) args = p.parse_args() with open(args.config) as f: cfg = json.load(f) K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1] block_z_to_x = bool(cfg.get("block_z_to_x", False)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = Path(args.ckpt) bcfg_nolora = BLTConfig( base_model=cfg["base_model"], use_lora=False, lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], lora_dropout=cfg["lora_dropout"], lora_target_modules=tuple(cfg["lora_target_modules"]), K_latents=K, block_y_to_x=cfg["block_y_to_x"], block_z_to_x=block_z_to_x, proj_init_scale=cfg["proj_init_scale"], dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], gradient_checkpointing=False, ) base_model, tokenizer = build_base(bcfg_nolora) from peft import PeftModel adapter_dir = ckpt / "model" if (adapter_dir / "adapter_config.json").exists(): model = PeftModel.from_pretrained(base_model, str(adapter_dir)) print(f"[load] adapter from {adapter_dir}") else: model = base_model model.to(device).eval() inner_base = model.get_base_model() if hasattr(model, "get_base_model") else model d_model = inner_base.config.hidden_size projector = LatentProjector( d_model, init_scale=cfg["proj_init_scale"], use_mlp=cfg.get("proj_mlp", False), hidden_mult=cfg.get("proj_hidden_mult", 4), ).to(device).to(next(model.parameters()).dtype) projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) projector.eval() ds_name = cfg.get("dataset", "gsm8k") val_ds = MATHDataset(split="test", max_examples=args.n) if ds_name.lower() == "math" \ else GSM8KDataset(split="test", max_examples=args.n) loader = DataLoader( val_ds, batch_size=8, shuffle=False, collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=cfg["max_prompt_len"], max_answer_len=cfg["max_answer_len"]), ) z_std = estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x) print(f"[z_std] {z_std:.4f}") results = {} t0 = time.time() for cond in ["normal", "random", "zero"]: r = teacher_forced_accuracy(model, projector, tokenizer, loader, device, K, condition=cond, z_std=z_std, block_z_to_x=block_z_to_x, seed=0) results[cond] = r print(f"[{cond}] tok_acc={r['tok_acc']:.4f} elapsed={time.time()-t0:.0f}s") summary = { "ckpt": str(ckpt), "K": K, "n": args.n, "z_std": z_std, "block_z_to_x_at_train_and_eval": block_z_to_x, "results": results, "delta_tokacc_normal_minus_random": results["normal"]["tok_acc"] - results["random"]["tok_acc"], "delta_tokacc_normal_minus_zero": results["normal"]["tok_acc"] - results["zero"]["tok_acc"], } out = args.out or str(ckpt / "ablation_teacher_forced.json") Path(out).write_text(json.dumps(summary, indent=2)) print(f"[written] {out}") print(f"Δ_random_tok = {summary['delta_tokacc_normal_minus_random']:+.4f}") print(f"Δ_zero_tok = {summary['delta_tokacc_normal_minus_zero']:+.4f}") if __name__ == "__main__": main()