| """Induction smoke test — verify Gemma-2-2B does token-copying induction. |
| |
| Build sequences like [random tokens][A][B][random tokens][A] and check that the |
| top-1 logit at the final position is B. Reports accuracy over N probes. |
| |
| This is the critical Phase 1 gate: if baseline ICL accuracy is <50%, the project's |
| core assumption (Gemma-2-2B has functional induction) is wrong and we must escalate. |
| """ |
|
|
| import sys |
|
|
| import torch |
| from transformer_lens import HookedTransformer |
|
|
| N_PROBES = 200 |
| SEQ_LEN_BEFORE_AB = 30 |
| GAP_TOKENS = 80 |
| MIN_VOCAB_ID = 1000 |
| MAX_VOCAB_ID = 50000 |
|
|
|
|
| def build_probe(rng: torch.Generator, vocab_lo: int, vocab_hi: int) -> torch.Tensor: |
| prefix = torch.randint(vocab_lo, vocab_hi, (SEQ_LEN_BEFORE_AB,), generator=rng) |
| a = torch.randint(vocab_lo, vocab_hi, (1,), generator=rng) |
| b = torch.randint(vocab_lo, vocab_hi, (1,), generator=rng) |
| while b.item() == a.item(): |
| b = torch.randint(vocab_lo, vocab_hi, (1,), generator=rng) |
| gap = torch.randint(vocab_lo, vocab_hi, (GAP_TOKENS,), generator=rng) |
| |
| gap = torch.where(gap == a.item(), (gap + 1).clamp(max=vocab_hi - 1), gap) |
| return torch.cat([prefix, a, b, gap, a]), a.item(), b.item() |
|
|
|
|
| def main() -> int: |
| model = HookedTransformer.from_pretrained("google/gemma-2-2b", dtype=torch.bfloat16) |
| model.to("cuda") |
| model.eval() |
|
|
| vocab_lo = MIN_VOCAB_ID |
| vocab_hi = min(MAX_VOCAB_ID, model.cfg.d_vocab - 1) |
| rng = torch.Generator().manual_seed(0) |
|
|
| correct = 0 |
| top5_correct = 0 |
| with torch.no_grad(): |
| for i in range(N_PROBES): |
| seq, a, b = build_probe(rng, vocab_lo, vocab_hi) |
| seq = seq.unsqueeze(0).to("cuda") |
| logits = model(seq)[0, -1] |
| top5 = torch.topk(logits, 5).indices.tolist() |
| if top5[0] == b: |
| correct += 1 |
| if b in top5: |
| top5_correct += 1 |
| if i < 5: |
| print(f" probe {i}: A={a} B={b} top5={top5}") |
|
|
| acc = correct / N_PROBES |
| top5_acc = top5_correct / N_PROBES |
| print(f"\nInduction top-1 accuracy: {acc:.1%} ({correct}/{N_PROBES})") |
| print(f"Induction top-5 accuracy: {top5_acc:.1%} ({top5_correct}/{N_PROBES})") |
|
|
| if acc < 0.5: |
| print("\nFAIL: baseline ICL <50% — induction may not be working on Gemma-2-2B", file=sys.stderr) |
| return 1 |
| print("\nPASS") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|