blt-reasoner-pilot1 / code /scripts /ablate_teacher_forced.py
LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""Teacher-forced 3-way z-ablation eval.
Uses `forward_with_latent` directly (which respects the `block_z_to_x` flag via
the 4D mask) and computes per-token accuracy on y under three conditions:
normal-z : z computed by the M-step loop
random-z : z input replaced by Gaussian noise (matched to z_std)
zero-z : K=0 (no z slots at all)
This is *teacher-forced* accuracy (the model sees gold y prefix when predicting
each token), so it's not the same metric as autoregressive `generate` accuracy.
But it directly tests "does z's content carry the signal y needs?" — which is
exactly the question the leak hypothesis is about. Autoregressive generation
with `block_z_to_x` would require non-trivial changes to `generate_with_latent`
(its KV-cache path doesn't use the 4D mask). For the principled experiment,
teacher-forced acc is the cleaner signal.
Usage:
python -m experiments.blt_reasoner.scripts.ablate_teacher_forced \
--ckpt /path/to/final --config <config.json> --n 200 --K 8 \
--out /path/to/ablation_tf.json
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, MATHDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent
@torch.no_grad()
def estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x):
all_z = []
for i, b in enumerate(loader):
if i >= 4: break
_, z, _ = forward_with_latent(
model, b.x_ids.to(device), b.x_attn.to(device),
b.y_ids.to(device), projector, K,
block_y_to_x=True, block_z_to_x=block_z_to_x,
)
all_z.append(z.float().cpu())
return float(torch.cat(all_z, 0).std().item())
def teacher_forced_accuracy(
model, projector, tokenizer, loader, device, K,
*, condition: str, z_std: float, block_z_to_x: bool, seed: int = 0,
) -> dict:
"""Per-token accuracy on y, scored token-by-token vs gold y under
teacher forcing (the model sees gold prefix for each prediction).
"""
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
proj_dtype = next(projector.parameters()).dtype
total_correct = 0
total = 0
sample_texts = []
for batch in loader:
x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device)
y_ids = batch.y_ids.to(device); y_mask = batch.y_attn.to(device)
B = x_ids.size(0)
override = None
K_eff = K
if condition == "random":
g = torch.Generator(device=device).manual_seed(seed + total)
override = torch.randn(B, K, d_model, device=device, generator=g, dtype=proj_dtype) * z_std
elif condition == "zero":
override = torch.zeros(B, 0, d_model, device=device, dtype=proj_dtype)
K_eff = 0
if override is not None:
# Run pass 2 directly with the override z (skipping the M-step loop).
# forward_with_latent doesn't expose override_z, so we mimic it manually.
embed_in = inner.get_input_embeddings()
x_embeds = embed_in(x_ids)
y_embeds = embed_in(y_ids)
P = x_ids.size(1); L_y = y_ids.size(1)
full_embeds = torch.cat([x_embeds, override.to(y_embeds.dtype), y_embeds], dim=1)
from ..model import build_blt_mask
mask = build_blt_mask(B, P, K_eff, L_y, device=device, dtype=full_embeds.dtype,
block_y_to_x=True, block_z_to_x=block_z_to_x)
# Mask out x-pad positions in keys
if (x_attn == 0).any():
pad_kv = torch.cat([(x_attn == 0),
torch.zeros(B, K_eff + L_y, device=device, dtype=torch.bool)], dim=1)
mask = mask.clone()
mask.masked_fill_(pad_kv[:, None, None, :], -1e9)
transformer = inner.model
lm_head = inner.get_output_embeddings()
out = transformer(inputs_embeds=full_embeds, attention_mask=mask,
use_cache=False, return_dict=True)
logits_all = lm_head(out.last_hidden_state)
logits_y = logits_all[:, P + K_eff - 1: P + K_eff - 1 + L_y, :] if K_eff > 0 else \
logits_all[:, P - 1: P - 1 + L_y, :]
else:
logits_y, _, _ = forward_with_latent(
model, x_ids, x_attn, y_ids, projector, K_eff,
block_y_to_x=True, block_z_to_x=block_z_to_x,
)
pred = logits_y.argmax(dim=-1)
# Shifted: logits at t predict token at t (already aligned by forward_with_latent).
correct = ((pred == y_ids) * y_mask).sum().item()
n = y_mask.sum().item()
total_correct += correct
total += n
if len(sample_texts) < 3:
t = tokenizer.decode(pred[0].clamp(min=0), skip_special_tokens=True)
sample_texts.append(t[:200])
return {
"condition": condition,
"K": K_eff,
"tok_acc": total_correct / max(total, 1),
"n_tokens": total,
"sample_preds": sample_texts,
}
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--config", required=True)
p.add_argument("--n", type=int, default=200)
p.add_argument("--K", type=int, default=None)
p.add_argument("--out", default=None)
args = p.parse_args()
with open(args.config) as f:
cfg = json.load(f)
K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1]
block_z_to_x = bool(cfg.get("block_z_to_x", False))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = Path(args.ckpt)
bcfg_nolora = BLTConfig(
base_model=cfg["base_model"], use_lora=False,
lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
lora_dropout=cfg["lora_dropout"],
lora_target_modules=tuple(cfg["lora_target_modules"]),
K_latents=K, block_y_to_x=cfg["block_y_to_x"],
block_z_to_x=block_z_to_x,
proj_init_scale=cfg["proj_init_scale"],
dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
gradient_checkpointing=False,
)
base_model, tokenizer = build_base(bcfg_nolora)
from peft import PeftModel
adapter_dir = ckpt / "model"
if (adapter_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(base_model, str(adapter_dir))
print(f"[load] adapter from {adapter_dir}")
else:
model = base_model
model.to(device).eval()
inner_base = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner_base.config.hidden_size
projector = LatentProjector(
d_model, init_scale=cfg["proj_init_scale"],
use_mlp=cfg.get("proj_mlp", False),
hidden_mult=cfg.get("proj_hidden_mult", 4),
).to(device).to(next(model.parameters()).dtype)
projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
projector.eval()
ds_name = cfg.get("dataset", "gsm8k")
val_ds = MATHDataset(split="test", max_examples=args.n) if ds_name.lower() == "math" \
else GSM8KDataset(split="test", max_examples=args.n)
loader = DataLoader(
val_ds, batch_size=8, shuffle=False,
collate_fn=lambda b: collate_batch(b, tokenizer,
max_prompt_len=cfg["max_prompt_len"],
max_answer_len=cfg["max_answer_len"]),
)
z_std = estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x)
print(f"[z_std] {z_std:.4f}")
results = {}
t0 = time.time()
for cond in ["normal", "random", "zero"]:
r = teacher_forced_accuracy(model, projector, tokenizer, loader, device, K,
condition=cond, z_std=z_std,
block_z_to_x=block_z_to_x, seed=0)
results[cond] = r
print(f"[{cond}] tok_acc={r['tok_acc']:.4f} elapsed={time.time()-t0:.0f}s")
summary = {
"ckpt": str(ckpt), "K": K, "n": args.n, "z_std": z_std,
"block_z_to_x_at_train_and_eval": block_z_to_x,
"results": results,
"delta_tokacc_normal_minus_random": results["normal"]["tok_acc"] - results["random"]["tok_acc"],
"delta_tokacc_normal_minus_zero": results["normal"]["tok_acc"] - results["zero"]["tok_acc"],
}
out = args.out or str(ckpt / "ablation_teacher_forced.json")
Path(out).write_text(json.dumps(summary, indent=2))
print(f"[written] {out}")
print(f"Δ_random_tok = {summary['delta_tokacc_normal_minus_random']:+.4f}")
print(f"Δ_zero_tok = {summary['delta_tokacc_normal_minus_zero']:+.4f}")
if __name__ == "__main__":
main()