File size: 9,021 Bytes
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""Teacher-forced 3-way z-ablation eval.

Uses `forward_with_latent` directly (which respects the `block_z_to_x` flag via
the 4D mask) and computes per-token accuracy on y under three conditions:
  normal-z    : z computed by the M-step loop
  random-z    : z input replaced by Gaussian noise (matched to z_std)
  zero-z      : K=0 (no z slots at all)

This is *teacher-forced* accuracy (the model sees gold y prefix when predicting
each token), so it's not the same metric as autoregressive `generate` accuracy.
But it directly tests "does z's content carry the signal y needs?" — which is
exactly the question the leak hypothesis is about. Autoregressive generation
with `block_z_to_x` would require non-trivial changes to `generate_with_latent`
(its KV-cache path doesn't use the 4D mask). For the principled experiment,
teacher-forced acc is the cleaner signal.

Usage:
    python -m experiments.blt_reasoner.scripts.ablate_teacher_forced \
        --ckpt /path/to/final --config <config.json> --n 200 --K 8 \
        --out /path/to/ablation_tf.json
"""
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path
from typing import Optional

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from ..data import GSM8KDataset, MATHDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, forward_with_latent


@torch.no_grad()
def estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x):
    all_z = []
    for i, b in enumerate(loader):
        if i >= 4: break
        _, z, _ = forward_with_latent(
            model, b.x_ids.to(device), b.x_attn.to(device),
            b.y_ids.to(device), projector, K,
            block_y_to_x=True, block_z_to_x=block_z_to_x,
        )
        all_z.append(z.float().cpu())
    return float(torch.cat(all_z, 0).std().item())


def teacher_forced_accuracy(
    model, projector, tokenizer, loader, device, K,
    *, condition: str, z_std: float, block_z_to_x: bool, seed: int = 0,
) -> dict:
    """Per-token accuracy on y, scored token-by-token vs gold y under
    teacher forcing (the model sees gold prefix for each prediction).
    """
    inner = model.get_base_model() if hasattr(model, "get_base_model") else model
    d_model = inner.config.hidden_size
    proj_dtype = next(projector.parameters()).dtype
    total_correct = 0
    total = 0
    sample_texts = []
    for batch in loader:
        x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device)
        y_ids = batch.y_ids.to(device); y_mask = batch.y_attn.to(device)
        B = x_ids.size(0)

        override = None
        K_eff = K
        if condition == "random":
            g = torch.Generator(device=device).manual_seed(seed + total)
            override = torch.randn(B, K, d_model, device=device, generator=g, dtype=proj_dtype) * z_std
        elif condition == "zero":
            override = torch.zeros(B, 0, d_model, device=device, dtype=proj_dtype)
            K_eff = 0

        if override is not None:
            # Run pass 2 directly with the override z (skipping the M-step loop).
            # forward_with_latent doesn't expose override_z, so we mimic it manually.
            embed_in = inner.get_input_embeddings()
            x_embeds = embed_in(x_ids)
            y_embeds = embed_in(y_ids)
            P = x_ids.size(1); L_y = y_ids.size(1)
            full_embeds = torch.cat([x_embeds, override.to(y_embeds.dtype), y_embeds], dim=1)
            from ..model import build_blt_mask
            mask = build_blt_mask(B, P, K_eff, L_y, device=device, dtype=full_embeds.dtype,
                                   block_y_to_x=True, block_z_to_x=block_z_to_x)
            # Mask out x-pad positions in keys
            if (x_attn == 0).any():
                pad_kv = torch.cat([(x_attn == 0),
                                     torch.zeros(B, K_eff + L_y, device=device, dtype=torch.bool)], dim=1)
                mask = mask.clone()
                mask.masked_fill_(pad_kv[:, None, None, :], -1e9)
            transformer = inner.model
            lm_head = inner.get_output_embeddings()
            out = transformer(inputs_embeds=full_embeds, attention_mask=mask,
                              use_cache=False, return_dict=True)
            logits_all = lm_head(out.last_hidden_state)
            logits_y = logits_all[:, P + K_eff - 1: P + K_eff - 1 + L_y, :] if K_eff > 0 else \
                       logits_all[:, P - 1: P - 1 + L_y, :]
        else:
            logits_y, _, _ = forward_with_latent(
                model, x_ids, x_attn, y_ids, projector, K_eff,
                block_y_to_x=True, block_z_to_x=block_z_to_x,
            )

        pred = logits_y.argmax(dim=-1)
        # Shifted: logits at t predict token at t (already aligned by forward_with_latent).
        correct = ((pred == y_ids) * y_mask).sum().item()
        n = y_mask.sum().item()
        total_correct += correct
        total += n
        if len(sample_texts) < 3:
            t = tokenizer.decode(pred[0].clamp(min=0), skip_special_tokens=True)
            sample_texts.append(t[:200])
    return {
        "condition": condition,
        "K": K_eff,
        "tok_acc": total_correct / max(total, 1),
        "n_tokens": total,
        "sample_preds": sample_texts,
    }


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", required=True)
    p.add_argument("--config", required=True)
    p.add_argument("--n", type=int, default=200)
    p.add_argument("--K", type=int, default=None)
    p.add_argument("--out", default=None)
    args = p.parse_args()
    with open(args.config) as f:
        cfg = json.load(f)
    K = args.K if args.K is not None else cfg.get("K_curriculum", [[0, 8]])[-1][1]
    block_z_to_x = bool(cfg.get("block_z_to_x", False))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = Path(args.ckpt)
    bcfg_nolora = BLTConfig(
        base_model=cfg["base_model"], use_lora=False,
        lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
        lora_dropout=cfg["lora_dropout"],
        lora_target_modules=tuple(cfg["lora_target_modules"]),
        K_latents=K, block_y_to_x=cfg["block_y_to_x"],
        block_z_to_x=block_z_to_x,
        proj_init_scale=cfg["proj_init_scale"],
        dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
        gradient_checkpointing=False,
    )
    base_model, tokenizer = build_base(bcfg_nolora)
    from peft import PeftModel
    adapter_dir = ckpt / "model"
    if (adapter_dir / "adapter_config.json").exists():
        model = PeftModel.from_pretrained(base_model, str(adapter_dir))
        print(f"[load] adapter from {adapter_dir}")
    else:
        model = base_model
    model.to(device).eval()
    inner_base = model.get_base_model() if hasattr(model, "get_base_model") else model
    d_model = inner_base.config.hidden_size
    projector = LatentProjector(
        d_model, init_scale=cfg["proj_init_scale"],
        use_mlp=cfg.get("proj_mlp", False),
        hidden_mult=cfg.get("proj_hidden_mult", 4),
    ).to(device).to(next(model.parameters()).dtype)
    projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
    projector.eval()

    ds_name = cfg.get("dataset", "gsm8k")
    val_ds = MATHDataset(split="test", max_examples=args.n) if ds_name.lower() == "math" \
             else GSM8KDataset(split="test", max_examples=args.n)
    loader = DataLoader(
        val_ds, batch_size=8, shuffle=False,
        collate_fn=lambda b: collate_batch(b, tokenizer,
                                            max_prompt_len=cfg["max_prompt_len"],
                                            max_answer_len=cfg["max_answer_len"]),
    )
    z_std = estimate_z_std(model, projector, tokenizer, loader, device, K, block_z_to_x)
    print(f"[z_std] {z_std:.4f}")
    results = {}
    t0 = time.time()
    for cond in ["normal", "random", "zero"]:
        r = teacher_forced_accuracy(model, projector, tokenizer, loader, device, K,
                                     condition=cond, z_std=z_std,
                                     block_z_to_x=block_z_to_x, seed=0)
        results[cond] = r
        print(f"[{cond}] tok_acc={r['tok_acc']:.4f} elapsed={time.time()-t0:.0f}s")

    summary = {
        "ckpt": str(ckpt), "K": K, "n": args.n, "z_std": z_std,
        "block_z_to_x_at_train_and_eval": block_z_to_x,
        "results": results,
        "delta_tokacc_normal_minus_random": results["normal"]["tok_acc"] - results["random"]["tok_acc"],
        "delta_tokacc_normal_minus_zero": results["normal"]["tok_acc"] - results["zero"]["tok_acc"],
    }
    out = args.out or str(ckpt / "ablation_teacher_forced.json")
    Path(out).write_text(json.dumps(summary, indent=2))
    print(f"[written] {out}")
    print(f"Δ_random_tok = {summary['delta_tokacc_normal_minus_random']:+.4f}")
    print(f"Δ_zero_tok   = {summary['delta_tokacc_normal_minus_zero']:+.4f}")


if __name__ == "__main__":
    main()