Chris4K's picture
Update app.py
cbcdc9d verified
Raw
History Blame Contribute Delete
70.1 kB
# =============================================================================
# COMPRESSION NAVIGATOR Β· extended + annotated edition
# =============================================================================
# An LLM is a lossy codec for text. Training compresses a corpus into weights;
# a forward pass decompresses a continuation. These five tools let you watch
# that decompression happen and poke at where facts physically live.
#
# The five tabs are not toys invented here - each one is a real mechanistic-
# interpretability technique you'll find in papers:
#
# 1. Decompress = LOGIT LENS (nostalgebraist, 2020)
# 2. Triangulate = EMBEDDING NEIGHBOURS (the geometry of the vocab)
# 3. Re-route = ACTIVATION STEERING (ActAdd / repr. engineering)
# 4. Diff = CROSS-MODEL ALIGNMENT (compare checkpoints by depth)
# 5. Causal trace = ACTIVATION PATCHING (ROME, Meng et al., 2022)
#
# WHY THE GLASS-BOX MODELS MATTER
# -------------------------------
# On a real model (gpt2) you never know the ground truth, so you can't tell
# whether a tool is *correct* or just producing plausible-looking output.
# This file ships two models whose internals you fully specify, so you can
# check each tool against a known answer:
#
# "handmade" - facts stored as a LOOKUP TABLE keyed on the prompt string.
# The computation happens in a side channel (string match),
# NOT in the residual stream. Lesson: such a model is almost
# invisible to residual-stream interpretability. Logit lens
# sees a sudden jump with no build-up; causal tracing finds
# nothing, because corrupting activations doesn't touch the
# string match. This is a real and underappreciated *limit*
# of these methods.
#
# "glassbox" - facts stored the way real transformers store them: as
# key->value writes into the RESIDUAL STREAM (Geva et al.'s
# "MLPs are key-value memories", which is exactly what ROME
# edits). Because the fact flows through activations, ALL five
# tools light up correctly - and you can verify they report
# the layer you actually put the fact in. This is a unit-test
# harness for interpretability code.
#
# Run order suggestion: glassbox -> handmade -> gpt2
# glassbox shows what "correct" looks like; handmade shows a failure mode;
# gpt2 shows the fuzzy, distributed real thing.
# =============================================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32
MODELS = {} # name -> (model, tokenizer) cache
STATE = {"name": None} # currently loaded model name
# =============================================================================
# A tiny shared tokenizer for both glass-box models.
# Case is CANONICALISED to lowercase everywhere (this fixes a real bug in the
# original: "Paris" from a pinned fact and "paris" from the Markov table became
# two different vocab entries, so the boosted token and the *tracked* token
# silently diverged - every neighbour read cos=0.000 and every tracked prob 0).
# =============================================================================
class FakeBatchEncoding(dict):
def to(self, device): # let callers do tok(...).to(DEVICE) safely
return self
class SimpleTok:
"""Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map)."""
is_fast = False
def __init__(self, stoi, itos):
self.stoi, self.itos = stoi, itos
self.eos_token_id = stoi["."] # period doubles as end-of-sequence
def _ids(self, text):
words = text.lower().replace(".", " .").split()
return [self.stoi.get(w, self.stoi["<s>"]) for w in words]
def __call__(self, text, return_tensors=None, return_offsets_mapping=False):
ids = self._ids(text) or [self.stoi["<s>"]]
return FakeBatchEncoding(
input_ids=torch.tensor([ids]),
attention_mask=torch.ones(1, len(ids), dtype=torch.long),
)
def encode(self, text, add_special_tokens=False):
return self._ids(text)
def decode(self, ids, skip_special_tokens=False):
out = []
for i in ids:
w = self.itos.get(int(i), "?")
if skip_special_tokens and w in ("<pad>", "<s>"):
continue
out.append(w)
return " ".join(out)
class _Out:
"""Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states."""
def __init__(self, logits, hidden_states):
self.logits = logits
self.hidden_states = hidden_states
def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_):
"""Minimal greedy decode so the steering tab works on the toy models too
(the originals had no .generate, so that tab crashed on 'handmade')."""
ids = input_ids
for _ in range(int(max_new_tokens)):
nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1)
ids = torch.cat([ids, nxt], dim=1)
if pad_token_id is not None and int(nxt.item()) == int(pad_token_id):
break
return ids
# =============================================================================
# MODEL 1 - "handmade": facts as a LOOKUP TABLE (the side-channel glass box)
# -----------------------------------------------------------------------------
# Embeddings are the identity matrix (each token is its own one-hot). The two
# "layers" don't read the residual stream in a meaningful linear way:
# - MemoryBlock matches the *decoded prompt string* and boosts the answer.
# - MarkovBlock adds a hand-built bigram transition for the last token.
# Because MemoryBlock keys on the prompt TEXT, not on activations, this is a
# deliberate demonstration of a model that residual-stream interpretability
# cannot see. Use it as the "what failure looks like" control.
# =============================================================================
PINNED = { # answers are lowercase now (bug fix)
"the capital of france is": " paris",
"the eiffel tower is in": " paris",
"two plus two equals": " four",
}
MARKOV = {
"<s>": {"the": 3, "i": 2, "a": 1},
"the": {"city": 2, "tower": 2, "answer": 1},
"i": {"think": 2, "am": 1},
"a": {"model": 2, "city": 1},
"city": {"of": 3, "is": 1},
"of": {"light": 2, "paris": 1},
"tower": {"is": 3},
"is": {"in": 2, "a": 1},
"in": {"paris": 2, "france": 1},
"model": {"is": 2},
"think": {"the": 2},
"paris": {".": 1},
"france": {".": 1},
"light": {".": 1},
"four": {".": 1},
}
def _build_handmade_vocab():
toks, seen = ["<pad>", "<s>", "."], {"<pad>", "<s>", "."}
def add(w):
if w not in seen:
toks.append(w); seen.add(w)
for v in PINNED.values():
add(v.strip())
for w, nxts in MARKOV.items():
add(w)
for x in nxts:
add(x)
for k in PINNED:
for w in k.split():
add(w)
return toks
HM_VOCAB = _build_handmade_vocab()
HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)}
HM_ITOS = {i: w for w, i in HM_STOI.items()}
HM_V = len(HM_VOCAB)
class _MemoryBlock(nn.Module):
"""If the decoded prompt ends with a pinned key, slam the answer logit.
NOTE: this reads prompt_ids (the string), not x - that's the whole point."""
def forward(self, x, prompt_ids=None):
out = x.clone()
if prompt_ids is not None:
text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip()
for key, ans in PINNED.items():
if text.endswith(key):
out[0, -1, HM_STOI[ans.strip()]] += 12.0
return (out,)
class _MarkovBlock(nn.Module):
"""Add a hand-built bigram transition row for the last token."""
def __init__(self):
super().__init__()
T = torch.zeros(HM_V, HM_V)
for w, nxts in MARKOV.items():
if w in HM_STOI:
tot = sum(nxts.values())
for x, wt in nxts.items():
if x in HM_STOI:
T[HM_STOI[w], HM_STOI[x]] = wt / tot
self.register_buffer("T", T)
def forward(self, x, prompt_ids=None):
out = x.clone()
if prompt_ids:
out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])]
return (out,)
class _HMTransformer(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(HM_V, HM_V)
with torch.no_grad():
self.wte.weight.copy_(torch.eye(HM_V)) # one-hot embeddings
self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()])
self.ln_f = nn.Identity()
class HandmadeModel(nn.Module):
def __init__(self):
super().__init__()
self.transformer = _HMTransformer()
self.head = nn.Linear(HM_V, HM_V, bias=False)
with torch.no_grad():
self.head.weight.copy_(torch.eye(HM_V)) # identity unembed
self.tok = SimpleTok(HM_STOI, HM_ITOS)
def get_input_embeddings(self): return self.transformer.wte
def get_output_embeddings(self): return self.head
def generate(self, input_ids=None, attention_mask=None, **kw):
return _greedy_generate(self, input_ids, **kw)
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
ids = input_ids[0].tolist()
x = self.transformer.wte(input_ids).float()
hs = [x]; h = x
for blk in self.transformer.h:
(h,) = blk(h, prompt_ids=ids); hs.append(h)
logits = self.head(self.transformer.ln_f(h))
return _Out(logits, tuple(hs) if output_hidden_states else None)
# =============================================================================
# MODEL 2 - "glassbox": facts as RESIDUAL-STREAM key->value writes
# -----------------------------------------------------------------------------
# This is the model the original was missing. It stores facts the way real
# transformers do, so every tool works AND can be checked against ground truth.
#
# Vocab + structured embeddings (d=32). Country and its capital deliberately
# SHARE an embedding dimension, so the neighbours tool finds real geometry
# (paris is near france).
#
# Four layers:
# L0 subject site : (identity here) the residual the trace will restore
# L1 pool/attention : copies subject signal from earlier positions -> last
# L2 fact MLP : key(subject+relation) -> relu -> value(answer dir) <- ROME edits this kind of layer
# L3 cleanup : identity
#
# Ground truth you can verify:
# - logit lens: the answer is INVISIBLE until L2, then appears. Compare with
# handmade (sudden, no build-up) and gpt2 (fuzzy, spread over many layers).
# - causal trace: corrupting the subject and restoring layer by layer peaks
# at L0 - because L1's "attention" re-reads the restored subject. That is
# the ROME story: the causal site is an early layer at the SUBJECT token.
# - steering / neighbours: both operate on real directions, so both work.
# =============================================================================
GB_D = 32
GB_TOKS = ["<pad>", "<s>", ".", "the", "capital", "of", "is", "in",
"france", "germany", "japan", "paris", "berlin", "tokyo",
"london", "rome"] # spare answers so edits can hit a fresh target
GB_STOI = {w: i for i, w in enumerate(GB_TOKS)}
GB_ITOS = {i: w for w, i in GB_STOI.items()}
GB_V = len(GB_TOKS)
GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")]
def _build_gb_embeddings():
E = torch.zeros(GB_V, GB_D)
def setd(tok, pairs):
for d, v in pairs:
E[GB_STOI[tok], d] = v
# country/capital pairs share their first dim -> positive cosine (geometry!)
setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)])
setd("paris", [(0, 0.8), (2, 0.9), (21, 0.5)])
setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)])
setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)])
setd("japan", [(6, 1.0), (7, 0.6), (24, 0.5)])
setd("tokyo", [(6, 0.8), (8, 0.9), (25, 0.5)])
setd("london", [(27, 1.0), (28, 0.5)]) # spare answers (own dirs)
setd("rome", [(29, 1.0), (30, 0.5)])
setd("is", [(9, 1.0), (26, 0.4)]) # the relation marker
for i, t in enumerate(GB_TOKS): # give fillers an id
if E[i].abs().sum() == 0:
E[i, 10 + i % 6] = 1.0
return E / (E.norm(dim=-1, keepdim=True) + 1e-9) # unit rows
GB_E = _build_gb_embeddings()
GB_SUBJ = torch.zeros(GB_D, GB_D) # projector onto subject dims 0..8
for _d in range(9):
GB_SUBJ[_d, _d] = 1.0
class _GBIdent(nn.Module):
def forward(self, x, prompt_ids=None):
return (x.clone(),)
class _GBPool(nn.Module):
"""Toy 'attention': sum the subject-projected residual of all earlier
positions into the last position. Corrupting the subject earlier shows up
here; restoring the subject BEFORE this layer is what makes the trace
recover - that is why the causal peak lands at L0, not L1."""
def forward(self, x, prompt_ids=None):
out = x.clone()
if x.shape[1] > 1:
pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0)
out[0, -1] = out[0, -1] + 0.9 * pooled
return (out,)
class _GBFactMLP(nn.Module):
"""Geva-style key->value memory. W_in rows are (subject+relation) keys;
relu gates which fact fires; W_out columns are answer unembed directions.
This is structurally the exact layer ROME rewrites to edit a fact."""
def __init__(self):
super().__init__()
Win = torch.zeros(len(GB_FACTS), GB_D)
Wout = torch.zeros(GB_D, len(GB_FACTS))
rel = GB_E[GB_STOI["is"]]
for k, (s, a) in enumerate(GB_FACTS):
key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel
Win[k] = key / key.norm()
Wout[:, k] = GB_E[GB_STOI[a]] # write answer direction
self.register_buffer("Win", Win)
self.register_buffer("Wout", Wout)
self.register_buffer("Win0", Win.clone()) # pristine backups for reset
self.register_buffer("Wout0", Wout.clone())
self.bias, self.gain = 0.85, 6.0 # tuned: clean p~0.5, corrupt p~0.07
def forward(self, x, prompt_ids=None):
out = x.clone()
pre = F.relu(self.Win @ out[0, -1] - self.bias)
out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre)
return (out,)
class _GBTransformer(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(GB_V, GB_D)
with torch.no_grad():
self.wte.weight.copy_(GB_E)
self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()])
self.ln_f = nn.Identity()
class GlassBoxModel(nn.Module):
def __init__(self):
super().__init__()
self.transformer = _GBTransformer()
self.head = nn.Linear(GB_D, GB_V, bias=False)
with torch.no_grad():
self.head.weight.copy_(GB_E) # tied unembed
self.tok = SimpleTok(GB_STOI, GB_ITOS)
# --- knowledge editing (ROME-style, exact on this key->value layer) -------
@torch.no_grad()
def edit_fact(self, subject, new_answer, method="rank1", strength=1.0):
"""Rewrite the value a fact-MLP key maps to. Methods:
rank1 / surgical - the minimal update: change only this fact's value.
broadcast - DELIBERATELY sloppy: smear the delta across ALL
facts, so the verifier has real collateral to catch.
"""
fm = self.transformer.h[2] # the FactMLP block
subjects = [s for s, _ in GB_FACTS]
if subject not in subjects:
raise ValueError("unknown subject %r" % subject)
if new_answer not in GB_STOI:
raise ValueError("unknown answer token %r" % new_answer)
k = subjects.index(subject)
delta = (GB_E[GB_STOI[new_answer]] - fm.Wout0[:, k]) * float(strength)
if method in ("rank1", "surgical"):
fm.Wout[:, k] = fm.Wout0[:, k] + delta
elif method == "broadcast":
fm.Wout += delta.unsqueeze(1) # hits every fact
else:
raise ValueError("unknown method %r" % method)
@torch.no_grad()
def reset(self):
fm = self.transformer.h[2]
fm.Win.copy_(fm.Win0); fm.Wout.copy_(fm.Wout0)
def get_input_embeddings(self): return self.transformer.wte
def get_output_embeddings(self): return self.head
def generate(self, input_ids=None, attention_mask=None, **kw):
return _greedy_generate(self, input_ids, **kw)
def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False):
ids = input_ids[0].tolist()
x = self.transformer.wte(input_ids).float()
hs = [x]; h = x
for blk in self.transformer.h:
(h,) = blk(h, prompt_ids=ids); hs.append(h)
logits = self.head(self.transformer.ln_f(h))
return _Out(logits, tuple(hs) if output_hidden_states else None)
# =============================================================================
# REAL MODELS - resolve the architecture-specific module paths
# =============================================================================
def _resolve(model, paths):
for path in paths:
obj, ok = model, True
for part in path.split("."):
if hasattr(obj, part):
obj = getattr(obj, part)
else:
ok = False; break
if ok:
return obj
return None
def get_blocks(model):
blocks = _resolve(model, ["transformer.h", "model.layers",
"gpt_neox.layers", "model.decoder.layers"])
if blocks is None:
raise RuntimeError("Could not locate transformer blocks.")
return blocks
def get_final_norm(model):
norm = _resolve(model, ["transformer.ln_f", "model.norm",
"gpt_neox.final_layer_norm",
"model.decoder.final_layer_norm"])
return norm if norm is not None else (lambda x: x)
def get_head(model):
return model.get_output_embeddings()
def get_handles(name):
if name not in MODELS:
if name == "handmade":
m = HandmadeModel().eval(); MODELS[name] = (m, m.tok)
elif name == "glassbox":
m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok)
else:
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=DTYPE).to(DEVICE).eval()
MODELS[name] = (model, tok)
return MODELS[name]
def load_model(name):
name = name.strip()
model, _ = get_handles(name)
STATE["name"] = name
return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model)))
# =============================================================================
# Shared readout: project every layer's last-token residual to a vocab dist.
# =============================================================================
@torch.no_grad()
def layer_distributions(model, tok, prompt):
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
out = model(**inputs, output_hidden_states=True)
hs = out.hidden_states
norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states)
dists = []
for i, layer_hs in enumerate(hs):
vec = layer_hs[0, -1].to(DTYPE)
# HF convention: the LAST hidden_states entry is already post-ln_f,
# so skip norm there; apply ln_f to intermediates (logit-lens style).
logits = head(vec) if i == n - 1 else head(norm(vec))
dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1)))
return dists
def _entropy_bits(probs):
p = probs.clamp_min(1e-12)
return float(-(p * p.log()).sum() / math.log(2))
# =============================================================================
# TAB 1 - LOGIT LENS: watch the answer condense out of the residual stream
# =============================================================================
@torch.no_grad()
def logit_lens(prompt, top_k, track):
if STATE["name"] is None:
return "Load a model first."
model, tok = get_handles(STATE["name"])
top_k = int(top_k)
tids = tok.encode(track, add_special_tokens=False) if track.strip() else []
tid = tids[0] if tids else None
dists = layer_distributions(model, tok, prompt)
header = "layer | top tokens (prob) | entropy" \
+ (" | p(%r)" % track if tid is not None else "")
lines = ["prompt: %r" % prompt, header, "-" * len(header)]
for label, probs in dists:
p, idx = probs.topk(top_k)
shown = " ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v)
for t, v in zip(idx.tolist(), p.tolist()))
row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs))
if tid is not None:
row += " | %.3f" % probs[tid].item()
lines.append(row)
return "\n".join(lines)
# =============================================================================
# TAB 2 - NEIGHBOURS: the geometry of the (un)embedding space
# =============================================================================
@torch.no_grad()
def neighbors(word, top_k):
if STATE["name"] is None:
return "Load a model first."
model, tok = get_handles(STATE["name"])
top_k = int(top_k)
ids = tok.encode(word, add_special_tokens=False)
if not ids:
return "Could not tokenize %r." % word
tid = ids[0]
W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1)
sims = W @ W[tid]
vals, idx = sims.topk(top_k + 1)
note = ""
if STATE["name"] == "handmade":
note = ("(handmade uses one-hot embeddings, so every token is "
"orthogonal -> all cosines are 0 by construction. This is the "
"tool telling the truth about a model with no vocab geometry.)\n")
lines = [note + "neighbours of %r:" % word]
for v, j in zip(vals.tolist(), idx.tolist()):
if j != tid:
lines.append(" %14r cos=%.3f" % (tok.decode([j]), v))
return "\n".join(lines[: top_k + 1])
# =============================================================================
# TAB 3 - STEERING: bend behaviour by adding a direction, no retraining
# =============================================================================
def _make_steer_hook(direction, alpha):
d = direction * alpha
def hook(module, inp, out):
if isinstance(out, tuple):
return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:]
return out + d.to(out.dtype).to(out.device)
return hook
@torch.no_grad()
def steer_generate(prompt, source, target, layer, alpha, max_new):
if STATE["name"] is None:
return "Load a model first.", ""
model, tok = get_handles(STATE["name"])
layer, max_new = int(layer), int(max_new)
emb = model.get_input_embeddings().weight
def first_emb(w):
ids = tok.encode(w, add_special_tokens=False)
return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE)
direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1)
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id)
base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
blocks = get_blocks(model)
layer = max(0, min(layer, len(blocks) - 1))
handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha))
try:
steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True)
finally:
handle.remove()
return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered)
# =============================================================================
# TAB 4 - DIFF: compare two models on one prompt, aligned by relative depth
# =============================================================================
@torch.no_grad()
def diff_models(name_a, name_b, prompt, target, top_k):
ma, ta = get_handles(name_a.strip())
mb, tb = get_handles(name_b.strip())
ida = ta.encode(target, add_special_tokens=False)
idb = tb.encode(target, add_special_tokens=False)
if not ida or not idb:
return "Could not tokenize target %r in both models." % target
ida, idb = ida[0], idb[0]
da = layer_distributions(ma, ta, prompt)
db = layer_distributions(mb, tb, prompt)
nA, nB = len(da) - 1, len(db) - 1
def top1(probs, tok):
v, i = probs.topk(1)
return "%r:%.2f" % (tok.decode([i.item()]), v.item())
lines = ["prompt: %r target: %r" % (prompt, target),
"%18s | %16s %6s | %16s %6s | %7s"
% ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")]
for i in range(nA + 1):
frac = (i / nA) if nA > 0 else 0.0
j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0
la, pa = da[i]; lb, pb = db[j]
a_t, b_t = pa[ida].item(), pb[idb].item()
lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f"
% ("%3.0f%% (%s/%s)" % (frac * 100, la, lb),
top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t))
return "\n".join(lines)
# =============================================================================
# TAB 5 - CAUSAL TRACE: corrupt the subject, restore each layer, find the site
# -----------------------------------------------------------------------------
# This is ROME's activation patching. We:
# 1. record clean activations and clean p(target)
# 2. add gaussian noise to the SUBJECT token embeddings -> corrupt p(target)
# 3. for each layer L: run corrupted, but force layer L's residual back to
# the clean values at the subject positions. How much p(target) recovers
# tells you how causally important layer L is. The peak is "the site".
# The glass-box gives a clean, verifiable peak; gpt2 gives a realistic band.
# =============================================================================
def _find_subject_positions(tok, input_ids, prompt, subject):
"""Locate subject token positions, with a path for slow (non-fast) toks."""
seq_len = input_ids.shape[1]
if getattr(tok, "is_fast", False):
enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True)
cs = prompt.find(subject)
if cs >= 0:
ce = cs + len(subject)
offs = enc["offset_mapping"][0].tolist()
pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce]
if pos:
return [p for p in pos if p != seq_len - 1], ""
else:
sub_ids = tok.encode(subject, add_special_tokens=False)
seq = input_ids[0].tolist()
pos = [i for i, t in enumerate(seq) if t in sub_ids]
if pos:
return [p for p in pos if p != seq_len - 1], ""
fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)]
return fb, "(subject not found; using fallback window)\n"
@torch.no_grad()
def causal_trace(prompt, subject, target, noise_scale, seed):
if STATE["name"] is None:
return "Load a model first."
model, tok = get_handles(STATE["name"])
seed, noise_scale = int(seed), float(noise_scale)
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"]
positions, note = _find_subject_positions(tok, input_ids, prompt, subject)
if not positions:
return note + "No valid subject positions."
target_ids = tok.encode(target, add_special_tokens=False)
if not target_ids:
return "Could not tokenize target %r." % target
tid = target_ids[0]
out_clean = model(**inputs, output_hidden_states=True)
clean_hs = out_clean.hidden_states
clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item()
emb_module = model.get_input_embeddings()
std = emb_module.weight.std().item()
hidden = emb_module.weight.shape[-1]
torch.manual_seed(seed)
noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std
def corrupt_hook(module, inp, out):
out = out.clone()
for k, p in enumerate(positions):
out[0, p] = out[0, p] + noise[k].to(out.dtype)
return out
h = emb_module.register_forward_hook(corrupt_hook)
corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
h.remove()
blocks, rows = get_blocks(model), []
for l in range(len(blocks)):
clean_layer_hs = clean_hs[l + 1][0]
def restore_hook(module, inp, out, _clean=clean_layer_hs):
if isinstance(out, tuple):
h0 = out[0].clone()
for p in positions:
h0[0, p] = _clean[p].to(h0.dtype)
return (h0,) + out[1:]
h0 = out.clone()
for p in positions:
h0[0, p] = _clean[p].to(h0.dtype)
return h0
h1 = emb_module.register_forward_hook(corrupt_hook)
h2 = blocks[l].register_forward_hook(restore_hook)
p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item()
h1.remove(); h2.remove()
rows.append((l, p_r))
denom = clean_p - corrupt_p
lines = [note + "prompt: %r" % prompt,
"subject: %r target: %r" % (subject, target),
"clean p=%.3f corrupt p=%.3f noise=%sx std" % (clean_p, corrupt_p, noise_scale),
"", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")]
best_l, best_r = 0, -1e9
for l, p_r in rows:
rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0
if rec > best_r:
best_r, best_l = rec, l
lines.append(" L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100))
lines.append("")
lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100))
if abs(denom) < 1e-6:
lines.append("# (corruption didn't move p(target): on 'handmade' this is "
"EXPECTED - the fact lives in a string match, not activations.)")
return "\n".join(lines)
# =============================================================================
# EDIT LOOP + VERIFICATION HARNESS (the ROME sandbox)
# -----------------------------------------------------------------------------
# Apply a knowledge edit to the glass-box, then PROVE it was surgical:
# efficacy - did the target fact change to the new answer?
# specificity - did the OTHER facts stay exactly as they were? (locality)
# fluency - did the output distribution stay sane (no entropy collapse)?
# Because we own the ground truth, "nothing else broke" is checkable, not vibes.
# An optional pass sends the before/after battery to Claude for an independent
# verdict - real LLM calls verifying the edit.
# =============================================================================
GB_ANSWERS = ["paris", "berlin", "tokyo", "london", "rome"]
@torch.no_grad()
def _probe_battery(model, tok):
"""Run every known fact + a neutral format probe; record what the model says."""
rows = {}
for country, orig in GB_FACTS:
prompt = "the capital of %s is" % country
probs = F.softmax(model(**tok(prompt, return_tensors="pt").to(DEVICE)
).logits[0, -1].to(DTYPE), dim=-1)
v, i = probs.topk(1)
rows[country] = {
"prompt": prompt, "orig": orig,
"top1": tok.decode([i.item()]), "top1_p": v.item(),
"p_orig": probs[GB_STOI[orig]].item(),
"cand": {a: probs[GB_STOI[a]].item() for a in GB_ANSWERS},
"entropy": _entropy_bits(probs),
}
return rows
def _verdict(before, after, subject, new_answer, drift_thresh=0.05):
eff = after[subject]["top1"] == new_answer
collateral, max_drift = [], 0.0
for c in before:
if c == subject:
continue
d = abs(after[c]["p_orig"] - before[c]["p_orig"])
max_drift = max(max_drift, d)
if after[c]["top1"] != before[c]["top1"] or d > drift_thresh:
collateral.append(c)
ent_blowup = any(abs(after[c]["entropy"] - before[c]["entropy"]) > 0.8 for c in before)
surgical = eff and not collateral and not ent_blowup
return eff, collateral, max_drift, ent_blowup, surgical
def edit_and_verify(subject, new_answer, method, strength, use_llm,
anthropic_key, anthropic_model, hf_token, hf_model,
local_url, local_model):
model, tok = get_handles("glassbox")
STATE["name"] = "glassbox"
model.reset()
before = _probe_battery(model, tok)
try:
model.edit_fact(subject.strip(), new_answer.strip(), method, float(strength))
except ValueError as e:
return "Edit failed: %s\nValid subjects: france, germany, japan. " \
"Valid answers: %s" % (e, ", ".join(GB_ANSWERS))
after = _probe_battery(model, tok)
eff, collateral, max_drift, ent, surgical = _verdict(before, after, subject, new_answer)
L = ["EDIT: %s's capital -> %r (method=%s, strength=%s)" %
(subject, new_answer, method, strength), "",
"%-9s | %-22s | %-22s" % ("fact", "before (top1 / p_orig)", "after (top1 / p_orig)"),
"-" * 60]
for c in before:
b, a = before[c], after[c]
flag = " <- TARGET" if c == subject else (" <- COLLATERAL" if c in collateral else "")
L.append("%-9s | %-22s | %-22s%s" % (
c, "%s / %.2f" % (b["top1"], b["p_orig"]),
"%s / %.2f" % (a["top1"], a["p_orig"]), flag))
L += ["",
"efficacy : %s (target now says %r, p=%.2f)" %
("PASS" if eff else "FAIL", after[subject]["top1"], after[subject]["top1_p"]),
"specificity : %s (max drift on other facts = %.3f%s)" %
("PASS" if not collateral else "FAIL: " + ", ".join(collateral),
max_drift, "; entropy spike" if ent else ""),
"", "VERDICT: %s" % ("SURGICAL EDIT" if surgical else "COLLATERAL DAMAGE")]
L.append("(model is left in the edited state - inspect it in tabs 1-5, or hit Reset.)")
llm_report = ""
if use_llm:
providers = [
{"type": "anthropic", "key": anthropic_key, "model": anthropic_model},
{"type": "hf", "key": hf_token, "model": hf_model},
{"type": "local", "url": local_url, "model": local_model},
]
llm_report = _llm_judge_chain(before, after, subject, new_answer, providers)
L += ["", "-" * 60, "INDEPENDENT LLM REVIEW:", llm_report]
report = "\n".join(L)
_log_session(subject, new_answer, method, strength, before, after,
eff, collateral, max_drift, surgical, llm_report)
return report
def reset_glassbox():
model, _ = get_handles("glassbox")
model.reset()
return "Glass-box weights restored to pristine. Re-run any tab to confirm."
# --- optional: real LLM calls to verify the edit, with a 3-tier fallback chain
# Anthropic (Claude) -> Hugging Face Inference -> local OpenAI-compatible server
# (e.g. LM Studio). Tries each in order; the first provider that's configured
# AND reachable wins. This means you're never blocked on one vendor being down
# or on not having an Anthropic key at all - your own RTX 5090 can be the judge.
def _build_judge_prompt(before, after, subject, new_answer):
import json
payload = {c: {"prompt": before[c]["prompt"],
"before_top1": before[c]["top1"], "before_p_orig": round(before[c]["p_orig"], 3),
"after_top1": after[c]["top1"], "after_p_orig": round(after[c]["p_orig"], 3)}
for c in before}
sys = ("You audit knowledge edits to a small language model. The intended edit "
"is: make %s's capital '%s'. Given before/after predictions for every "
"known fact, decide if the edit was SURGICAL (target changed, all other "
"facts unchanged) or caused COLLATERAL damage. Reply ONLY as JSON, no "
'prose, no markdown fences: {"verdict":"surgical|collateral",'
'"target_changed":bool,"damaged_facts":[...],"confidence":0-1,'
'"reason":"one sentence"}.') % (subject, new_answer)
return sys, json.dumps(payload)
def _parse_verdict_json(text, provider_label):
import json
clean = text.strip().strip("`")
if clean.lower().startswith("json"):
clean = clean[4:].strip()
start, end = clean.find("{"), clean.rfind("}")
if start != -1 and end != -1:
clean = clean[start:end + 1]
v = json.loads(clean)
return ("[%s] verdict=%s target_changed=%s confidence=%s\n damaged: %s\n reason: %s"
% (provider_label, v.get("verdict"), v.get("target_changed"), v.get("confidence"),
v.get("damaged_facts") or "none", v.get("reason")))
def _try_anthropic(sys, user, cfg):
import os, json
key = (cfg.get("key") or "").strip() or os.environ.get("ANTHROPIC_API_KEY", "")
if not key:
return None, "anthropic: no key configured"
body = {"model": (cfg.get("model") or "claude-sonnet-4-6").strip(),
"max_tokens": 400, "system": sys, "messages": [{"role": "user", "content": user}]}
try:
try:
import anthropic
client = anthropic.Anthropic(api_key=key)
msg = client.messages.create(**body)
text = "".join(b.text for b in msg.content if getattr(b, "type", "") == "text")
except ImportError:
import urllib.request
req = urllib.request.Request(
"https://api.anthropic.com/v1/messages", data=json.dumps(body).encode(),
headers={"x-api-key": key, "anthropic-version": "2023-06-01",
"content-type": "application/json"})
with urllib.request.urlopen(req, timeout=30) as r:
data = json.loads(r.read())
text = "".join(b.get("text", "") for b in data.get("content", [])
if b.get("type") == "text")
return _parse_verdict_json(text, "anthropic:" + body["model"]), None
except Exception as e:
return None, "anthropic failed: %s" % e
def _try_hf(sys, user, cfg):
token = (cfg.get("key") or "").strip()
model = (cfg.get("model") or "Qwen/Qwen2.5-7B-Instruct").strip()
if not token:
import os
token = os.environ.get("HF_TOKEN", "")
if not token:
return None, "hf: no token configured"
try:
from huggingface_hub import InferenceClient
client = InferenceClient(model=model, token=token)
resp = client.chat_completion(
messages=[{"role": "system", "content": sys}, {"role": "user", "content": user}],
max_tokens=400)
text = resp.choices[0].message.content
return _parse_verdict_json(text, "hf:" + model), None
except Exception as e:
return None, "hf failed: %s" % e
def _try_local(sys, user, cfg):
"""Any OpenAI-compatible /v1/chat/completions server - LM Studio, vLLM,
Ollama (with its OpenAI shim), text-generation-webui, etc."""
import json, urllib.request
url = (cfg.get("url") or "").strip().rstrip("/")
if not url:
return None, "local: no URL configured"
model = (cfg.get("model") or "local-model").strip()
body = json.dumps({"model": model, "max_tokens": 400, "temperature": 0,
"messages": [{"role": "system", "content": sys},
{"role": "user", "content": user}]}).encode()
try:
req = urllib.request.Request(
url + "/v1/chat/completions", data=body,
headers={"content-type": "application/json"})
with urllib.request.urlopen(req, timeout=20) as r:
data = json.loads(r.read())
text = data["choices"][0]["message"]["content"]
return _parse_verdict_json(text, "local:" + model + "@" + url), None
except Exception as e:
return None, "local failed: %s" % e
def _llm_judge_chain(before, after, subject, new_answer, providers):
sys, user = _build_judge_prompt(before, after, subject, new_answer)
dispatch = {"anthropic": _try_anthropic, "hf": _try_hf, "local": _try_local}
skipped = []
for cfg in providers:
fn = dispatch.get(cfg["type"])
if fn is None:
continue
result, err = fn(sys, user, cfg)
if result is not None:
note = ("" if not skipped else
"(skipped: %s)\n" % "; ".join(skipped))
return note + result
skipped.append(err)
return ("all providers unavailable:\n " + "\n ".join(skipped) +
"\n(configure at least one: Anthropic key, HF token, or a local "
"OpenAI-compatible server URL like http://192.168.188.25:1234)")
# --- session log: every edit+verify run is appended here as JSON, so you can
# download it, or paste the markdown block straight into a future chat with
# Claude for review ("did all work, here's the log").
SESSION_LOG = []
def _log_session(subject, new_answer, method, strength, before, after,
eff, collateral, max_drift, surgical, llm_report):
import datetime
SESSION_LOG.append({
"ts": datetime.datetime.utcnow().isoformat() + "Z",
"subject": subject, "new_answer": new_answer, "method": method,
"strength": strength, "efficacy_pass": bool(eff),
"collateral": collateral, "max_drift": round(max_drift, 4),
"verdict": "SURGICAL" if surgical else "COLLATERAL",
"before": {c: {"top1": before[c]["top1"], "p_orig": round(before[c]["p_orig"], 4)}
for c in before},
"after": {c: {"top1": after[c]["top1"], "p_orig": round(after[c]["p_orig"], 4)}
for c in after},
"llm_review": llm_report or None,
})
def export_session_log():
import json, os
if not SESSION_LOG:
return None, "No edits run yet this session - nothing to export."
os.makedirs("/mnt/user-data/outputs", exist_ok=True)
path = "/mnt/user-data/outputs/edit_session_log.json"
json.dump(SESSION_LOG, open(path, "w"), indent=2)
# also a markdown rendition meant to be pasted straight into a chat
md = ["# Edit session log\n"]
for i, e in enumerate(SESSION_LOG, 1):
md.append("## Edit %d - %s (%s, %s, strength=%s)\n" %
(i, e["verdict"], e["subject"] + "->" + e["new_answer"],
e["method"], e["strength"]))
md.append("- efficacy: %s, max collateral drift: %.4f, damaged: %s" %
("pass" if e["efficacy_pass"] else "fail", e["max_drift"],
e["collateral"] or "none"))
if e["llm_review"]:
md.append("- LLM review: " + e["llm_review"].replace("\n", " "))
md.append("")
md_path = "/mnt/user-data/outputs/edit_session_log.md"
open(md_path, "w").write("\n".join(md))
return path, "Wrote %d edit(s) to %s and %s" % (len(SESSION_LOG), path, md_path)
# =============================================================================
# EXPORT + UPLOAD TO HUGGING FACE
# -----------------------------------------------------------------------------
# Save the glass-box as a self-contained, reloadable repo (weights + config +
# vocab + a standalone modeling file + a model card), and optionally push it -
# and/or this whole app as a Space - to the Hub.
# =============================================================================
_MODELING_PY = '''"""Standalone glass-box model - reload with no other files.
from modeling_glassbox import load
m, tok = load(".") # folder containing config/weights/vocab
print(tok.decode(m.generate(tok("the capital of france is"))[0]))
"""
import json, torch, torch.nn as nn, torch.nn.functional as F
from safetensors.torch import load_file
def load(path="."):
cfg = json.load(open(f"{path}/config.json"))
stoi = json.load(open(f"{path}/vocab.json")); itos = {i: w for w, i in stoi.items()}
D, V = cfg["d_model"], len(stoi); facts = [tuple(f) for f in cfg["facts"]]
SUBJ = torch.zeros(D, D)
for d in range(cfg["subject_dims"]): SUBJ[d, d] = 1.0
class Tok:
is_fast = False
def __init__(s): s.eos_token_id = stoi["."]
def _ids(s, t): return [stoi.get(w, stoi["<s>"]) for w in t.lower().replace(".", " .").split()] or [stoi["<s>"]]
def __call__(s, t, **k):
import torch as T; return {"input_ids": T.tensor([s._ids(t)])}
def decode(s, ids, **k): return " ".join(itos.get(int(i), "?") for i in ids)
class Ident(nn.Module):
def forward(s, x): return (x.clone(),)
class Pool(nn.Module):
def forward(s, x):
o = x.clone()
if x.shape[1] > 1: o[0, -1] = o[0, -1] + 0.9 * (x[0, :-1] @ SUBJ.T).sum(0)
return (o,)
class FactMLP(nn.Module):
def __init__(s):
super().__init__()
s.register_buffer("Win", torch.zeros(len(facts), D))
s.register_buffer("Wout", torch.zeros(D, len(facts)))
s.bias, s.gain = cfg["bias"], cfg["gain"]
def forward(s, x):
o = x.clone(); pre = F.relu(s.Win @ o[0, -1] - s.bias)
o[0, -1] = o[0, -1] + s.gain * (s.Wout @ pre); return (o,)
class T(nn.Module):
def __init__(s):
super().__init__(); s.wte = nn.Embedding(V, D)
s.h = nn.ModuleList([Ident(), Pool(), FactMLP(), Ident()]); s.ln_f = nn.Identity()
class GlassBox(nn.Module):
def __init__(s):
super().__init__(); s.transformer = T(); s.head = nn.Linear(D, V, bias=False)
def get_input_embeddings(s): return s.transformer.wte
def forward(s, input_ids=None, **k):
x = s.transformer.wte(input_ids)
for b in s.transformer.h: (x,) = b(x)
class O: pass
o = O(); o.logits = s.head(x); return o
@torch.no_grad()
def generate(s, input_ids=None, max_new_tokens=12, **k):
ids = input_ids
for _ in range(max_new_tokens):
ids = torch.cat([ids, s(input_ids=ids).logits[0, -1].argmax().view(1, 1)], 1)
return ids
m = GlassBox().eval()
sd = load_file(f"{path}/model.safetensors")
m.load_state_dict({k: v for k, v in sd.items() if not k.endswith("0")}, strict=False)
return m, Tok()
'''
def export_glassbox(outdir="glassbox_export"):
import os, json
from safetensors.torch import save_file
os.makedirs(outdir, exist_ok=True)
model, _ = get_handles("glassbox")
sd = {k: v.contiguous() for k, v in model.state_dict().items()}
save_file(sd, os.path.join(outdir, "model.safetensors"))
json.dump({"model_type": "glassbox", "d_model": GB_D, "vocab_size": GB_V,
"subject_dims": 9, "bias": model.transformer.h[2].bias,
"gain": model.transformer.h[2].gain,
"facts": [list(f) for f in GB_FACTS]},
open(os.path.join(outdir, "config.json"), "w"), indent=2)
json.dump(GB_STOI, open(os.path.join(outdir, "vocab.json"), "w"), indent=2)
open(os.path.join(outdir, "modeling_glassbox.py"), "w").write(_MODELING_PY)
open(os.path.join(outdir, "README.md"), "w").write(
"---\nlicense: mit\ntags: [interpretability, glass-box, rome, toy-model]\n---\n\n"
"# Glass-box interpretability model\n\n"
"A tiny transformer-shaped model whose facts are stored as key->value "
"writes into the residual stream, so logit-lens, activation steering and "
"ROME causal tracing all reproduce the *known* ground truth. Built as a "
"verification harness for interpretability code.\n\n"
"```python\nfrom modeling_glassbox import load\n"
"m, tok = load('.')\n"
"print(tok.decode(m.generate(tok('the capital of france is')['input_ids'])[0]))\n```\n\n"
"Facts: " + ", ".join("%s->%s" % f for f in GB_FACTS) + ".\n")
return outdir
def upload_to_hf(repo_id, token, what, app_path=__file__):
"""Push the model and/or this app (as a Space) to the Hub."""
import os
try:
from huggingface_hub import HfApi
except ImportError:
return "huggingface_hub not installed. `pip install huggingface_hub`."
token = (token or "").strip() or os.environ.get("HF_TOKEN", "")
if not token:
return "No HF token. Paste a write token or set HF_TOKEN."
if not repo_id.strip():
return "Enter a repo id like 'Chris4K/glassbox-interp'."
api, logs = HfApi(token=token), []
try:
if what in ("model", "both"):
d = export_glassbox()
api.create_repo(repo_id, repo_type="model", exist_ok=True)
api.upload_folder(folder_path=d, repo_id=repo_id, repo_type="model")
logs.append("model -> https://huggingface.co/%s" % repo_id)
if what in ("space", "both"):
sid = repo_id + "-space" if what == "both" else repo_id
api.create_repo(sid, repo_type="space", space_sdk="gradio", exist_ok=True)
api.upload_file(path_or_fileobj=app_path, path_in_repo="app.py",
repo_id=sid, repo_type="space")
req = "torch\ntransformers\ngradio\nsafetensors\nhuggingface_hub\nanthropic\n"
api.upload_file(path_or_fileobj=req.encode(), path_in_repo="requirements.txt",
repo_id=sid, repo_type="space")
logs.append("space -> https://huggingface.co/spaces/%s" % sid)
return "Uploaded:\n " + "\n ".join(logs)
except Exception as e:
return "Upload failed: %s" % e
# --- upload a REAL model (e.g. a VINDEX-edited Llama checkpoint), not the toy.
# This does NOT load the model into memory (multi-GB Llama weights don't need
# to round-trip through Python) - it just pushes whatever's already on disk.
# Point it at the local folder produced by your save_pretrained()/VINDEX run:
# expects the usual HF layout (config.json + .safetensors shards + tokenizer
# files). Note: gated models (e.g. meta-llama/*) require the destination repo
# to either be your own namespace or one you have write access to - the Hub's
# license gate is independent of this upload step.
def upload_local_checkpoint(local_dir, repo_id, token, private, commit_message):
import os
try:
from huggingface_hub import HfApi
except ImportError:
return "huggingface_hub not installed. `pip install huggingface_hub`."
local_dir = (local_dir or "").strip()
repo_id = (repo_id or "").strip()
if not local_dir or not os.path.isdir(local_dir):
return "local_dir %r does not exist or is not a directory." % local_dir
if not repo_id:
return "Enter a repo id like 'Chris4K/vindex-llama3-edited'."
token = (token or "").strip() or os.environ.get("HF_TOKEN", "")
if not token:
return "No HF token. Paste a write token or set HF_TOKEN."
has_cfg = os.path.exists(os.path.join(local_dir, "config.json"))
has_weights = any(f.endswith((".safetensors", ".bin"))
for f in os.listdir(local_dir))
warn = "" if (has_cfg and has_weights) else (
"WARNING: folder is missing config.json or weight files - this may "
"not be a loadable HF checkpoint. Uploading anyway.\n")
api = HfApi(token=token)
try:
api.create_repo(repo_id, repo_type="model", private=bool(private), exist_ok=True)
api.upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model",
commit_message=(commit_message or "upload checkpoint").strip())
return (warn + "Uploaded %s -> https://huggingface.co/%s\n"
"Files: %s" % (local_dir, repo_id, ", ".join(sorted(os.listdir(local_dir))[:12])))
except Exception as e:
return warn + "Upload failed: %s" % e
# =============================================================================
# UI
# =============================================================================
INTRO = """
# Compression Navigator
**An LLM is a lossy codec for text.** Training compresses a corpus into weights;
a forward pass decompresses a continuation. These five tools let you watch that
decompression and find where facts physically live.
Each tab is a real interpretability technique: **logit lens, embedding
neighbours, activation steering, cross-model diff, and causal tracing (ROME).**
### Three models, on purpose
| name | how it stores facts | what it teaches |
|---|---|---|
| **`glassbox`** | key→value writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source |
| **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability β€” a real limitation |
| **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like |
**Suggested order:** load `glassbox` first (see "correct"), then `handmade`
(see a failure mode), then `gpt2` (see reality). Type a name below and Load.
"""
with gr.Blocks(title="Compression Navigator") as demo:
gr.Markdown(INTRO)
with gr.Row():
model_name = gr.Textbox(value="glassbox", label="model name or HF id")
load_btn = gr.Button("Load", variant="primary")
load_status = gr.Markdown()
load_btn.click(load_model, inputs=model_name, outputs=load_status)
# ---- TAB 1 -------------------------------------------------------------
with gr.Tab("1 Β· Decompress (logit lens)"):
gr.Markdown("""
### Logit lens β€” watch the answer condense, layer by layer
**What it does:** takes the last-token residual at *every* layer and reads it
through the unembedding, as if the model had to answer right there. You see the
prediction form.
**How to read it:** each row is a layer. Watch your tracked token's probability
(right column) climb, and watch **entropy** (bits) fall as the model commits.
**Ground truth to check:**
- `glassbox` β€” `paris` is ~0 until **L3** (the readout right after the fact-MLP), then jumps to ~0.51. Sharp and localised because you put it there.
- `handmade` β€” the answer snaps to 1.00 at **L1** with zero build-up (it's a lookup, not a computation).
- `gpt2` β€” the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like.
*(Numbering note: the lens counts from the embedding, so `L1` is after the first block. The causal-trace tab counts blocks from `L0`. So the fact-MLP is lens-`L3` / trace-block-`L2`, and its causal site shows at trace-`L0`.)*
""")
ll_prompt = gr.Textbox(value="the capital of france is", label="prompt")
with gr.Row():
ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer")
ll_track = gr.Textbox(value="paris", label="track this token's prob")
ll_out = gr.Textbox(label="output", lines=18)
gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out)
# ---- TAB 2 -------------------------------------------------------------
with gr.Tab("2 Β· Triangulate (neighbours)"):
gr.Markdown("""
### Neighbours β€” the geometry of the vocabulary
**What it does:** ranks tokens by cosine similarity of their unembedding rows.
Directions that point the same way are "near" in the model's compressed space.
**How to read it:** high cosine = the model treats these tokens as related.
**Ground truth to check:**
- `glassbox` β€” `paris` is near `france` (cos β‰ˆ 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design.
- `handmade` β€” **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here."
- `gpt2` β€” neighbours are messy but meaningful (casing variants, plurals, semantic kin).
""")
nb_word = gr.Textbox(value="paris", label="word")
nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours")
nb_out = gr.Textbox(label="output", lines=15)
gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out)
# ---- TAB 3 -------------------------------------------------------------
with gr.Tab("3 Β· Re-route (steering)"):
gr.Markdown("""
### Steering β€” bend behaviour with a direction, no retraining
**What it does:** builds the vector `emb(target) βˆ’ emb(source)` and *adds* it to
a layer's output during generation. The model drifts from `source` toward
`target`. This is the cheap cousin of fine-tuning (ActAdd / representation
engineering).
**How to read it:** compare *baseline* vs *steered*. Raise **strength** until the
output flips; too high and it turns to noise (you've knocked the residual off
the manifold).
**Tips:** on `gpt2` try `from: Paris to: London` on the France prompt, layer
0–4, strength 6–14. On `glassbox` it works cleanly too β€” `from: france
to: japan` at layer 0, strength 8, flips the output from `paris` to `tokyo`
(you're pushing the residual along the subject→subject direction the fact-MLP
keys on).
""")
st_prompt = gr.Textbox(value="the capital of france is", label="prompt")
with gr.Row():
st_src = gr.Textbox(value="Paris", label="from")
st_tgt = gr.Textbox(value="London", label="to")
with gr.Row():
st_layer = gr.Slider(0, 11, value=2, step=1, label="layer")
st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength")
st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens")
st_base = gr.Textbox(label="baseline", lines=2)
st_out = gr.Textbox(label="steered", lines=3)
gr.Button("Run").click(steer_generate,
[st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max],
[st_base, st_out])
# ---- TAB 4 -------------------------------------------------------------
with gr.Tab("4 Β· Diff (align by depth)"):
gr.Markdown("""
### Diff β€” two models on one prompt, aligned by *relative* depth
**What it does:** runs the logit lens on model A and model B and lines their
layers up by percentage depth (0–100%), so you can compare a 2-layer toy with a
12-layer gpt2 side by side. `dp` is `p_B βˆ’ p_A` for the target token.
**How to read it:** look at *where* on the depth axis each model commits to the
target. A localised model commits at one depth; a distributed one ramps up.
**Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp
through the middle while glassbox snaps on at its fact layer β€” the same fact,
two very different internal shapes.
""")
with gr.Row():
df_a = gr.Textbox(value="gpt2", label="model A")
df_b = gr.Textbox(value="glassbox", label="model B")
df_prompt = gr.Textbox(value="the capital of france is", label="prompt")
df_target = gr.Textbox(value="paris", label="target token")
df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)")
df_out = gr.Textbox(label="output", lines=16)
gr.Button("Run").click(diff_models,
[df_a, df_b, df_prompt, df_target, df_k], df_out)
# ---- TAB 5 -------------------------------------------------------------
with gr.Tab("5 Β· Causal trace (ROME)"):
gr.Markdown("""
### Causal trace β€” corrupt the subject, restore each layer, find the site
**What it does:** activation patching (Meng et al.'s ROME). It noises the
**subject** token, which breaks the prediction, then restores one layer at a
time and measures how much of the answer comes back. The layer that restores
the most is where the fact is *causally* computed.
**How to read it:** `recovery` β‰ˆ 100% means "restoring this layer is enough" β†’
the fact is read here. The peak line names the site.
**Ground truth to check:**
- `glassbox` β€” peak at **L0** (β‰ˆ100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism.
- `handmade` β€” `clean p` β‰ˆ `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson β€” patching can't see lookup behaviour.
- `gpt2` β€” a *band* of early–middle layers at the subject token light up, exactly as in the ROME paper.
""")
ct_prompt = gr.Textbox(value="the capital of france is", label="prompt")
ct_subject = gr.Textbox(value="france", label="subject to corrupt")
ct_target = gr.Textbox(value="paris", label="target token")
with gr.Row():
ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)")
ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed")
ct_out = gr.Textbox(label="output", lines=18)
gr.Button("Run").click(causal_trace,
[ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out)
# ---- TAB 6 -------------------------------------------------------------
with gr.Tab("6 Β· Edit + verify (ROME loop)"):
gr.Markdown("""
### Edit a fact, then prove nothing else broke
**What it does:** rewrites the value one fact-MLP key maps to (the exact thing
ROME/MEMIT do on real models β€” this is a literal `nn.Module` weight tensor,
not a token or vocab change), then runs a verification battery over **every**
known fact to measure **efficacy** (target changed), **specificity** (others
untouched), and **fluency** (no entropy collapse).
**Two methods, on purpose:**
- `rank1` β€” the minimal, surgical update. Only the target fact moves β†’ **SURGICAL**.
- `broadcast` β€” a deliberately sloppy edit that smears the change across all facts β†’ the harness catches the **COLLATERAL DAMAGE**. This proves the verifier actually works, not just reports "ok" by default.
**Independent LLM review, with a fallback chain β€” not locked to one vendor:**
tick the box and it tries, in order: **Anthropic** (Claude, if you give a key)
β†’ **Hugging Face Inference** (any hosted chat model, if you give an HF token)
β†’ **your own local server** (LM Studio / vLLM / Ollama's OpenAI shim β€” anything
exposing `/v1/chat/completions`). The first one that's configured *and*
reachable answers; the rest are skipped and noted. So your own RTX 5090 can
be the judge with zero cloud calls if you just fill in the local URL.
Subjects: `france`, `germany`, `japan`. Answers: `paris, berlin, tokyo, london, rome`.
After editing, the model stays edited β€” go look at it in tabs 1–5 (the logit lens
will show the new answer rising; the trace still localises to L0). Hit **Reset**
to restore. Every run is appended to a session log you can download below and
paste into a future chat for review.
""")
with gr.Row():
ed_subj = gr.Textbox(value="france", label="subject")
ed_new = gr.Textbox(value="london", label="new answer")
ed_method = gr.Radio(["rank1", "broadcast"], value="rank1", label="method")
ed_strength = gr.Slider(0.2, 2.0, value=1.0, step=0.1, label="strength")
ed_llm = gr.Checkbox(value=False, label="also run an independent LLM review")
with gr.Accordion("LLM review providers (tried in this order)", open=False):
with gr.Row():
ed_a_model = gr.Textbox(value="claude-sonnet-4-6", label="1. Anthropic model")
ed_a_key = gr.Textbox(value="", label="Anthropic API key", type="password")
with gr.Row():
ed_h_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct",
label="2. HF Inference model")
ed_h_key = gr.Textbox(value="", label="HF token", type="password")
with gr.Row():
ed_l_url = gr.Textbox(value="http://192.168.188.25:1234",
label="3. Local server URL (LM Studio etc.)")
ed_l_model = gr.Textbox(value="local-model", label="local model name")
ed_out = gr.Textbox(label="edit + verification report", lines=24)
with gr.Row():
gr.Button("Edit & verify", variant="primary").click(
edit_and_verify,
[ed_subj, ed_new, ed_method, ed_strength, ed_llm,
ed_a_key, ed_a_model, ed_h_key, ed_h_model, ed_l_url, ed_l_model],
ed_out)
gr.Button("Reset model").click(reset_glassbox, outputs=ed_out)
gr.Markdown("**Session log** (every edit run above, appended):")
with gr.Row():
log_btn = gr.Button("Write session log to disk")
log_file = gr.File(label="download")
log_status = gr.Markdown()
log_btn.click(lambda: export_session_log(), outputs=[log_file, log_status])
# ---- TAB 7 -------------------------------------------------------------
with gr.Tab("7 Β· Export / Upload to HF"):
gr.Markdown("""
### Ship the toy glass-box
**Export** writes a self-contained, reloadable repo: weights (`safetensors`),
`config.json`, `vocab.json`, a standalone `modeling_glassbox.py` (reload with
`from modeling_glassbox import load`), and a model card.
**Upload** pushes it to the Hub. Choose `model`, `space` (this whole app,
runnable), or `both`. Paste a **write** token (or set `HF_TOKEN`).
""")
with gr.Row():
hf_repo = gr.Textbox(value="Chris4K/glassbox-interp", label="repo id")
hf_what = gr.Radio(["model", "space", "both"], value="model", label="what to push")
hf_token = gr.Textbox(value="", label="HF write token (optional)", type="password")
hf_out = gr.Textbox(label="result", lines=6)
with gr.Row():
gr.Button("Export locally").click(
lambda: "Exported to ./%s" % export_glassbox(), outputs=hf_out)
gr.Button("Upload to HF", variant="primary").click(
upload_to_hf, [hf_repo, hf_token, hf_what], hf_out)
gr.Markdown("""
---
### Upload a REAL model β€” e.g. your VINDEX-edited Llama checkpoint
This does **not** load the model into memory and does **not** assume any
particular architecture β€” it just pushes whatever's already on disk at
`local_dir` (the usual `save_pretrained()` layout: `config.json` +
`*.safetensors` shards + tokenizer files) straight to a new repo. Large
weights upload fine through `upload_folder`; for very large repos consider
installing `hf_transfer` for faster throughput. If the base model is gated
(e.g. `meta-llama/*`), the gate applies to the destination repo's license
settings, not to this upload step.
""")
with gr.Row():
rc_dir = gr.Textbox(value="", label="local checkpoint folder (on this machine)")
rc_repo = gr.Textbox(value="", label="destination repo id, e.g. Chris4K/vindex-llama3-edited")
with gr.Row():
rc_token = gr.Textbox(value="", label="HF write token (optional)", type="password")
rc_private = gr.Checkbox(value=True, label="private repo")
rc_msg = gr.Textbox(value="upload edited checkpoint", label="commit message")
rc_out = gr.Textbox(label="result", lines=6)
gr.Button("Upload real checkpoint", variant="primary").click(
upload_local_checkpoint, [rc_dir, rc_repo, rc_token, rc_private, rc_msg], rc_out)
gr.Markdown("""
---
### Where this goes next
- **Closing the loop (what "self-improving" would actually require):** right now a human picks every edit; the verifier just grades it. A real closed loop needs a policy that *proposes* edits on its own (e.g. scanning eval failures for wrong facts), auto-applies, and auto-commits only on a SURGICAL verdict, rolling back otherwise. The hard part β€” the verifier β€” already exists here; the proposal step doesn't yet.
- **A training-method angle worth taking seriously:** instead of accept/reject after the fact, feed the specificity battery's drift score back as a regularizer *during* the edit computation (closer to elastic weight consolidation, or the null-space projection AlphaEdit-style methods use) so collateral is penalized while solving, not caught after.
- **Real-model MEMIT:** the edit loop here is exact because the glass-box's fact layer is literally key→value. The same verify harness (efficacy / specificity / fluency + the multi-provider LLM judge) ports straight onto a gpt2/Llama MEMIT edit — the toy is the regression test you run first.
- **Multi-hop & paraphrase generalization:** add `"the currency of france is"` so two relations share a subject, and have the LLM judge auto-generate paraphrase probes to test that an edit generalizes, not just memorizes the one prompt.
- **Attribution view:** Geva-style "what does this neuron write to the vocab", per-head attention attribution.
- **It already ships:** tab 7 pushes the toy model and this whole app (as a Space) to your Hub, or a real local checkpoint folder to its own repo.
""")
demo.load(lambda: load_model("glassbox"), outputs=load_status)
if __name__ == "__main__":
demo.launch()