Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# COMPRESSION NAVIGATOR Β· extended + annotated edition
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# An LLM is a lossy codec for text. Training compresses a corpus into weights;
|
| 5 |
+
# a forward pass decompresses a continuation. These five tools let you watch
|
| 6 |
+
# that decompression happen and poke at where facts physically live.
|
| 7 |
+
#
|
| 8 |
+
# The five tabs are not toys invented here - each one is a real mechanistic-
|
| 9 |
+
# interpretability technique you'll find in papers:
|
| 10 |
+
#
|
| 11 |
+
# 1. Decompress = LOGIT LENS (nostalgebraist, 2020)
|
| 12 |
+
# 2. Triangulate = EMBEDDING NEIGHBOURS (the geometry of the vocab)
|
| 13 |
+
# 3. Re-route = ACTIVATION STEERING (ActAdd / repr. engineering)
|
| 14 |
+
# 4. Diff = CROSS-MODEL ALIGNMENT (compare checkpoints by depth)
|
| 15 |
+
# 5. Causal trace = ACTIVATION PATCHING (ROME, Meng et al., 2022)
|
| 16 |
+
#
|
| 17 |
+
# WHY THE GLASS-BOX MODELS MATTER
|
| 18 |
+
# -------------------------------
|
| 19 |
+
# On a real model (gpt2) you never know the ground truth, so you can't tell
|
| 20 |
+
# whether a tool is *correct* or just producing plausible-looking output.
|
| 21 |
+
# This file ships two models whose internals you fully specify, so you can
|
| 22 |
+
# check each tool against a known answer:
|
| 23 |
+
#
|
| 24 |
+
# "handmade" - facts stored as a LOOKUP TABLE keyed on the prompt string.
|
| 25 |
+
# The computation happens in a side channel (string match),
|
| 26 |
+
# NOT in the residual stream. Lesson: such a model is almost
|
| 27 |
+
# invisible to residual-stream interpretability. Logit lens
|
| 28 |
+
# sees a sudden jump with no build-up; causal tracing finds
|
| 29 |
+
# nothing, because corrupting activations doesn't touch the
|
| 30 |
+
# string match. This is a real and underappreciated *limit*
|
| 31 |
+
# of these methods.
|
| 32 |
+
#
|
| 33 |
+
# "glassbox" - facts stored the way real transformers store them: as
|
| 34 |
+
# key->value writes into the RESIDUAL STREAM (Geva et al.'s
|
| 35 |
+
# "MLPs are key-value memories", which is exactly what ROME
|
| 36 |
+
# edits). Because the fact flows through activations, ALL five
|
| 37 |
+
# tools light up correctly - and you can verify they report
|
| 38 |
+
# the layer you actually put the fact in. This is a unit-test
|
| 39 |
+
# harness for interpretability code.
|
| 40 |
+
#
|
| 41 |
+
# Run order suggestion: glassbox -> handmade -> gpt2
|
| 42 |
+
# glassbox shows what "correct" looks like; handmade shows a failure mode;
|
| 43 |
+
# gpt2 shows the fuzzy, distributed real thing.
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
import math
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
import gradio as gr
|
| 51 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 52 |
+
|
| 53 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
+
DTYPE = torch.float32
|
| 55 |
+
MODELS = {} # name -> (model, tokenizer) cache
|
| 56 |
+
STATE = {"name": None} # currently loaded model name
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# =============================================================================
|
| 60 |
+
# A tiny shared tokenizer for both glass-box models.
|
| 61 |
+
# Case is CANONICALISED to lowercase everywhere (this fixes a real bug in the
|
| 62 |
+
# original: "Paris" from a pinned fact and "paris" from the Markov table became
|
| 63 |
+
# two different vocab entries, so the boosted token and the *tracked* token
|
| 64 |
+
# silently diverged - every neighbour read cos=0.000 and every tracked prob 0).
|
| 65 |
+
# =============================================================================
|
| 66 |
+
class FakeBatchEncoding(dict):
|
| 67 |
+
def to(self, device): # let callers do tok(...).to(DEVICE) safely
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SimpleTok:
|
| 72 |
+
"""Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map)."""
|
| 73 |
+
is_fast = False
|
| 74 |
+
|
| 75 |
+
def __init__(self, stoi, itos):
|
| 76 |
+
self.stoi, self.itos = stoi, itos
|
| 77 |
+
self.eos_token_id = stoi["."] # period doubles as end-of-sequence
|
| 78 |
+
|
| 79 |
+
def _ids(self, text):
|
| 80 |
+
words = text.lower().replace(".", " .").split()
|
| 81 |
+
return [self.stoi.get(w, self.stoi["<s>"]) for w in words]
|
| 82 |
+
|
| 83 |
+
def __call__(self, text, return_tensors=None, return_offsets_mapping=False):
|
| 84 |
+
ids = self._ids(text) or [self.stoi["<s>"]]
|
| 85 |
+
return FakeBatchEncoding(
|
| 86 |
+
input_ids=torch.tensor([ids]),
|
| 87 |
+
attention_mask=torch.ones(1, len(ids), dtype=torch.long),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def encode(self, text, add_special_tokens=False):
|
| 91 |
+
return self._ids(text)
|
| 92 |
+
|
| 93 |
+
def decode(self, ids, skip_special_tokens=False):
|
| 94 |
+
out = []
|
| 95 |
+
for i in ids:
|
| 96 |
+
w = self.itos.get(int(i), "?")
|
| 97 |
+
if skip_special_tokens and w in ("<pad>", "<s>"):
|
| 98 |
+
continue
|
| 99 |
+
out.append(w)
|
| 100 |
+
return " ".join(out)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class _Out:
|
| 104 |
+
"""Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states."""
|
| 105 |
+
def __init__(self, logits, hidden_states):
|
| 106 |
+
self.logits = logits
|
| 107 |
+
self.hidden_states = hidden_states
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_):
|
| 111 |
+
"""Minimal greedy decode so the steering tab works on the toy models too
|
| 112 |
+
(the originals had no .generate, so that tab crashed on 'handmade')."""
|
| 113 |
+
ids = input_ids
|
| 114 |
+
for _ in range(int(max_new_tokens)):
|
| 115 |
+
nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1)
|
| 116 |
+
ids = torch.cat([ids, nxt], dim=1)
|
| 117 |
+
if pad_token_id is not None and int(nxt.item()) == int(pad_token_id):
|
| 118 |
+
break
|
| 119 |
+
return ids
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# =============================================================================
|
| 123 |
+
# MODEL 1 - "handmade": facts as a LOOKUP TABLE (the side-channel glass box)
|
| 124 |
+
# -----------------------------------------------------------------------------
|
| 125 |
+
# Embeddings are the identity matrix (each token is its own one-hot). The two
|
| 126 |
+
# "layers" don't read the residual stream in a meaningful linear way:
|
| 127 |
+
# - MemoryBlock matches the *decoded prompt string* and boosts the answer.
|
| 128 |
+
# - MarkovBlock adds a hand-built bigram transition for the last token.
|
| 129 |
+
# Because MemoryBlock keys on the prompt TEXT, not on activations, this is a
|
| 130 |
+
# deliberate demonstration of a model that residual-stream interpretability
|
| 131 |
+
# cannot see. Use it as the "what failure looks like" control.
|
| 132 |
+
# =============================================================================
|
| 133 |
+
PINNED = { # answers are lowercase now (bug fix)
|
| 134 |
+
"the capital of france is": " paris",
|
| 135 |
+
"the eiffel tower is in": " paris",
|
| 136 |
+
"two plus two equals": " four",
|
| 137 |
+
}
|
| 138 |
+
MARKOV = {
|
| 139 |
+
"<s>": {"the": 3, "i": 2, "a": 1},
|
| 140 |
+
"the": {"city": 2, "tower": 2, "answer": 1},
|
| 141 |
+
"i": {"think": 2, "am": 1},
|
| 142 |
+
"a": {"model": 2, "city": 1},
|
| 143 |
+
"city": {"of": 3, "is": 1},
|
| 144 |
+
"of": {"light": 2, "paris": 1},
|
| 145 |
+
"tower": {"is": 3},
|
| 146 |
+
"is": {"in": 2, "a": 1},
|
| 147 |
+
"in": {"paris": 2, "france": 1},
|
| 148 |
+
"model": {"is": 2},
|
| 149 |
+
"think": {"the": 2},
|
| 150 |
+
"paris": {".": 1},
|
| 151 |
+
"france": {".": 1},
|
| 152 |
+
"light": {".": 1},
|
| 153 |
+
"four": {".": 1},
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _build_handmade_vocab():
|
| 158 |
+
toks, seen = ["<pad>", "<s>", "."], {"<pad>", "<s>", "."}
|
| 159 |
+
def add(w):
|
| 160 |
+
if w not in seen:
|
| 161 |
+
toks.append(w); seen.add(w)
|
| 162 |
+
for v in PINNED.values():
|
| 163 |
+
add(v.strip())
|
| 164 |
+
for w, nxts in MARKOV.items():
|
| 165 |
+
add(w)
|
| 166 |
+
for x in nxts:
|
| 167 |
+
add(x)
|
| 168 |
+
for k in PINNED:
|
| 169 |
+
for w in k.split():
|
| 170 |
+
add(w)
|
| 171 |
+
return toks
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
HM_VOCAB = _build_handmade_vocab()
|
| 175 |
+
HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)}
|
| 176 |
+
HM_ITOS = {i: w for w, i in HM_STOI.items()}
|
| 177 |
+
HM_V = len(HM_VOCAB)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class _MemoryBlock(nn.Module):
|
| 181 |
+
"""If the decoded prompt ends with a pinned key, slam the answer logit.
|
| 182 |
+
NOTE: this reads prompt_ids (the string), not x - that's the whole point."""
|
| 183 |
+
def forward(self, x, prompt_ids=None):
|
| 184 |
+
out = x.clone()
|
| 185 |
+
if prompt_ids is not None:
|
| 186 |
+
text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip()
|
| 187 |
+
for key, ans in PINNED.items():
|
| 188 |
+
if text.endswith(key):
|
| 189 |
+
out[0, -1, HM_STOI[ans.strip()]] += 12.0
|
| 190 |
+
return (out,)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class _MarkovBlock(nn.Module):
|
| 194 |
+
"""Add a hand-built bigram transition row for the last token."""
|
| 195 |
+
def __init__(self):
|
| 196 |
+
super().__init__()
|
| 197 |
+
T = torch.zeros(HM_V, HM_V)
|
| 198 |
+
for w, nxts in MARKOV.items():
|
| 199 |
+
if w in HM_STOI:
|
| 200 |
+
tot = sum(nxts.values())
|
| 201 |
+
for x, wt in nxts.items():
|
| 202 |
+
if x in HM_STOI:
|
| 203 |
+
T[HM_STOI[w], HM_STOI[x]] = wt / tot
|
| 204 |
+
self.register_buffer("T", T)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, prompt_ids=None):
|
| 207 |
+
out = x.clone()
|
| 208 |
+
if prompt_ids:
|
| 209 |
+
out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])]
|
| 210 |
+
return (out,)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class _HMTransformer(nn.Module):
|
| 214 |
+
def __init__(self):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.wte = nn.Embedding(HM_V, HM_V)
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
self.wte.weight.copy_(torch.eye(HM_V)) # one-hot embeddings
|
| 219 |
+
self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()])
|
| 220 |
+
self.ln_f = nn.Identity()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class HandmadeModel(nn.Module):
|
| 224 |
+
def __init__(self):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.transformer = _HMTransformer()
|
| 227 |
+
self.head = nn.Linear(HM_V, HM_V, bias=False)
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
self.head.weight.copy_(torch.eye(HM_V)) # identity unembed
|
| 230 |
+
self.tok = SimpleTok(HM_STOI, HM_ITOS)
|
| 231 |
+
|
| 232 |
+
def get_input_embeddings(self): return self.transformer.wte
|
| 233 |
+
def get_output_embeddings(self): return self.head
|
| 234 |
+
def generate(self, input_ids=None, attention_mask=None, **kw):
|
| 235 |
+
return _greedy_generate(self, input_ids, **kw)
|
| 236 |
+
|
| 237 |
+
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
|
| 238 |
+
ids = input_ids[0].tolist()
|
| 239 |
+
x = self.transformer.wte(input_ids).float()
|
| 240 |
+
hs = [x]; h = x
|
| 241 |
+
for blk in self.transformer.h:
|
| 242 |
+
(h,) = blk(h, prompt_ids=ids); hs.append(h)
|
| 243 |
+
logits = self.head(self.transformer.ln_f(h))
|
| 244 |
+
return _Out(logits, tuple(hs) if output_hidden_states else None)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# =============================================================================
|
| 248 |
+
# MODEL 2 - "glassbox": facts as RESIDUAL-STREAM key->value writes
|
| 249 |
+
# -----------------------------------------------------------------------------
|
| 250 |
+
# This is the model the original was missing. It stores facts the way real
|
| 251 |
+
# transformers do, so every tool works AND can be checked against ground truth.
|
| 252 |
+
#
|
| 253 |
+
# Vocab + structured embeddings (d=32). Country and its capital deliberately
|
| 254 |
+
# SHARE an embedding dimension, so the neighbours tool finds real geometry
|
| 255 |
+
# (paris is near france).
|
| 256 |
+
#
|
| 257 |
+
# Four layers:
|
| 258 |
+
# L0 subject site : (identity here) the residual the trace will restore
|
| 259 |
+
# L1 pool/attention : copies subject signal from earlier positions -> last
|
| 260 |
+
# L2 fact MLP : key(subject+relation) -> relu -> value(answer dir) <- ROME edits this kind of layer
|
| 261 |
+
# L3 cleanup : identity
|
| 262 |
+
#
|
| 263 |
+
# Ground truth you can verify:
|
| 264 |
+
# - logit lens: the answer is INVISIBLE until L2, then appears. Compare with
|
| 265 |
+
# handmade (sudden, no build-up) and gpt2 (fuzzy, spread over many layers).
|
| 266 |
+
# - causal trace: corrupting the subject and restoring layer by layer peaks
|
| 267 |
+
# at L0 - because L1's "attention" re-reads the restored subject. That is
|
| 268 |
+
# the ROME story: the causal site is an early layer at the SUBJECT token.
|
| 269 |
+
# - steering / neighbours: both operate on real directions, so both work.
|
| 270 |
+
# =============================================================================
|
| 271 |
+
GB_D = 32
|
| 272 |
+
GB_TOKS = ["<pad>", "<s>", ".", "the", "capital", "of", "is", "in",
|
| 273 |
+
"france", "germany", "japan", "paris", "berlin", "tokyo"]
|
| 274 |
+
GB_STOI = {w: i for i, w in enumerate(GB_TOKS)}
|
| 275 |
+
GB_ITOS = {i: w for w, i in GB_STOI.items()}
|
| 276 |
+
GB_V = len(GB_TOKS)
|
| 277 |
+
GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def _build_gb_embeddings():
|
| 281 |
+
E = torch.zeros(GB_V, GB_D)
|
| 282 |
+
def setd(tok, pairs):
|
| 283 |
+
for d, v in pairs:
|
| 284 |
+
E[GB_STOI[tok], d] = v
|
| 285 |
+
# country/capital pairs share their first dim -> positive cosine (geometry!)
|
| 286 |
+
setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)])
|
| 287 |
+
setd("paris", [(0, 0.8), (2, 0.9), (21, 0.5)])
|
| 288 |
+
setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)])
|
| 289 |
+
setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)])
|
| 290 |
+
setd("japan", [(6, 1.0), (7, 0.6), (24, 0.5)])
|
| 291 |
+
setd("tokyo", [(6, 0.8), (8, 0.9), (25, 0.5)])
|
| 292 |
+
setd("is", [(9, 1.0), (26, 0.4)]) # the relation marker
|
| 293 |
+
for i, t in enumerate(GB_TOKS): # give fillers an id
|
| 294 |
+
if E[i].abs().sum() == 0:
|
| 295 |
+
E[i, 10 + i % 6] = 1.0
|
| 296 |
+
return E / (E.norm(dim=-1, keepdim=True) + 1e-9) # unit rows
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
GB_E = _build_gb_embeddings()
|
| 300 |
+
GB_SUBJ = torch.zeros(GB_D, GB_D) # projector onto subject dims 0..8
|
| 301 |
+
for _d in range(9):
|
| 302 |
+
GB_SUBJ[_d, _d] = 1.0
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class _GBIdent(nn.Module):
|
| 306 |
+
def forward(self, x, prompt_ids=None):
|
| 307 |
+
return (x.clone(),)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class _GBPool(nn.Module):
|
| 311 |
+
"""Toy 'attention': sum the subject-projected residual of all earlier
|
| 312 |
+
positions into the last position. Corrupting the subject earlier shows up
|
| 313 |
+
here; restoring the subject BEFORE this layer is what makes the trace
|
| 314 |
+
recover - that is why the causal peak lands at L0, not L1."""
|
| 315 |
+
def forward(self, x, prompt_ids=None):
|
| 316 |
+
out = x.clone()
|
| 317 |
+
if x.shape[1] > 1:
|
| 318 |
+
pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0)
|
| 319 |
+
out[0, -1] = out[0, -1] + 0.9 * pooled
|
| 320 |
+
return (out,)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class _GBFactMLP(nn.Module):
|
| 324 |
+
"""Geva-style key->value memory. W_in rows are (subject+relation) keys;
|
| 325 |
+
relu gates which fact fires; W_out columns are answer unembed directions.
|
| 326 |
+
This is structurally the exact layer ROME rewrites to edit a fact."""
|
| 327 |
+
def __init__(self):
|
| 328 |
+
super().__init__()
|
| 329 |
+
Win = torch.zeros(len(GB_FACTS), GB_D)
|
| 330 |
+
Wout = torch.zeros(GB_D, len(GB_FACTS))
|
| 331 |
+
rel = GB_E[GB_STOI["is"]]
|
| 332 |
+
for k, (s, a) in enumerate(GB_FACTS):
|
| 333 |
+
key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel
|
| 334 |
+
Win[k] = key / key.norm()
|
| 335 |
+
Wout[:, k] = GB_E[GB_STOI[a]] # write answer direction
|
| 336 |
+
self.register_buffer("Win", Win)
|
| 337 |
+
self.register_buffer("Wout", Wout)
|
| 338 |
+
self.bias, self.gain = 0.85, 6.0 # tuned: clean p~0.5, corrupt p~0.07
|
| 339 |
+
|
| 340 |
+
def forward(self, x, prompt_ids=None):
|
| 341 |
+
out = x.clone()
|
| 342 |
+
pre = F.relu(self.Win @ out[0, -1] - self.bias)
|
| 343 |
+
out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre)
|
| 344 |
+
return (out,)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class _GBTransformer(nn.Module):
|
| 348 |
+
def __init__(self):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.wte = nn.Embedding(GB_V, GB_D)
|
| 351 |
+
with torch.no_grad():
|
| 352 |
+
self.wte.weight.copy_(GB_E)
|
| 353 |
+
self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()])
|
| 354 |
+
self.ln_f = nn.Identity()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class GlassBoxModel(nn.Module):
|
| 358 |
+
def __init__(self):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.transformer = _GBTransformer()
|
| 361 |
+
self.head = nn.Linear(GB_D, GB_V, bias=False)
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
self.head.weight.copy_(GB_E) # tied unembed
|
| 364 |
+
self.tok = SimpleTok(GB_STOI, GB_ITOS)
|
| 365 |
+
|
| 366 |
+
def get_input_embeddings(self): return self.transformer.wte
|
| 367 |
+
def get_output_embeddings(self): return self.head
|
| 368 |
+
def generate(self, input_ids=None, attention_mask=None, **kw):
|
| 369 |
+
return _greedy_generate(self, input_ids, **kw)
|
| 370 |
+
|
| 371 |
+
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
|
| 372 |
+
ids = input_ids[0].tolist()
|
| 373 |
+
x = self.transformer.wte(input_ids).float()
|
| 374 |
+
hs = [x]; h = x
|
| 375 |
+
for blk in self.transformer.h:
|
| 376 |
+
(h,) = blk(h, prompt_ids=ids); hs.append(h)
|
| 377 |
+
logits = self.head(self.transformer.ln_f(h))
|
| 378 |
+
return _Out(logits, tuple(hs) if output_hidden_states else None)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# =============================================================================
|
| 382 |
+
# REAL MODELS - resolve the architecture-specific module paths
|
| 383 |
+
# =============================================================================
|
| 384 |
+
def _resolve(model, paths):
|
| 385 |
+
for path in paths:
|
| 386 |
+
obj, ok = model, True
|
| 387 |
+
for part in path.split("."):
|
| 388 |
+
if hasattr(obj, part):
|
| 389 |
+
obj = getattr(obj, part)
|
| 390 |
+
else:
|
| 391 |
+
ok = False; break
|
| 392 |
+
if ok:
|
| 393 |
+
return obj
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def get_blocks(model):
|
| 398 |
+
blocks = _resolve(model, ["transformer.h", "model.layers",
|
| 399 |
+
"gpt_neox.layers", "model.decoder.layers"])
|
| 400 |
+
if blocks is None:
|
| 401 |
+
raise RuntimeError("Could not locate transformer blocks.")
|
| 402 |
+
return blocks
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def get_final_norm(model):
|
| 406 |
+
norm = _resolve(model, ["transformer.ln_f", "model.norm",
|
| 407 |
+
"gpt_neox.final_layer_norm",
|
| 408 |
+
"model.decoder.final_layer_norm"])
|
| 409 |
+
return norm if norm is not None else (lambda x: x)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def get_head(model):
|
| 413 |
+
return model.get_output_embeddings()
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def get_handles(name):
|
| 417 |
+
if name not in MODELS:
|
| 418 |
+
if name == "handmade":
|
| 419 |
+
m = HandmadeModel().eval(); MODELS[name] = (m, m.tok)
|
| 420 |
+
elif name == "glassbox":
|
| 421 |
+
m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok)
|
| 422 |
+
else:
|
| 423 |
+
tok = AutoTokenizer.from_pretrained(name)
|
| 424 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 425 |
+
name, torch_dtype=DTYPE).to(DEVICE).eval()
|
| 426 |
+
MODELS[name] = (model, tok)
|
| 427 |
+
return MODELS[name]
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def load_model(name):
|
| 431 |
+
name = name.strip()
|
| 432 |
+
model, _ = get_handles(name)
|
| 433 |
+
STATE["name"] = name
|
| 434 |
+
return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model)))
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# =============================================================================
|
| 438 |
+
# Shared readout: project every layer's last-token residual to a vocab dist.
|
| 439 |
+
# =============================================================================
|
| 440 |
+
@torch.no_grad()
|
| 441 |
+
def layer_distributions(model, tok, prompt):
|
| 442 |
+
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
|
| 443 |
+
out = model(**inputs, output_hidden_states=True)
|
| 444 |
+
hs = out.hidden_states
|
| 445 |
+
norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states)
|
| 446 |
+
dists = []
|
| 447 |
+
for i, layer_hs in enumerate(hs):
|
| 448 |
+
vec = layer_hs[0, -1].to(DTYPE)
|
| 449 |
+
# HF convention: the LAST hidden_states entry is already post-ln_f,
|
| 450 |
+
# so skip norm there; apply ln_f to intermediates (logit-lens style).
|
| 451 |
+
logits = head(vec) if i == n - 1 else head(norm(vec))
|
| 452 |
+
dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1)))
|
| 453 |
+
return dists
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _entropy_bits(probs):
|
| 457 |
+
p = probs.clamp_min(1e-12)
|
| 458 |
+
return float(-(p * p.log()).sum() / math.log(2))
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# =============================================================================
|
| 462 |
+
# TAB 1 - LOGIT LENS: watch the answer condense out of the residual stream
|
| 463 |
+
# =============================================================================
|
| 464 |
+
@torch.no_grad()
|
| 465 |
+
def logit_lens(prompt, top_k, track):
|
| 466 |
+
if STATE["name"] is None:
|
| 467 |
+
return "Load a model first."
|
| 468 |
+
model, tok = get_handles(STATE["name"])
|
| 469 |
+
top_k = int(top_k)
|
| 470 |
+
tids = tok.encode(track, add_special_tokens=False) if track.strip() else []
|
| 471 |
+
tid = tids[0] if tids else None
|
| 472 |
+
dists = layer_distributions(model, tok, prompt)
|
| 473 |
+
header = "layer | top tokens (prob) | entropy" \
|
| 474 |
+
+ (" | p(%r)" % track if tid is not None else "")
|
| 475 |
+
lines = ["prompt: %r" % prompt, header, "-" * len(header)]
|
| 476 |
+
for label, probs in dists:
|
| 477 |
+
p, idx = probs.topk(top_k)
|
| 478 |
+
shown = " ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v)
|
| 479 |
+
for t, v in zip(idx.tolist(), p.tolist()))
|
| 480 |
+
row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs))
|
| 481 |
+
if tid is not None:
|
| 482 |
+
row += " | %.3f" % probs[tid].item()
|
| 483 |
+
lines.append(row)
|
| 484 |
+
return "\n".join(lines)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# =============================================================================
|
| 488 |
+
# TAB 2 - NEIGHBOURS: the geometry of the (un)embedding space
|
| 489 |
+
# =============================================================================
|
| 490 |
+
@torch.no_grad()
|
| 491 |
+
def neighbors(word, top_k):
|
| 492 |
+
if STATE["name"] is None:
|
| 493 |
+
return "Load a model first."
|
| 494 |
+
model, tok = get_handles(STATE["name"])
|
| 495 |
+
top_k = int(top_k)
|
| 496 |
+
ids = tok.encode(word, add_special_tokens=False)
|
| 497 |
+
if not ids:
|
| 498 |
+
return "Could not tokenize %r." % word
|
| 499 |
+
tid = ids[0]
|
| 500 |
+
W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1)
|
| 501 |
+
sims = W @ W[tid]
|
| 502 |
+
vals, idx = sims.topk(top_k + 1)
|
| 503 |
+
note = ""
|
| 504 |
+
if STATE["name"] == "handmade":
|
| 505 |
+
note = ("(handmade uses one-hot embeddings, so every token is "
|
| 506 |
+
"orthogonal -> all cosines are 0 by construction. This is the "
|
| 507 |
+
"tool telling the truth about a model with no vocab geometry.)\n")
|
| 508 |
+
lines = [note + "neighbours of %r:" % word]
|
| 509 |
+
for v, j in zip(vals.tolist(), idx.tolist()):
|
| 510 |
+
if j != tid:
|
| 511 |
+
lines.append(" %14r cos=%.3f" % (tok.decode([j]), v))
|
| 512 |
+
return "\n".join(lines[: top_k + 1])
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
# =============================================================================
|
| 516 |
+
# TAB 3 - STEERING: bend behaviour by adding a direction, no retraining
|
| 517 |
+
# =============================================================================
|
| 518 |
+
def _make_steer_hook(direction, alpha):
|
| 519 |
+
d = direction * alpha
|
| 520 |
+
def hook(module, inp, out):
|
| 521 |
+
if isinstance(out, tuple):
|
| 522 |
+
return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:]
|
| 523 |
+
return out + d.to(out.dtype).to(out.device)
|
| 524 |
+
return hook
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
@torch.no_grad()
|
| 528 |
+
def steer_generate(prompt, source, target, layer, alpha, max_new):
|
| 529 |
+
if STATE["name"] is None:
|
| 530 |
+
return "Load a model first.", ""
|
| 531 |
+
model, tok = get_handles(STATE["name"])
|
| 532 |
+
layer, max_new = int(layer), int(max_new)
|
| 533 |
+
emb = model.get_input_embeddings().weight
|
| 534 |
+
def first_emb(w):
|
| 535 |
+
ids = tok.encode(w, add_special_tokens=False)
|
| 536 |
+
return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE)
|
| 537 |
+
direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1)
|
| 538 |
+
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
|
| 539 |
+
gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)
|
| 540 |
+
base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
|
| 541 |
+
blocks = get_blocks(model)
|
| 542 |
+
layer = max(0, min(layer, len(blocks) - 1))
|
| 543 |
+
handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha))
|
| 544 |
+
try:
|
| 545 |
+
steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
|
| 546 |
+
finally:
|
| 547 |
+
handle.remove()
|
| 548 |
+
return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
# =============================================================================
|
| 552 |
+
# TAB 4 - DIFF: compare two models on one prompt, aligned by relative depth
|
| 553 |
+
# =============================================================================
|
| 554 |
+
@torch.no_grad()
|
| 555 |
+
def diff_models(name_a, name_b, prompt, target, top_k):
|
| 556 |
+
ma, ta = get_handles(name_a.strip())
|
| 557 |
+
mb, tb = get_handles(name_b.strip())
|
| 558 |
+
ida = ta.encode(target, add_special_tokens=False)
|
| 559 |
+
idb = tb.encode(target, add_special_tokens=False)
|
| 560 |
+
if not ida or not idb:
|
| 561 |
+
return "Could not tokenize target %r in both models." % target
|
| 562 |
+
ida, idb = ida[0], idb[0]
|
| 563 |
+
da = layer_distributions(ma, ta, prompt)
|
| 564 |
+
db = layer_distributions(mb, tb, prompt)
|
| 565 |
+
nA, nB = len(da) - 1, len(db) - 1
|
| 566 |
+
def top1(probs, tok):
|
| 567 |
+
v, i = probs.topk(1)
|
| 568 |
+
return "%r:%.2f" % (tok.decode([i.item()]), v.item())
|
| 569 |
+
lines = ["prompt: %r target: %r" % (prompt, target),
|
| 570 |
+
"%18s | %16s %6s | %16s %6s | %7s"
|
| 571 |
+
% ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")]
|
| 572 |
+
for i in range(nA + 1):
|
| 573 |
+
frac = (i / nA) if nA > 0 else 0.0
|
| 574 |
+
j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0
|
| 575 |
+
la, pa = da[i]; lb, pb = db[j]
|
| 576 |
+
a_t, b_t = pa[ida].item(), pb[idb].item()
|
| 577 |
+
lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f"
|
| 578 |
+
% ("%3.0f%% (%s/%s)" % (frac * 100, la, lb),
|
| 579 |
+
top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t))
|
| 580 |
+
return "\n".join(lines)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# =============================================================================
|
| 584 |
+
# TAB 5 - CAUSAL TRACE: corrupt the subject, restore each layer, find the site
|
| 585 |
+
# -----------------------------------------------------------------------------
|
| 586 |
+
# This is ROME's activation patching. We:
|
| 587 |
+
# 1. record clean activations and clean p(target)
|
| 588 |
+
# 2. add gaussian noise to the SUBJECT token embeddings -> corrupt p(target)
|
| 589 |
+
# 3. for each layer L: run corrupted, but force layer L's residual back to
|
| 590 |
+
# the clean values at the subject positions. How much p(target) recovers
|
| 591 |
+
# tells you how causally important layer L is. The peak is "the site".
|
| 592 |
+
# The glass-box gives a clean, verifiable peak; gpt2 gives a realistic band.
|
| 593 |
+
# =============================================================================
|
| 594 |
+
def _find_subject_positions(tok, input_ids, prompt, subject):
|
| 595 |
+
"""Locate subject token positions, with a path for slow (non-fast) toks."""
|
| 596 |
+
seq_len = input_ids.shape[1]
|
| 597 |
+
if getattr(tok, "is_fast", False):
|
| 598 |
+
enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True)
|
| 599 |
+
cs = prompt.find(subject)
|
| 600 |
+
if cs >= 0:
|
| 601 |
+
ce = cs + len(subject)
|
| 602 |
+
offs = enc["offset_mapping"][0].tolist()
|
| 603 |
+
pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce]
|
| 604 |
+
if pos:
|
| 605 |
+
return [p for p in pos if p != seq_len - 1], ""
|
| 606 |
+
else:
|
| 607 |
+
sub_ids = tok.encode(subject, add_special_tokens=False)
|
| 608 |
+
seq = input_ids[0].tolist()
|
| 609 |
+
pos = [i for i, t in enumerate(seq) if t in sub_ids]
|
| 610 |
+
if pos:
|
| 611 |
+
return [p for p in pos if p != seq_len - 1], ""
|
| 612 |
+
fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)]
|
| 613 |
+
return fb, "(subject not found; using fallback window)\n"
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
@torch.no_grad()
|
| 617 |
+
def causal_trace(prompt, subject, target, noise_scale, seed):
|
| 618 |
+
if STATE["name"] is None:
|
| 619 |
+
return "Load a model first."
|
| 620 |
+
model, tok = get_handles(STATE["name"])
|
| 621 |
+
seed, noise_scale = int(seed), float(noise_scale)
|
| 622 |
+
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
|
| 623 |
+
input_ids = inputs["input_ids"]
|
| 624 |
+
positions, note = _find_subject_positions(tok, input_ids, prompt, subject)
|
| 625 |
+
if not positions:
|
| 626 |
+
return note + "No valid subject positions."
|
| 627 |
+
target_ids = tok.encode(target, add_special_tokens=False)
|
| 628 |
+
if not target_ids:
|
| 629 |
+
return "Could not tokenize target %r." % target
|
| 630 |
+
tid = target_ids[0]
|
| 631 |
+
|
| 632 |
+
out_clean = model(**inputs, output_hidden_states=True)
|
| 633 |
+
clean_hs = out_clean.hidden_states
|
| 634 |
+
clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item()
|
| 635 |
+
|
| 636 |
+
emb_module = model.get_input_embeddings()
|
| 637 |
+
std = emb_module.weight.std().item()
|
| 638 |
+
hidden = emb_module.weight.shape[-1]
|
| 639 |
+
torch.manual_seed(seed)
|
| 640 |
+
noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std
|
| 641 |
+
|
| 642 |
+
def corrupt_hook(module, inp, out):
|
| 643 |
+
out = out.clone()
|
| 644 |
+
for k, p in enumerate(positions):
|
| 645 |
+
out[0, p] = out[0, p] + noise[k].to(out.dtype)
|
| 646 |
+
return out
|
| 647 |
+
|
| 648 |
+
h = emb_module.register_forward_hook(corrupt_hook)
|
| 649 |
+
corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
|
| 650 |
+
h.remove()
|
| 651 |
+
|
| 652 |
+
blocks, rows = get_blocks(model), []
|
| 653 |
+
for l in range(len(blocks)):
|
| 654 |
+
clean_layer_hs = clean_hs[l + 1][0]
|
| 655 |
+
def restore_hook(module, inp, out, _clean=clean_layer_hs):
|
| 656 |
+
if isinstance(out, tuple):
|
| 657 |
+
h0 = out[0].clone()
|
| 658 |
+
for p in positions:
|
| 659 |
+
h0[0, p] = _clean[p].to(h0.dtype)
|
| 660 |
+
return (h0,) + out[1:]
|
| 661 |
+
h0 = out.clone()
|
| 662 |
+
for p in positions:
|
| 663 |
+
h0[0, p] = _clean[p].to(h0.dtype)
|
| 664 |
+
return h0
|
| 665 |
+
h1 = emb_module.register_forward_hook(corrupt_hook)
|
| 666 |
+
h2 = blocks[l].register_forward_hook(restore_hook)
|
| 667 |
+
p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
|
| 668 |
+
h1.remove(); h2.remove()
|
| 669 |
+
rows.append((l, p_r))
|
| 670 |
+
|
| 671 |
+
denom = clean_p - corrupt_p
|
| 672 |
+
lines = [note + "prompt: %r" % prompt,
|
| 673 |
+
"subject: %r target: %r" % (subject, target),
|
| 674 |
+
"clean p=%.3f corrupt p=%.3f noise=%sx std" % (clean_p, corrupt_p, noise_scale),
|
| 675 |
+
"", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")]
|
| 676 |
+
best_l, best_r = 0, -1e9
|
| 677 |
+
for l, p_r in rows:
|
| 678 |
+
rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0
|
| 679 |
+
if rec > best_r:
|
| 680 |
+
best_r, best_l = rec, l
|
| 681 |
+
lines.append(" L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100))
|
| 682 |
+
lines.append("")
|
| 683 |
+
lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100))
|
| 684 |
+
if abs(denom) < 1e-6:
|
| 685 |
+
lines.append("# (corruption didn't move p(target): on 'handmade' this is "
|
| 686 |
+
"EXPECTED - the fact lives in a string match, not activations.)")
|
| 687 |
+
return "\n".join(lines)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# =============================================================================
|
| 691 |
+
# UI
|
| 692 |
+
# =============================================================================
|
| 693 |
+
INTRO = """
|
| 694 |
+
# Compression Navigator
|
| 695 |
+
**An LLM is a lossy codec for text.** Training compresses a corpus into weights;
|
| 696 |
+
a forward pass decompresses a continuation. These five tools let you watch that
|
| 697 |
+
decompression and find where facts physically live.
|
| 698 |
+
|
| 699 |
+
Each tab is a real interpretability technique: **logit lens, embedding
|
| 700 |
+
neighbours, activation steering, cross-model diff, and causal tracing (ROME).**
|
| 701 |
+
|
| 702 |
+
### Three models, on purpose
|
| 703 |
+
| name | how it stores facts | what it teaches |
|
| 704 |
+
|---|---|---|
|
| 705 |
+
| **`glassbox`** | keyβvalue writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source |
|
| 706 |
+
| **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability β a real limitation |
|
| 707 |
+
| **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like |
|
| 708 |
+
|
| 709 |
+
**Suggested order:** load `glassbox` first (see "correct"), then `handmade`
|
| 710 |
+
(see a failure mode), then `gpt2` (see reality). Type a name below and Load.
|
| 711 |
+
"""
|
| 712 |
+
|
| 713 |
+
with gr.Blocks(title="Compression Navigator") as demo:
|
| 714 |
+
gr.Markdown(INTRO)
|
| 715 |
+
with gr.Row():
|
| 716 |
+
model_name = gr.Textbox(value="glassbox", label="model name or HF id")
|
| 717 |
+
load_btn = gr.Button("Load", variant="primary")
|
| 718 |
+
load_status = gr.Markdown()
|
| 719 |
+
load_btn.click(load_model, inputs=model_name, outputs=load_status)
|
| 720 |
+
|
| 721 |
+
# ---- TAB 1 -------------------------------------------------------------
|
| 722 |
+
with gr.Tab("1 Β· Decompress (logit lens)"):
|
| 723 |
+
gr.Markdown("""
|
| 724 |
+
### Logit lens β watch the answer condense, layer by layer
|
| 725 |
+
**What it does:** takes the last-token residual at *every* layer and reads it
|
| 726 |
+
through the unembedding, as if the model had to answer right there. You see the
|
| 727 |
+
prediction form.
|
| 728 |
+
|
| 729 |
+
**How to read it:** each row is a layer. Watch your tracked token's probability
|
| 730 |
+
(right column) climb, and watch **entropy** (bits) fall as the model commits.
|
| 731 |
+
|
| 732 |
+
**Ground truth to check:**
|
| 733 |
+
- `glassbox` β `paris` is ~0 until **L2** (the fact-MLP), then jumps. Sharp and localised because you put it there.
|
| 734 |
+
- `handmade` β the answer appears suddenly with no build-up (it's a lookup, not a computation).
|
| 735 |
+
- `gpt2` β the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like.
|
| 736 |
+
""")
|
| 737 |
+
ll_prompt = gr.Textbox(value="the capital of france is", label="prompt")
|
| 738 |
+
with gr.Row():
|
| 739 |
+
ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer")
|
| 740 |
+
ll_track = gr.Textbox(value="paris", label="track this token's prob")
|
| 741 |
+
ll_out = gr.Textbox(label="output", lines=18)
|
| 742 |
+
gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out)
|
| 743 |
+
|
| 744 |
+
# ---- TAB 2 -------------------------------------------------------------
|
| 745 |
+
with gr.Tab("2 Β· Triangulate (neighbours)"):
|
| 746 |
+
gr.Markdown("""
|
| 747 |
+
### Neighbours β the geometry of the vocabulary
|
| 748 |
+
**What it does:** ranks tokens by cosine similarity of their unembedding rows.
|
| 749 |
+
Directions that point the same way are "near" in the model's compressed space.
|
| 750 |
+
|
| 751 |
+
**How to read it:** high cosine = the model treats these tokens as related.
|
| 752 |
+
|
| 753 |
+
**Ground truth to check:**
|
| 754 |
+
- `glassbox` β `paris` is near `france` (cos β 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design.
|
| 755 |
+
- `handmade` β **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here."
|
| 756 |
+
- `gpt2` β neighbours are messy but meaningful (casing variants, plurals, semantic kin).
|
| 757 |
+
""")
|
| 758 |
+
nb_word = gr.Textbox(value="paris", label="word")
|
| 759 |
+
nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours")
|
| 760 |
+
nb_out = gr.Textbox(label="output", lines=15)
|
| 761 |
+
gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out)
|
| 762 |
+
|
| 763 |
+
# ---- TAB 3 -------------------------------------------------------------
|
| 764 |
+
with gr.Tab("3 Β· Re-route (steering)"):
|
| 765 |
+
gr.Markdown("""
|
| 766 |
+
### Steering β bend behaviour with a direction, no retraining
|
| 767 |
+
**What it does:** builds the vector `emb(target) β emb(source)` and *adds* it to
|
| 768 |
+
a layer's output during generation. The model drifts from `source` toward
|
| 769 |
+
`target`. This is the cheap cousin of fine-tuning (ActAdd / representation
|
| 770 |
+
engineering).
|
| 771 |
+
|
| 772 |
+
**How to read it:** compare *baseline* vs *steered*. Raise **strength** until the
|
| 773 |
+
output flips; too high and it turns to noise (you've knocked the residual off
|
| 774 |
+
the manifold).
|
| 775 |
+
|
| 776 |
+
**Tips:** on `gpt2` try `from: Paris to: London` on the France prompt, layer
|
| 777 |
+
0β4, strength 6β14. On `glassbox`/`handmade` the vocab is tiny β steering is
|
| 778 |
+
mostly a mechanics demo there; the real lesson lives on `gpt2`.
|
| 779 |
+
""")
|
| 780 |
+
st_prompt = gr.Textbox(value="the capital of france is", label="prompt")
|
| 781 |
+
with gr.Row():
|
| 782 |
+
st_src = gr.Textbox(value="Paris", label="from")
|
| 783 |
+
st_tgt = gr.Textbox(value="London", label="to")
|
| 784 |
+
with gr.Row():
|
| 785 |
+
st_layer = gr.Slider(0, 11, value=2, step=1, label="layer")
|
| 786 |
+
st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength")
|
| 787 |
+
st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens")
|
| 788 |
+
st_base = gr.Textbox(label="baseline", lines=2)
|
| 789 |
+
st_out = gr.Textbox(label="steered", lines=3)
|
| 790 |
+
gr.Button("Run").click(steer_generate,
|
| 791 |
+
[st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max],
|
| 792 |
+
[st_base, st_out])
|
| 793 |
+
|
| 794 |
+
# ---- TAB 4 -------------------------------------------------------------
|
| 795 |
+
with gr.Tab("4 Β· Diff (align by depth)"):
|
| 796 |
+
gr.Markdown("""
|
| 797 |
+
### Diff β two models on one prompt, aligned by *relative* depth
|
| 798 |
+
**What it does:** runs the logit lens on model A and model B and lines their
|
| 799 |
+
layers up by percentage depth (0β100%), so you can compare a 2-layer toy with a
|
| 800 |
+
12-layer gpt2 side by side. `dp` is `p_B β p_A` for the target token.
|
| 801 |
+
|
| 802 |
+
**How to read it:** look at *where* on the depth axis each model commits to the
|
| 803 |
+
target. A localised model commits at one depth; a distributed one ramps up.
|
| 804 |
+
|
| 805 |
+
**Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp
|
| 806 |
+
through the middle while glassbox snaps on at its fact layer β the same fact,
|
| 807 |
+
two very different internal shapes.
|
| 808 |
+
""")
|
| 809 |
+
with gr.Row():
|
| 810 |
+
df_a = gr.Textbox(value="gpt2", label="model A")
|
| 811 |
+
df_b = gr.Textbox(value="glassbox", label="model B")
|
| 812 |
+
df_prompt = gr.Textbox(value="the capital of france is", label="prompt")
|
| 813 |
+
df_target = gr.Textbox(value="paris", label="target token")
|
| 814 |
+
df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)")
|
| 815 |
+
df_out = gr.Textbox(label="output", lines=16)
|
| 816 |
+
gr.Button("Run").click(diff_models,
|
| 817 |
+
[df_a, df_b, df_prompt, df_target, df_k], df_out)
|
| 818 |
+
|
| 819 |
+
# ---- TAB 5 -------------------------------------------------------------
|
| 820 |
+
with gr.Tab("5 Β· Causal trace (ROME)"):
|
| 821 |
+
gr.Markdown("""
|
| 822 |
+
### Causal trace β corrupt the subject, restore each layer, find the site
|
| 823 |
+
**What it does:** activation patching (Meng et al.'s ROME). It noises the
|
| 824 |
+
**subject** token, which breaks the prediction, then restores one layer at a
|
| 825 |
+
time and measures how much of the answer comes back. The layer that restores
|
| 826 |
+
the most is where the fact is *causally* computed.
|
| 827 |
+
|
| 828 |
+
**How to read it:** `recovery` β 100% means "restoring this layer is enough" β
|
| 829 |
+
the fact is read here. The peak line names the site.
|
| 830 |
+
|
| 831 |
+
**Ground truth to check:**
|
| 832 |
+
- `glassbox` β peak at **L0** (β100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism.
|
| 833 |
+
- `handmade` β `clean p` β `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson β patching can't see lookup behaviour.
|
| 834 |
+
- `gpt2` β a *band* of earlyβmiddle layers at the subject token light up, exactly as in the ROME paper.
|
| 835 |
+
""")
|
| 836 |
+
ct_prompt = gr.Textbox(value="the capital of france is", label="prompt")
|
| 837 |
+
ct_subject = gr.Textbox(value="france", label="subject to corrupt")
|
| 838 |
+
ct_target = gr.Textbox(value="paris", label="target token")
|
| 839 |
+
with gr.Row():
|
| 840 |
+
ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)")
|
| 841 |
+
ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed")
|
| 842 |
+
ct_out = gr.Textbox(label="output", lines=18)
|
| 843 |
+
gr.Button("Run").click(causal_trace,
|
| 844 |
+
[ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out)
|
| 845 |
+
|
| 846 |
+
gr.Markdown("""
|
| 847 |
+
---
|
| 848 |
+
### Where this goes next
|
| 849 |
+
- **Edit loop (the VINDEX bridge):** trace β pick the layer β apply a ROME/MEMIT rank-1 edit to that MLP β re-run the logit lens to confirm the new fact took *and* nothing else moved. The glass-box is the unit test for that pipeline before you trust it on a real model.
|
| 850 |
+
- **More glass-box facts / multi-hop:** add `"the currency of france is"` to force a second relation through the same subject, and watch the trace separate the two sites.
|
| 851 |
+
- **Attention + MLP key-value inspection:** Geva-style "what does this neuron write to the vocab" and per-head attribution.
|
| 852 |
+
- **Package as an HF Space** with this writeup as the README β it's a clean teaching artifact and a regression harness for interpretability code.
|
| 853 |
+
""")
|
| 854 |
+
|
| 855 |
+
demo.load(lambda: load_model("glassbox"), outputs=load_status)
|
| 856 |
+
|
| 857 |
+
if __name__ == "__main__":
|
| 858 |
+
demo.launch()
|