Chris4K commited on
Commit
35acee3
Β·
verified Β·
1 Parent(s): 92e1b3e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +858 -0
app.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # COMPRESSION NAVIGATOR Β· extended + annotated edition
3
+ # =============================================================================
4
+ # An LLM is a lossy codec for text. Training compresses a corpus into weights;
5
+ # a forward pass decompresses a continuation. These five tools let you watch
6
+ # that decompression happen and poke at where facts physically live.
7
+ #
8
+ # The five tabs are not toys invented here - each one is a real mechanistic-
9
+ # interpretability technique you'll find in papers:
10
+ #
11
+ # 1. Decompress = LOGIT LENS (nostalgebraist, 2020)
12
+ # 2. Triangulate = EMBEDDING NEIGHBOURS (the geometry of the vocab)
13
+ # 3. Re-route = ACTIVATION STEERING (ActAdd / repr. engineering)
14
+ # 4. Diff = CROSS-MODEL ALIGNMENT (compare checkpoints by depth)
15
+ # 5. Causal trace = ACTIVATION PATCHING (ROME, Meng et al., 2022)
16
+ #
17
+ # WHY THE GLASS-BOX MODELS MATTER
18
+ # -------------------------------
19
+ # On a real model (gpt2) you never know the ground truth, so you can't tell
20
+ # whether a tool is *correct* or just producing plausible-looking output.
21
+ # This file ships two models whose internals you fully specify, so you can
22
+ # check each tool against a known answer:
23
+ #
24
+ # "handmade" - facts stored as a LOOKUP TABLE keyed on the prompt string.
25
+ # The computation happens in a side channel (string match),
26
+ # NOT in the residual stream. Lesson: such a model is almost
27
+ # invisible to residual-stream interpretability. Logit lens
28
+ # sees a sudden jump with no build-up; causal tracing finds
29
+ # nothing, because corrupting activations doesn't touch the
30
+ # string match. This is a real and underappreciated *limit*
31
+ # of these methods.
32
+ #
33
+ # "glassbox" - facts stored the way real transformers store them: as
34
+ # key->value writes into the RESIDUAL STREAM (Geva et al.'s
35
+ # "MLPs are key-value memories", which is exactly what ROME
36
+ # edits). Because the fact flows through activations, ALL five
37
+ # tools light up correctly - and you can verify they report
38
+ # the layer you actually put the fact in. This is a unit-test
39
+ # harness for interpretability code.
40
+ #
41
+ # Run order suggestion: glassbox -> handmade -> gpt2
42
+ # glassbox shows what "correct" looks like; handmade shows a failure mode;
43
+ # gpt2 shows the fuzzy, distributed real thing.
44
+ # =============================================================================
45
+
46
+ import math
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ import gradio as gr
51
+ from transformers import AutoModelForCausalLM, AutoTokenizer
52
+
53
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
54
+ DTYPE = torch.float32
55
+ MODELS = {} # name -> (model, tokenizer) cache
56
+ STATE = {"name": None} # currently loaded model name
57
+
58
+
59
+ # =============================================================================
60
+ # A tiny shared tokenizer for both glass-box models.
61
+ # Case is CANONICALISED to lowercase everywhere (this fixes a real bug in the
62
+ # original: "Paris" from a pinned fact and "paris" from the Markov table became
63
+ # two different vocab entries, so the boosted token and the *tracked* token
64
+ # silently diverged - every neighbour read cos=0.000 and every tracked prob 0).
65
+ # =============================================================================
66
+ class FakeBatchEncoding(dict):
67
+ def to(self, device): # let callers do tok(...).to(DEVICE) safely
68
+ return self
69
+
70
+
71
+ class SimpleTok:
72
+ """Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map)."""
73
+ is_fast = False
74
+
75
+ def __init__(self, stoi, itos):
76
+ self.stoi, self.itos = stoi, itos
77
+ self.eos_token_id = stoi["."] # period doubles as end-of-sequence
78
+
79
+ def _ids(self, text):
80
+ words = text.lower().replace(".", " .").split()
81
+ return [self.stoi.get(w, self.stoi["<s>"]) for w in words]
82
+
83
+ def __call__(self, text, return_tensors=None, return_offsets_mapping=False):
84
+ ids = self._ids(text) or [self.stoi["<s>"]]
85
+ return FakeBatchEncoding(
86
+ input_ids=torch.tensor([ids]),
87
+ attention_mask=torch.ones(1, len(ids), dtype=torch.long),
88
+ )
89
+
90
+ def encode(self, text, add_special_tokens=False):
91
+ return self._ids(text)
92
+
93
+ def decode(self, ids, skip_special_tokens=False):
94
+ out = []
95
+ for i in ids:
96
+ w = self.itos.get(int(i), "?")
97
+ if skip_special_tokens and w in ("<pad>", "<s>"):
98
+ continue
99
+ out.append(w)
100
+ return " ".join(out)
101
+
102
+
103
+ class _Out:
104
+ """Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states."""
105
+ def __init__(self, logits, hidden_states):
106
+ self.logits = logits
107
+ self.hidden_states = hidden_states
108
+
109
+
110
+ def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_):
111
+ """Minimal greedy decode so the steering tab works on the toy models too
112
+ (the originals had no .generate, so that tab crashed on 'handmade')."""
113
+ ids = input_ids
114
+ for _ in range(int(max_new_tokens)):
115
+ nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1)
116
+ ids = torch.cat([ids, nxt], dim=1)
117
+ if pad_token_id is not None and int(nxt.item()) == int(pad_token_id):
118
+ break
119
+ return ids
120
+
121
+
122
+ # =============================================================================
123
+ # MODEL 1 - "handmade": facts as a LOOKUP TABLE (the side-channel glass box)
124
+ # -----------------------------------------------------------------------------
125
+ # Embeddings are the identity matrix (each token is its own one-hot). The two
126
+ # "layers" don't read the residual stream in a meaningful linear way:
127
+ # - MemoryBlock matches the *decoded prompt string* and boosts the answer.
128
+ # - MarkovBlock adds a hand-built bigram transition for the last token.
129
+ # Because MemoryBlock keys on the prompt TEXT, not on activations, this is a
130
+ # deliberate demonstration of a model that residual-stream interpretability
131
+ # cannot see. Use it as the "what failure looks like" control.
132
+ # =============================================================================
133
+ PINNED = { # answers are lowercase now (bug fix)
134
+ "the capital of france is": " paris",
135
+ "the eiffel tower is in": " paris",
136
+ "two plus two equals": " four",
137
+ }
138
+ MARKOV = {
139
+ "<s>": {"the": 3, "i": 2, "a": 1},
140
+ "the": {"city": 2, "tower": 2, "answer": 1},
141
+ "i": {"think": 2, "am": 1},
142
+ "a": {"model": 2, "city": 1},
143
+ "city": {"of": 3, "is": 1},
144
+ "of": {"light": 2, "paris": 1},
145
+ "tower": {"is": 3},
146
+ "is": {"in": 2, "a": 1},
147
+ "in": {"paris": 2, "france": 1},
148
+ "model": {"is": 2},
149
+ "think": {"the": 2},
150
+ "paris": {".": 1},
151
+ "france": {".": 1},
152
+ "light": {".": 1},
153
+ "four": {".": 1},
154
+ }
155
+
156
+
157
+ def _build_handmade_vocab():
158
+ toks, seen = ["<pad>", "<s>", "."], {"<pad>", "<s>", "."}
159
+ def add(w):
160
+ if w not in seen:
161
+ toks.append(w); seen.add(w)
162
+ for v in PINNED.values():
163
+ add(v.strip())
164
+ for w, nxts in MARKOV.items():
165
+ add(w)
166
+ for x in nxts:
167
+ add(x)
168
+ for k in PINNED:
169
+ for w in k.split():
170
+ add(w)
171
+ return toks
172
+
173
+
174
+ HM_VOCAB = _build_handmade_vocab()
175
+ HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)}
176
+ HM_ITOS = {i: w for w, i in HM_STOI.items()}
177
+ HM_V = len(HM_VOCAB)
178
+
179
+
180
+ class _MemoryBlock(nn.Module):
181
+ """If the decoded prompt ends with a pinned key, slam the answer logit.
182
+ NOTE: this reads prompt_ids (the string), not x - that's the whole point."""
183
+ def forward(self, x, prompt_ids=None):
184
+ out = x.clone()
185
+ if prompt_ids is not None:
186
+ text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip()
187
+ for key, ans in PINNED.items():
188
+ if text.endswith(key):
189
+ out[0, -1, HM_STOI[ans.strip()]] += 12.0
190
+ return (out,)
191
+
192
+
193
+ class _MarkovBlock(nn.Module):
194
+ """Add a hand-built bigram transition row for the last token."""
195
+ def __init__(self):
196
+ super().__init__()
197
+ T = torch.zeros(HM_V, HM_V)
198
+ for w, nxts in MARKOV.items():
199
+ if w in HM_STOI:
200
+ tot = sum(nxts.values())
201
+ for x, wt in nxts.items():
202
+ if x in HM_STOI:
203
+ T[HM_STOI[w], HM_STOI[x]] = wt / tot
204
+ self.register_buffer("T", T)
205
+
206
+ def forward(self, x, prompt_ids=None):
207
+ out = x.clone()
208
+ if prompt_ids:
209
+ out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])]
210
+ return (out,)
211
+
212
+
213
+ class _HMTransformer(nn.Module):
214
+ def __init__(self):
215
+ super().__init__()
216
+ self.wte = nn.Embedding(HM_V, HM_V)
217
+ with torch.no_grad():
218
+ self.wte.weight.copy_(torch.eye(HM_V)) # one-hot embeddings
219
+ self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()])
220
+ self.ln_f = nn.Identity()
221
+
222
+
223
+ class HandmadeModel(nn.Module):
224
+ def __init__(self):
225
+ super().__init__()
226
+ self.transformer = _HMTransformer()
227
+ self.head = nn.Linear(HM_V, HM_V, bias=False)
228
+ with torch.no_grad():
229
+ self.head.weight.copy_(torch.eye(HM_V)) # identity unembed
230
+ self.tok = SimpleTok(HM_STOI, HM_ITOS)
231
+
232
+ def get_input_embeddings(self): return self.transformer.wte
233
+ def get_output_embeddings(self): return self.head
234
+ def generate(self, input_ids=None, attention_mask=None, **kw):
235
+ return _greedy_generate(self, input_ids, **kw)
236
+
237
+ def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
238
+ ids = input_ids[0].tolist()
239
+ x = self.transformer.wte(input_ids).float()
240
+ hs = [x]; h = x
241
+ for blk in self.transformer.h:
242
+ (h,) = blk(h, prompt_ids=ids); hs.append(h)
243
+ logits = self.head(self.transformer.ln_f(h))
244
+ return _Out(logits, tuple(hs) if output_hidden_states else None)
245
+
246
+
247
+ # =============================================================================
248
+ # MODEL 2 - "glassbox": facts as RESIDUAL-STREAM key->value writes
249
+ # -----------------------------------------------------------------------------
250
+ # This is the model the original was missing. It stores facts the way real
251
+ # transformers do, so every tool works AND can be checked against ground truth.
252
+ #
253
+ # Vocab + structured embeddings (d=32). Country and its capital deliberately
254
+ # SHARE an embedding dimension, so the neighbours tool finds real geometry
255
+ # (paris is near france).
256
+ #
257
+ # Four layers:
258
+ # L0 subject site : (identity here) the residual the trace will restore
259
+ # L1 pool/attention : copies subject signal from earlier positions -> last
260
+ # L2 fact MLP : key(subject+relation) -> relu -> value(answer dir) <- ROME edits this kind of layer
261
+ # L3 cleanup : identity
262
+ #
263
+ # Ground truth you can verify:
264
+ # - logit lens: the answer is INVISIBLE until L2, then appears. Compare with
265
+ # handmade (sudden, no build-up) and gpt2 (fuzzy, spread over many layers).
266
+ # - causal trace: corrupting the subject and restoring layer by layer peaks
267
+ # at L0 - because L1's "attention" re-reads the restored subject. That is
268
+ # the ROME story: the causal site is an early layer at the SUBJECT token.
269
+ # - steering / neighbours: both operate on real directions, so both work.
270
+ # =============================================================================
271
+ GB_D = 32
272
+ GB_TOKS = ["<pad>", "<s>", ".", "the", "capital", "of", "is", "in",
273
+ "france", "germany", "japan", "paris", "berlin", "tokyo"]
274
+ GB_STOI = {w: i for i, w in enumerate(GB_TOKS)}
275
+ GB_ITOS = {i: w for w, i in GB_STOI.items()}
276
+ GB_V = len(GB_TOKS)
277
+ GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")]
278
+
279
+
280
+ def _build_gb_embeddings():
281
+ E = torch.zeros(GB_V, GB_D)
282
+ def setd(tok, pairs):
283
+ for d, v in pairs:
284
+ E[GB_STOI[tok], d] = v
285
+ # country/capital pairs share their first dim -> positive cosine (geometry!)
286
+ setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)])
287
+ setd("paris", [(0, 0.8), (2, 0.9), (21, 0.5)])
288
+ setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)])
289
+ setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)])
290
+ setd("japan", [(6, 1.0), (7, 0.6), (24, 0.5)])
291
+ setd("tokyo", [(6, 0.8), (8, 0.9), (25, 0.5)])
292
+ setd("is", [(9, 1.0), (26, 0.4)]) # the relation marker
293
+ for i, t in enumerate(GB_TOKS): # give fillers an id
294
+ if E[i].abs().sum() == 0:
295
+ E[i, 10 + i % 6] = 1.0
296
+ return E / (E.norm(dim=-1, keepdim=True) + 1e-9) # unit rows
297
+
298
+
299
+ GB_E = _build_gb_embeddings()
300
+ GB_SUBJ = torch.zeros(GB_D, GB_D) # projector onto subject dims 0..8
301
+ for _d in range(9):
302
+ GB_SUBJ[_d, _d] = 1.0
303
+
304
+
305
+ class _GBIdent(nn.Module):
306
+ def forward(self, x, prompt_ids=None):
307
+ return (x.clone(),)
308
+
309
+
310
+ class _GBPool(nn.Module):
311
+ """Toy 'attention': sum the subject-projected residual of all earlier
312
+ positions into the last position. Corrupting the subject earlier shows up
313
+ here; restoring the subject BEFORE this layer is what makes the trace
314
+ recover - that is why the causal peak lands at L0, not L1."""
315
+ def forward(self, x, prompt_ids=None):
316
+ out = x.clone()
317
+ if x.shape[1] > 1:
318
+ pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0)
319
+ out[0, -1] = out[0, -1] + 0.9 * pooled
320
+ return (out,)
321
+
322
+
323
+ class _GBFactMLP(nn.Module):
324
+ """Geva-style key->value memory. W_in rows are (subject+relation) keys;
325
+ relu gates which fact fires; W_out columns are answer unembed directions.
326
+ This is structurally the exact layer ROME rewrites to edit a fact."""
327
+ def __init__(self):
328
+ super().__init__()
329
+ Win = torch.zeros(len(GB_FACTS), GB_D)
330
+ Wout = torch.zeros(GB_D, len(GB_FACTS))
331
+ rel = GB_E[GB_STOI["is"]]
332
+ for k, (s, a) in enumerate(GB_FACTS):
333
+ key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel
334
+ Win[k] = key / key.norm()
335
+ Wout[:, k] = GB_E[GB_STOI[a]] # write answer direction
336
+ self.register_buffer("Win", Win)
337
+ self.register_buffer("Wout", Wout)
338
+ self.bias, self.gain = 0.85, 6.0 # tuned: clean p~0.5, corrupt p~0.07
339
+
340
+ def forward(self, x, prompt_ids=None):
341
+ out = x.clone()
342
+ pre = F.relu(self.Win @ out[0, -1] - self.bias)
343
+ out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre)
344
+ return (out,)
345
+
346
+
347
+ class _GBTransformer(nn.Module):
348
+ def __init__(self):
349
+ super().__init__()
350
+ self.wte = nn.Embedding(GB_V, GB_D)
351
+ with torch.no_grad():
352
+ self.wte.weight.copy_(GB_E)
353
+ self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()])
354
+ self.ln_f = nn.Identity()
355
+
356
+
357
+ class GlassBoxModel(nn.Module):
358
+ def __init__(self):
359
+ super().__init__()
360
+ self.transformer = _GBTransformer()
361
+ self.head = nn.Linear(GB_D, GB_V, bias=False)
362
+ with torch.no_grad():
363
+ self.head.weight.copy_(GB_E) # tied unembed
364
+ self.tok = SimpleTok(GB_STOI, GB_ITOS)
365
+
366
+ def get_input_embeddings(self): return self.transformer.wte
367
+ def get_output_embeddings(self): return self.head
368
+ def generate(self, input_ids=None, attention_mask=None, **kw):
369
+ return _greedy_generate(self, input_ids, **kw)
370
+
371
+ def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
372
+ ids = input_ids[0].tolist()
373
+ x = self.transformer.wte(input_ids).float()
374
+ hs = [x]; h = x
375
+ for blk in self.transformer.h:
376
+ (h,) = blk(h, prompt_ids=ids); hs.append(h)
377
+ logits = self.head(self.transformer.ln_f(h))
378
+ return _Out(logits, tuple(hs) if output_hidden_states else None)
379
+
380
+
381
+ # =============================================================================
382
+ # REAL MODELS - resolve the architecture-specific module paths
383
+ # =============================================================================
384
+ def _resolve(model, paths):
385
+ for path in paths:
386
+ obj, ok = model, True
387
+ for part in path.split("."):
388
+ if hasattr(obj, part):
389
+ obj = getattr(obj, part)
390
+ else:
391
+ ok = False; break
392
+ if ok:
393
+ return obj
394
+ return None
395
+
396
+
397
+ def get_blocks(model):
398
+ blocks = _resolve(model, ["transformer.h", "model.layers",
399
+ "gpt_neox.layers", "model.decoder.layers"])
400
+ if blocks is None:
401
+ raise RuntimeError("Could not locate transformer blocks.")
402
+ return blocks
403
+
404
+
405
+ def get_final_norm(model):
406
+ norm = _resolve(model, ["transformer.ln_f", "model.norm",
407
+ "gpt_neox.final_layer_norm",
408
+ "model.decoder.final_layer_norm"])
409
+ return norm if norm is not None else (lambda x: x)
410
+
411
+
412
+ def get_head(model):
413
+ return model.get_output_embeddings()
414
+
415
+
416
+ def get_handles(name):
417
+ if name not in MODELS:
418
+ if name == "handmade":
419
+ m = HandmadeModel().eval(); MODELS[name] = (m, m.tok)
420
+ elif name == "glassbox":
421
+ m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok)
422
+ else:
423
+ tok = AutoTokenizer.from_pretrained(name)
424
+ model = AutoModelForCausalLM.from_pretrained(
425
+ name, torch_dtype=DTYPE).to(DEVICE).eval()
426
+ MODELS[name] = (model, tok)
427
+ return MODELS[name]
428
+
429
+
430
+ def load_model(name):
431
+ name = name.strip()
432
+ model, _ = get_handles(name)
433
+ STATE["name"] = name
434
+ return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model)))
435
+
436
+
437
+ # =============================================================================
438
+ # Shared readout: project every layer's last-token residual to a vocab dist.
439
+ # =============================================================================
440
+ @torch.no_grad()
441
+ def layer_distributions(model, tok, prompt):
442
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
443
+ out = model(**inputs, output_hidden_states=True)
444
+ hs = out.hidden_states
445
+ norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states)
446
+ dists = []
447
+ for i, layer_hs in enumerate(hs):
448
+ vec = layer_hs[0, -1].to(DTYPE)
449
+ # HF convention: the LAST hidden_states entry is already post-ln_f,
450
+ # so skip norm there; apply ln_f to intermediates (logit-lens style).
451
+ logits = head(vec) if i == n - 1 else head(norm(vec))
452
+ dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1)))
453
+ return dists
454
+
455
+
456
+ def _entropy_bits(probs):
457
+ p = probs.clamp_min(1e-12)
458
+ return float(-(p * p.log()).sum() / math.log(2))
459
+
460
+
461
+ # =============================================================================
462
+ # TAB 1 - LOGIT LENS: watch the answer condense out of the residual stream
463
+ # =============================================================================
464
+ @torch.no_grad()
465
+ def logit_lens(prompt, top_k, track):
466
+ if STATE["name"] is None:
467
+ return "Load a model first."
468
+ model, tok = get_handles(STATE["name"])
469
+ top_k = int(top_k)
470
+ tids = tok.encode(track, add_special_tokens=False) if track.strip() else []
471
+ tid = tids[0] if tids else None
472
+ dists = layer_distributions(model, tok, prompt)
473
+ header = "layer | top tokens (prob) | entropy" \
474
+ + (" | p(%r)" % track if tid is not None else "")
475
+ lines = ["prompt: %r" % prompt, header, "-" * len(header)]
476
+ for label, probs in dists:
477
+ p, idx = probs.topk(top_k)
478
+ shown = " ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v)
479
+ for t, v in zip(idx.tolist(), p.tolist()))
480
+ row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs))
481
+ if tid is not None:
482
+ row += " | %.3f" % probs[tid].item()
483
+ lines.append(row)
484
+ return "\n".join(lines)
485
+
486
+
487
+ # =============================================================================
488
+ # TAB 2 - NEIGHBOURS: the geometry of the (un)embedding space
489
+ # =============================================================================
490
+ @torch.no_grad()
491
+ def neighbors(word, top_k):
492
+ if STATE["name"] is None:
493
+ return "Load a model first."
494
+ model, tok = get_handles(STATE["name"])
495
+ top_k = int(top_k)
496
+ ids = tok.encode(word, add_special_tokens=False)
497
+ if not ids:
498
+ return "Could not tokenize %r." % word
499
+ tid = ids[0]
500
+ W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1)
501
+ sims = W @ W[tid]
502
+ vals, idx = sims.topk(top_k + 1)
503
+ note = ""
504
+ if STATE["name"] == "handmade":
505
+ note = ("(handmade uses one-hot embeddings, so every token is "
506
+ "orthogonal -> all cosines are 0 by construction. This is the "
507
+ "tool telling the truth about a model with no vocab geometry.)\n")
508
+ lines = [note + "neighbours of %r:" % word]
509
+ for v, j in zip(vals.tolist(), idx.tolist()):
510
+ if j != tid:
511
+ lines.append(" %14r cos=%.3f" % (tok.decode([j]), v))
512
+ return "\n".join(lines[: top_k + 1])
513
+
514
+
515
+ # =============================================================================
516
+ # TAB 3 - STEERING: bend behaviour by adding a direction, no retraining
517
+ # =============================================================================
518
+ def _make_steer_hook(direction, alpha):
519
+ d = direction * alpha
520
+ def hook(module, inp, out):
521
+ if isinstance(out, tuple):
522
+ return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:]
523
+ return out + d.to(out.dtype).to(out.device)
524
+ return hook
525
+
526
+
527
+ @torch.no_grad()
528
+ def steer_generate(prompt, source, target, layer, alpha, max_new):
529
+ if STATE["name"] is None:
530
+ return "Load a model first.", ""
531
+ model, tok = get_handles(STATE["name"])
532
+ layer, max_new = int(layer), int(max_new)
533
+ emb = model.get_input_embeddings().weight
534
+ def first_emb(w):
535
+ ids = tok.encode(w, add_special_tokens=False)
536
+ return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE)
537
+ direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1)
538
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
539
+ gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)
540
+ base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
541
+ blocks = get_blocks(model)
542
+ layer = max(0, min(layer, len(blocks) - 1))
543
+ handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha))
544
+ try:
545
+ steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
546
+ finally:
547
+ handle.remove()
548
+ return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered)
549
+
550
+
551
+ # =============================================================================
552
+ # TAB 4 - DIFF: compare two models on one prompt, aligned by relative depth
553
+ # =============================================================================
554
+ @torch.no_grad()
555
+ def diff_models(name_a, name_b, prompt, target, top_k):
556
+ ma, ta = get_handles(name_a.strip())
557
+ mb, tb = get_handles(name_b.strip())
558
+ ida = ta.encode(target, add_special_tokens=False)
559
+ idb = tb.encode(target, add_special_tokens=False)
560
+ if not ida or not idb:
561
+ return "Could not tokenize target %r in both models." % target
562
+ ida, idb = ida[0], idb[0]
563
+ da = layer_distributions(ma, ta, prompt)
564
+ db = layer_distributions(mb, tb, prompt)
565
+ nA, nB = len(da) - 1, len(db) - 1
566
+ def top1(probs, tok):
567
+ v, i = probs.topk(1)
568
+ return "%r:%.2f" % (tok.decode([i.item()]), v.item())
569
+ lines = ["prompt: %r target: %r" % (prompt, target),
570
+ "%18s | %16s %6s | %16s %6s | %7s"
571
+ % ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")]
572
+ for i in range(nA + 1):
573
+ frac = (i / nA) if nA > 0 else 0.0
574
+ j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0
575
+ la, pa = da[i]; lb, pb = db[j]
576
+ a_t, b_t = pa[ida].item(), pb[idb].item()
577
+ lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f"
578
+ % ("%3.0f%% (%s/%s)" % (frac * 100, la, lb),
579
+ top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t))
580
+ return "\n".join(lines)
581
+
582
+
583
+ # =============================================================================
584
+ # TAB 5 - CAUSAL TRACE: corrupt the subject, restore each layer, find the site
585
+ # -----------------------------------------------------------------------------
586
+ # This is ROME's activation patching. We:
587
+ # 1. record clean activations and clean p(target)
588
+ # 2. add gaussian noise to the SUBJECT token embeddings -> corrupt p(target)
589
+ # 3. for each layer L: run corrupted, but force layer L's residual back to
590
+ # the clean values at the subject positions. How much p(target) recovers
591
+ # tells you how causally important layer L is. The peak is "the site".
592
+ # The glass-box gives a clean, verifiable peak; gpt2 gives a realistic band.
593
+ # =============================================================================
594
+ def _find_subject_positions(tok, input_ids, prompt, subject):
595
+ """Locate subject token positions, with a path for slow (non-fast) toks."""
596
+ seq_len = input_ids.shape[1]
597
+ if getattr(tok, "is_fast", False):
598
+ enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True)
599
+ cs = prompt.find(subject)
600
+ if cs >= 0:
601
+ ce = cs + len(subject)
602
+ offs = enc["offset_mapping"][0].tolist()
603
+ pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce]
604
+ if pos:
605
+ return [p for p in pos if p != seq_len - 1], ""
606
+ else:
607
+ sub_ids = tok.encode(subject, add_special_tokens=False)
608
+ seq = input_ids[0].tolist()
609
+ pos = [i for i, t in enumerate(seq) if t in sub_ids]
610
+ if pos:
611
+ return [p for p in pos if p != seq_len - 1], ""
612
+ fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)]
613
+ return fb, "(subject not found; using fallback window)\n"
614
+
615
+
616
+ @torch.no_grad()
617
+ def causal_trace(prompt, subject, target, noise_scale, seed):
618
+ if STATE["name"] is None:
619
+ return "Load a model first."
620
+ model, tok = get_handles(STATE["name"])
621
+ seed, noise_scale = int(seed), float(noise_scale)
622
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
623
+ input_ids = inputs["input_ids"]
624
+ positions, note = _find_subject_positions(tok, input_ids, prompt, subject)
625
+ if not positions:
626
+ return note + "No valid subject positions."
627
+ target_ids = tok.encode(target, add_special_tokens=False)
628
+ if not target_ids:
629
+ return "Could not tokenize target %r." % target
630
+ tid = target_ids[0]
631
+
632
+ out_clean = model(**inputs, output_hidden_states=True)
633
+ clean_hs = out_clean.hidden_states
634
+ clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item()
635
+
636
+ emb_module = model.get_input_embeddings()
637
+ std = emb_module.weight.std().item()
638
+ hidden = emb_module.weight.shape[-1]
639
+ torch.manual_seed(seed)
640
+ noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std
641
+
642
+ def corrupt_hook(module, inp, out):
643
+ out = out.clone()
644
+ for k, p in enumerate(positions):
645
+ out[0, p] = out[0, p] + noise[k].to(out.dtype)
646
+ return out
647
+
648
+ h = emb_module.register_forward_hook(corrupt_hook)
649
+ corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
650
+ h.remove()
651
+
652
+ blocks, rows = get_blocks(model), []
653
+ for l in range(len(blocks)):
654
+ clean_layer_hs = clean_hs[l + 1][0]
655
+ def restore_hook(module, inp, out, _clean=clean_layer_hs):
656
+ if isinstance(out, tuple):
657
+ h0 = out[0].clone()
658
+ for p in positions:
659
+ h0[0, p] = _clean[p].to(h0.dtype)
660
+ return (h0,) + out[1:]
661
+ h0 = out.clone()
662
+ for p in positions:
663
+ h0[0, p] = _clean[p].to(h0.dtype)
664
+ return h0
665
+ h1 = emb_module.register_forward_hook(corrupt_hook)
666
+ h2 = blocks[l].register_forward_hook(restore_hook)
667
+ p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
668
+ h1.remove(); h2.remove()
669
+ rows.append((l, p_r))
670
+
671
+ denom = clean_p - corrupt_p
672
+ lines = [note + "prompt: %r" % prompt,
673
+ "subject: %r target: %r" % (subject, target),
674
+ "clean p=%.3f corrupt p=%.3f noise=%sx std" % (clean_p, corrupt_p, noise_scale),
675
+ "", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")]
676
+ best_l, best_r = 0, -1e9
677
+ for l, p_r in rows:
678
+ rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0
679
+ if rec > best_r:
680
+ best_r, best_l = rec, l
681
+ lines.append(" L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100))
682
+ lines.append("")
683
+ lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100))
684
+ if abs(denom) < 1e-6:
685
+ lines.append("# (corruption didn't move p(target): on 'handmade' this is "
686
+ "EXPECTED - the fact lives in a string match, not activations.)")
687
+ return "\n".join(lines)
688
+
689
+
690
+ # =============================================================================
691
+ # UI
692
+ # =============================================================================
693
+ INTRO = """
694
+ # Compression Navigator
695
+ **An LLM is a lossy codec for text.** Training compresses a corpus into weights;
696
+ a forward pass decompresses a continuation. These five tools let you watch that
697
+ decompression and find where facts physically live.
698
+
699
+ Each tab is a real interpretability technique: **logit lens, embedding
700
+ neighbours, activation steering, cross-model diff, and causal tracing (ROME).**
701
+
702
+ ### Three models, on purpose
703
+ | name | how it stores facts | what it teaches |
704
+ |---|---|---|
705
+ | **`glassbox`** | key→value writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source |
706
+ | **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability β€” a real limitation |
707
+ | **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like |
708
+
709
+ **Suggested order:** load `glassbox` first (see "correct"), then `handmade`
710
+ (see a failure mode), then `gpt2` (see reality). Type a name below and Load.
711
+ """
712
+
713
+ with gr.Blocks(title="Compression Navigator") as demo:
714
+ gr.Markdown(INTRO)
715
+ with gr.Row():
716
+ model_name = gr.Textbox(value="glassbox", label="model name or HF id")
717
+ load_btn = gr.Button("Load", variant="primary")
718
+ load_status = gr.Markdown()
719
+ load_btn.click(load_model, inputs=model_name, outputs=load_status)
720
+
721
+ # ---- TAB 1 -------------------------------------------------------------
722
+ with gr.Tab("1 Β· Decompress (logit lens)"):
723
+ gr.Markdown("""
724
+ ### Logit lens β€” watch the answer condense, layer by layer
725
+ **What it does:** takes the last-token residual at *every* layer and reads it
726
+ through the unembedding, as if the model had to answer right there. You see the
727
+ prediction form.
728
+
729
+ **How to read it:** each row is a layer. Watch your tracked token's probability
730
+ (right column) climb, and watch **entropy** (bits) fall as the model commits.
731
+
732
+ **Ground truth to check:**
733
+ - `glassbox` β€” `paris` is ~0 until **L2** (the fact-MLP), then jumps. Sharp and localised because you put it there.
734
+ - `handmade` β€” the answer appears suddenly with no build-up (it's a lookup, not a computation).
735
+ - `gpt2` β€” the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like.
736
+ """)
737
+ ll_prompt = gr.Textbox(value="the capital of france is", label="prompt")
738
+ with gr.Row():
739
+ ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer")
740
+ ll_track = gr.Textbox(value="paris", label="track this token's prob")
741
+ ll_out = gr.Textbox(label="output", lines=18)
742
+ gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out)
743
+
744
+ # ---- TAB 2 -------------------------------------------------------------
745
+ with gr.Tab("2 Β· Triangulate (neighbours)"):
746
+ gr.Markdown("""
747
+ ### Neighbours β€” the geometry of the vocabulary
748
+ **What it does:** ranks tokens by cosine similarity of their unembedding rows.
749
+ Directions that point the same way are "near" in the model's compressed space.
750
+
751
+ **How to read it:** high cosine = the model treats these tokens as related.
752
+
753
+ **Ground truth to check:**
754
+ - `glassbox` β€” `paris` is near `france` (cos β‰ˆ 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design.
755
+ - `handmade` β€” **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here."
756
+ - `gpt2` β€” neighbours are messy but meaningful (casing variants, plurals, semantic kin).
757
+ """)
758
+ nb_word = gr.Textbox(value="paris", label="word")
759
+ nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours")
760
+ nb_out = gr.Textbox(label="output", lines=15)
761
+ gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out)
762
+
763
+ # ---- TAB 3 -------------------------------------------------------------
764
+ with gr.Tab("3 Β· Re-route (steering)"):
765
+ gr.Markdown("""
766
+ ### Steering β€” bend behaviour with a direction, no retraining
767
+ **What it does:** builds the vector `emb(target) βˆ’ emb(source)` and *adds* it to
768
+ a layer's output during generation. The model drifts from `source` toward
769
+ `target`. This is the cheap cousin of fine-tuning (ActAdd / representation
770
+ engineering).
771
+
772
+ **How to read it:** compare *baseline* vs *steered*. Raise **strength** until the
773
+ output flips; too high and it turns to noise (you've knocked the residual off
774
+ the manifold).
775
+
776
+ **Tips:** on `gpt2` try `from: Paris to: London` on the France prompt, layer
777
+ 0–4, strength 6–14. On `glassbox`/`handmade` the vocab is tiny β€” steering is
778
+ mostly a mechanics demo there; the real lesson lives on `gpt2`.
779
+ """)
780
+ st_prompt = gr.Textbox(value="the capital of france is", label="prompt")
781
+ with gr.Row():
782
+ st_src = gr.Textbox(value="Paris", label="from")
783
+ st_tgt = gr.Textbox(value="London", label="to")
784
+ with gr.Row():
785
+ st_layer = gr.Slider(0, 11, value=2, step=1, label="layer")
786
+ st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength")
787
+ st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens")
788
+ st_base = gr.Textbox(label="baseline", lines=2)
789
+ st_out = gr.Textbox(label="steered", lines=3)
790
+ gr.Button("Run").click(steer_generate,
791
+ [st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max],
792
+ [st_base, st_out])
793
+
794
+ # ---- TAB 4 -------------------------------------------------------------
795
+ with gr.Tab("4 Β· Diff (align by depth)"):
796
+ gr.Markdown("""
797
+ ### Diff β€” two models on one prompt, aligned by *relative* depth
798
+ **What it does:** runs the logit lens on model A and model B and lines their
799
+ layers up by percentage depth (0–100%), so you can compare a 2-layer toy with a
800
+ 12-layer gpt2 side by side. `dp` is `p_B βˆ’ p_A` for the target token.
801
+
802
+ **How to read it:** look at *where* on the depth axis each model commits to the
803
+ target. A localised model commits at one depth; a distributed one ramps up.
804
+
805
+ **Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp
806
+ through the middle while glassbox snaps on at its fact layer β€” the same fact,
807
+ two very different internal shapes.
808
+ """)
809
+ with gr.Row():
810
+ df_a = gr.Textbox(value="gpt2", label="model A")
811
+ df_b = gr.Textbox(value="glassbox", label="model B")
812
+ df_prompt = gr.Textbox(value="the capital of france is", label="prompt")
813
+ df_target = gr.Textbox(value="paris", label="target token")
814
+ df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)")
815
+ df_out = gr.Textbox(label="output", lines=16)
816
+ gr.Button("Run").click(diff_models,
817
+ [df_a, df_b, df_prompt, df_target, df_k], df_out)
818
+
819
+ # ---- TAB 5 -------------------------------------------------------------
820
+ with gr.Tab("5 Β· Causal trace (ROME)"):
821
+ gr.Markdown("""
822
+ ### Causal trace β€” corrupt the subject, restore each layer, find the site
823
+ **What it does:** activation patching (Meng et al.'s ROME). It noises the
824
+ **subject** token, which breaks the prediction, then restores one layer at a
825
+ time and measures how much of the answer comes back. The layer that restores
826
+ the most is where the fact is *causally* computed.
827
+
828
+ **How to read it:** `recovery` β‰ˆ 100% means "restoring this layer is enough" β†’
829
+ the fact is read here. The peak line names the site.
830
+
831
+ **Ground truth to check:**
832
+ - `glassbox` β€” peak at **L0** (β‰ˆ100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism.
833
+ - `handmade` β€” `clean p` β‰ˆ `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson β€” patching can't see lookup behaviour.
834
+ - `gpt2` β€” a *band* of early–middle layers at the subject token light up, exactly as in the ROME paper.
835
+ """)
836
+ ct_prompt = gr.Textbox(value="the capital of france is", label="prompt")
837
+ ct_subject = gr.Textbox(value="france", label="subject to corrupt")
838
+ ct_target = gr.Textbox(value="paris", label="target token")
839
+ with gr.Row():
840
+ ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)")
841
+ ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed")
842
+ ct_out = gr.Textbox(label="output", lines=18)
843
+ gr.Button("Run").click(causal_trace,
844
+ [ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out)
845
+
846
+ gr.Markdown("""
847
+ ---
848
+ ### Where this goes next
849
+ - **Edit loop (the VINDEX bridge):** trace β†’ pick the layer β†’ apply a ROME/MEMIT rank-1 edit to that MLP β†’ re-run the logit lens to confirm the new fact took *and* nothing else moved. The glass-box is the unit test for that pipeline before you trust it on a real model.
850
+ - **More glass-box facts / multi-hop:** add `"the currency of france is"` to force a second relation through the same subject, and watch the trace separate the two sites.
851
+ - **Attention + MLP key-value inspection:** Geva-style "what does this neuron write to the vocab" and per-head attribution.
852
+ - **Package as an HF Space** with this writeup as the README β€” it's a clean teaching artifact and a regression harness for interpretability code.
853
+ """)
854
+
855
+ demo.load(lambda: load_model("glassbox"), outputs=load_status)
856
+
857
+ if __name__ == "__main__":
858
+ demo.launch()