# ============================================================================= # COMPRESSION NAVIGATOR · extended + annotated edition # ============================================================================= # An LLM is a lossy codec for text. Training compresses a corpus into weights; # a forward pass decompresses a continuation. These five tools let you watch # that decompression happen and poke at where facts physically live. # # The five tabs are not toys invented here - each one is a real mechanistic- # interpretability technique you'll find in papers: # # 1. Decompress = LOGIT LENS (nostalgebraist, 2020) # 2. Triangulate = EMBEDDING NEIGHBOURS (the geometry of the vocab) # 3. Re-route = ACTIVATION STEERING (ActAdd / repr. engineering) # 4. Diff = CROSS-MODEL ALIGNMENT (compare checkpoints by depth) # 5. Causal trace = ACTIVATION PATCHING (ROME, Meng et al., 2022) # # WHY THE GLASS-BOX MODELS MATTER # ------------------------------- # On a real model (gpt2) you never know the ground truth, so you can't tell # whether a tool is *correct* or just producing plausible-looking output. # This file ships two models whose internals you fully specify, so you can # check each tool against a known answer: # # "handmade" - facts stored as a LOOKUP TABLE keyed on the prompt string. # The computation happens in a side channel (string match), # NOT in the residual stream. Lesson: such a model is almost # invisible to residual-stream interpretability. Logit lens # sees a sudden jump with no build-up; causal tracing finds # nothing, because corrupting activations doesn't touch the # string match. This is a real and underappreciated *limit* # of these methods. # # "glassbox" - facts stored the way real transformers store them: as # key->value writes into the RESIDUAL STREAM (Geva et al.'s # "MLPs are key-value memories", which is exactly what ROME # edits). Because the fact flows through activations, ALL five # tools light up correctly - and you can verify they report # the layer you actually put the fact in. This is a unit-test # harness for interpretability code. # # Run order suggestion: glassbox -> handmade -> gpt2 # glassbox shows what "correct" looks like; handmade shows a failure mode; # gpt2 shows the fuzzy, distributed real thing. # ============================================================================= import math import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float32 MODELS = {} # name -> (model, tokenizer) cache STATE = {"name": None} # currently loaded model name # ============================================================================= # A tiny shared tokenizer for both glass-box models. # Case is CANONICALISED to lowercase everywhere (this fixes a real bug in the # original: "Paris" from a pinned fact and "paris" from the Markov table became # two different vocab entries, so the boosted token and the *tracked* token # silently diverged - every neighbour read cos=0.000 and every tracked prob 0). # ============================================================================= class FakeBatchEncoding(dict): def to(self, device): # let callers do tok(...).to(DEVICE) safely return self class SimpleTok: """Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map).""" is_fast = False def __init__(self, stoi, itos): self.stoi, self.itos = stoi, itos self.eos_token_id = stoi["."] # period doubles as end-of-sequence def _ids(self, text): words = text.lower().replace(".", " .").split() return [self.stoi.get(w, self.stoi[""]) for w in words] def __call__(self, text, return_tensors=None, return_offsets_mapping=False): ids = self._ids(text) or [self.stoi[""]] return FakeBatchEncoding( input_ids=torch.tensor([ids]), attention_mask=torch.ones(1, len(ids), dtype=torch.long), ) def encode(self, text, add_special_tokens=False): return self._ids(text) def decode(self, ids, skip_special_tokens=False): out = [] for i in ids: w = self.itos.get(int(i), "?") if skip_special_tokens and w in ("", ""): continue out.append(w) return " ".join(out) class _Out: """Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states.""" def __init__(self, logits, hidden_states): self.logits = logits self.hidden_states = hidden_states def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_): """Minimal greedy decode so the steering tab works on the toy models too (the originals had no .generate, so that tab crashed on 'handmade').""" ids = input_ids for _ in range(int(max_new_tokens)): nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1) ids = torch.cat([ids, nxt], dim=1) if pad_token_id is not None and int(nxt.item()) == int(pad_token_id): break return ids # ============================================================================= # MODEL 1 - "handmade": facts as a LOOKUP TABLE (the side-channel glass box) # ----------------------------------------------------------------------------- # Embeddings are the identity matrix (each token is its own one-hot). The two # "layers" don't read the residual stream in a meaningful linear way: # - MemoryBlock matches the *decoded prompt string* and boosts the answer. # - MarkovBlock adds a hand-built bigram transition for the last token. # Because MemoryBlock keys on the prompt TEXT, not on activations, this is a # deliberate demonstration of a model that residual-stream interpretability # cannot see. Use it as the "what failure looks like" control. # ============================================================================= PINNED = { # answers are lowercase now (bug fix) "the capital of france is": " paris", "the eiffel tower is in": " paris", "two plus two equals": " four", } MARKOV = { "": {"the": 3, "i": 2, "a": 1}, "the": {"city": 2, "tower": 2, "answer": 1}, "i": {"think": 2, "am": 1}, "a": {"model": 2, "city": 1}, "city": {"of": 3, "is": 1}, "of": {"light": 2, "paris": 1}, "tower": {"is": 3}, "is": {"in": 2, "a": 1}, "in": {"paris": 2, "france": 1}, "model": {"is": 2}, "think": {"the": 2}, "paris": {".": 1}, "france": {".": 1}, "light": {".": 1}, "four": {".": 1}, } def _build_handmade_vocab(): toks, seen = ["", "", "."], {"", "", "."} def add(w): if w not in seen: toks.append(w); seen.add(w) for v in PINNED.values(): add(v.strip()) for w, nxts in MARKOV.items(): add(w) for x in nxts: add(x) for k in PINNED: for w in k.split(): add(w) return toks HM_VOCAB = _build_handmade_vocab() HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)} HM_ITOS = {i: w for w, i in HM_STOI.items()} HM_V = len(HM_VOCAB) class _MemoryBlock(nn.Module): """If the decoded prompt ends with a pinned key, slam the answer logit. NOTE: this reads prompt_ids (the string), not x - that's the whole point.""" def forward(self, x, prompt_ids=None): out = x.clone() if prompt_ids is not None: text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip() for key, ans in PINNED.items(): if text.endswith(key): out[0, -1, HM_STOI[ans.strip()]] += 12.0 return (out,) class _MarkovBlock(nn.Module): """Add a hand-built bigram transition row for the last token.""" def __init__(self): super().__init__() T = torch.zeros(HM_V, HM_V) for w, nxts in MARKOV.items(): if w in HM_STOI: tot = sum(nxts.values()) for x, wt in nxts.items(): if x in HM_STOI: T[HM_STOI[w], HM_STOI[x]] = wt / tot self.register_buffer("T", T) def forward(self, x, prompt_ids=None): out = x.clone() if prompt_ids: out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])] return (out,) class _HMTransformer(nn.Module): def __init__(self): super().__init__() self.wte = nn.Embedding(HM_V, HM_V) with torch.no_grad(): self.wte.weight.copy_(torch.eye(HM_V)) # one-hot embeddings self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()]) self.ln_f = nn.Identity() class HandmadeModel(nn.Module): def __init__(self): super().__init__() self.transformer = _HMTransformer() self.head = nn.Linear(HM_V, HM_V, bias=False) with torch.no_grad(): self.head.weight.copy_(torch.eye(HM_V)) # identity unembed self.tok = SimpleTok(HM_STOI, HM_ITOS) def get_input_embeddings(self): return self.transformer.wte def get_output_embeddings(self): return self.head def generate(self, input_ids=None, attention_mask=None, **kw): return _greedy_generate(self, input_ids, **kw) def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False): ids = input_ids[0].tolist() x = self.transformer.wte(input_ids).float() hs = [x]; h = x for blk in self.transformer.h: (h,) = blk(h, prompt_ids=ids); hs.append(h) logits = self.head(self.transformer.ln_f(h)) return _Out(logits, tuple(hs) if output_hidden_states else None) # ============================================================================= # MODEL 2 - "glassbox": facts as RESIDUAL-STREAM key->value writes # ----------------------------------------------------------------------------- # This is the model the original was missing. It stores facts the way real # transformers do, so every tool works AND can be checked against ground truth. # # Vocab + structured embeddings (d=32). Country and its capital deliberately # SHARE an embedding dimension, so the neighbours tool finds real geometry # (paris is near france). # # Four layers: # L0 subject site : (identity here) the residual the trace will restore # L1 pool/attention : copies subject signal from earlier positions -> last # L2 fact MLP : key(subject+relation) -> relu -> value(answer dir) <- ROME edits this kind of layer # L3 cleanup : identity # # Ground truth you can verify: # - logit lens: the answer is INVISIBLE until L2, then appears. Compare with # handmade (sudden, no build-up) and gpt2 (fuzzy, spread over many layers). # - causal trace: corrupting the subject and restoring layer by layer peaks # at L0 - because L1's "attention" re-reads the restored subject. That is # the ROME story: the causal site is an early layer at the SUBJECT token. # - steering / neighbours: both operate on real directions, so both work. # ============================================================================= GB_D = 32 GB_TOKS = ["", "", ".", "the", "capital", "of", "is", "in", "france", "germany", "japan", "paris", "berlin", "tokyo", "london", "rome"] # spare answers so edits can hit a fresh target GB_STOI = {w: i for i, w in enumerate(GB_TOKS)} GB_ITOS = {i: w for w, i in GB_STOI.items()} GB_V = len(GB_TOKS) GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")] def _build_gb_embeddings(): E = torch.zeros(GB_V, GB_D) def setd(tok, pairs): for d, v in pairs: E[GB_STOI[tok], d] = v # country/capital pairs share their first dim -> positive cosine (geometry!) setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)]) setd("paris", [(0, 0.8), (2, 0.9), (21, 0.5)]) setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)]) setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)]) setd("japan", [(6, 1.0), (7, 0.6), (24, 0.5)]) setd("tokyo", [(6, 0.8), (8, 0.9), (25, 0.5)]) setd("london", [(27, 1.0), (28, 0.5)]) # spare answers (own dirs) setd("rome", [(29, 1.0), (30, 0.5)]) setd("is", [(9, 1.0), (26, 0.4)]) # the relation marker for i, t in enumerate(GB_TOKS): # give fillers an id if E[i].abs().sum() == 0: E[i, 10 + i % 6] = 1.0 return E / (E.norm(dim=-1, keepdim=True) + 1e-9) # unit rows GB_E = _build_gb_embeddings() GB_SUBJ = torch.zeros(GB_D, GB_D) # projector onto subject dims 0..8 for _d in range(9): GB_SUBJ[_d, _d] = 1.0 class _GBIdent(nn.Module): def forward(self, x, prompt_ids=None): return (x.clone(),) class _GBPool(nn.Module): """Toy 'attention': sum the subject-projected residual of all earlier positions into the last position. Corrupting the subject earlier shows up here; restoring the subject BEFORE this layer is what makes the trace recover - that is why the causal peak lands at L0, not L1.""" def forward(self, x, prompt_ids=None): out = x.clone() if x.shape[1] > 1: pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0) out[0, -1] = out[0, -1] + 0.9 * pooled return (out,) class _GBFactMLP(nn.Module): """Geva-style key->value memory. W_in rows are (subject+relation) keys; relu gates which fact fires; W_out columns are answer unembed directions. This is structurally the exact layer ROME rewrites to edit a fact.""" def __init__(self): super().__init__() Win = torch.zeros(len(GB_FACTS), GB_D) Wout = torch.zeros(GB_D, len(GB_FACTS)) rel = GB_E[GB_STOI["is"]] for k, (s, a) in enumerate(GB_FACTS): key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel Win[k] = key / key.norm() Wout[:, k] = GB_E[GB_STOI[a]] # write answer direction self.register_buffer("Win", Win) self.register_buffer("Wout", Wout) self.register_buffer("Win0", Win.clone()) # pristine backups for reset self.register_buffer("Wout0", Wout.clone()) self.bias, self.gain = 0.85, 6.0 # tuned: clean p~0.5, corrupt p~0.07 def forward(self, x, prompt_ids=None): out = x.clone() pre = F.relu(self.Win @ out[0, -1] - self.bias) out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre) return (out,) class _GBTransformer(nn.Module): def __init__(self): super().__init__() self.wte = nn.Embedding(GB_V, GB_D) with torch.no_grad(): self.wte.weight.copy_(GB_E) self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()]) self.ln_f = nn.Identity() class GlassBoxModel(nn.Module): def __init__(self): super().__init__() self.transformer = _GBTransformer() self.head = nn.Linear(GB_D, GB_V, bias=False) with torch.no_grad(): self.head.weight.copy_(GB_E) # tied unembed self.tok = SimpleTok(GB_STOI, GB_ITOS) # --- knowledge editing (ROME-style, exact on this key->value layer) ------- @torch.no_grad() def edit_fact(self, subject, new_answer, method="rank1", strength=1.0): """Rewrite the value a fact-MLP key maps to. Methods: rank1 / surgical - the minimal update: change only this fact's value. broadcast - DELIBERATELY sloppy: smear the delta across ALL facts, so the verifier has real collateral to catch. """ fm = self.transformer.h[2] # the FactMLP block subjects = [s for s, _ in GB_FACTS] if subject not in subjects: raise ValueError("unknown subject %r" % subject) if new_answer not in GB_STOI: raise ValueError("unknown answer token %r" % new_answer) k = subjects.index(subject) delta = (GB_E[GB_STOI[new_answer]] - fm.Wout0[:, k]) * float(strength) if method in ("rank1", "surgical"): fm.Wout[:, k] = fm.Wout0[:, k] + delta elif method == "broadcast": fm.Wout += delta.unsqueeze(1) # hits every fact else: raise ValueError("unknown method %r" % method) @torch.no_grad() def reset(self): fm = self.transformer.h[2] fm.Win.copy_(fm.Win0); fm.Wout.copy_(fm.Wout0) def get_input_embeddings(self): return self.transformer.wte def get_output_embeddings(self): return self.head def generate(self, input_ids=None, attention_mask=None, **kw): return _greedy_generate(self, input_ids, **kw) def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False): ids = input_ids[0].tolist() x = self.transformer.wte(input_ids).float() hs = [x]; h = x for blk in self.transformer.h: (h,) = blk(h, prompt_ids=ids); hs.append(h) logits = self.head(self.transformer.ln_f(h)) return _Out(logits, tuple(hs) if output_hidden_states else None) # ============================================================================= # REAL MODELS - resolve the architecture-specific module paths # ============================================================================= def _resolve(model, paths): for path in paths: obj, ok = model, True for part in path.split("."): if hasattr(obj, part): obj = getattr(obj, part) else: ok = False; break if ok: return obj return None def get_blocks(model): blocks = _resolve(model, ["transformer.h", "model.layers", "gpt_neox.layers", "model.decoder.layers"]) if blocks is None: raise RuntimeError("Could not locate transformer blocks.") return blocks def get_final_norm(model): norm = _resolve(model, ["transformer.ln_f", "model.norm", "gpt_neox.final_layer_norm", "model.decoder.final_layer_norm"]) return norm if norm is not None else (lambda x: x) def get_head(model): return model.get_output_embeddings() def get_handles(name): if name not in MODELS: if name == "handmade": m = HandmadeModel().eval(); MODELS[name] = (m, m.tok) elif name == "glassbox": m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok) else: tok = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained( name, torch_dtype=DTYPE).to(DEVICE).eval() MODELS[name] = (model, tok) return MODELS[name] def load_model(name): name = name.strip() model, _ = get_handles(name) STATE["name"] = name return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model))) # ============================================================================= # Shared readout: project every layer's last-token residual to a vocab dist. # ============================================================================= @torch.no_grad() def layer_distributions(model, tok, prompt): inputs = tok(prompt, return_tensors="pt").to(DEVICE) out = model(**inputs, output_hidden_states=True) hs = out.hidden_states norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states) dists = [] for i, layer_hs in enumerate(hs): vec = layer_hs[0, -1].to(DTYPE) # HF convention: the LAST hidden_states entry is already post-ln_f, # so skip norm there; apply ln_f to intermediates (logit-lens style). logits = head(vec) if i == n - 1 else head(norm(vec)) dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1))) return dists def _entropy_bits(probs): p = probs.clamp_min(1e-12) return float(-(p * p.log()).sum() / math.log(2)) # ============================================================================= # TAB 1 - LOGIT LENS: watch the answer condense out of the residual stream # ============================================================================= @torch.no_grad() def logit_lens(prompt, top_k, track): if STATE["name"] is None: return "Load a model first." model, tok = get_handles(STATE["name"]) top_k = int(top_k) tids = tok.encode(track, add_special_tokens=False) if track.strip() else [] tid = tids[0] if tids else None dists = layer_distributions(model, tok, prompt) header = "layer | top tokens (prob) | entropy" \ + (" | p(%r)" % track if tid is not None else "") lines = ["prompt: %r" % prompt, header, "-" * len(header)] for label, probs in dists: p, idx = probs.topk(top_k) shown = " ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v) for t, v in zip(idx.tolist(), p.tolist())) row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs)) if tid is not None: row += " | %.3f" % probs[tid].item() lines.append(row) return "\n".join(lines) # ============================================================================= # TAB 2 - NEIGHBOURS: the geometry of the (un)embedding space # ============================================================================= @torch.no_grad() def neighbors(word, top_k): if STATE["name"] is None: return "Load a model first." model, tok = get_handles(STATE["name"]) top_k = int(top_k) ids = tok.encode(word, add_special_tokens=False) if not ids: return "Could not tokenize %r." % word tid = ids[0] W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1) sims = W @ W[tid] vals, idx = sims.topk(top_k + 1) note = "" if STATE["name"] == "handmade": note = ("(handmade uses one-hot embeddings, so every token is " "orthogonal -> all cosines are 0 by construction. This is the " "tool telling the truth about a model with no vocab geometry.)\n") lines = [note + "neighbours of %r:" % word] for v, j in zip(vals.tolist(), idx.tolist()): if j != tid: lines.append(" %14r cos=%.3f" % (tok.decode([j]), v)) return "\n".join(lines[: top_k + 1]) # ============================================================================= # TAB 3 - STEERING: bend behaviour by adding a direction, no retraining # ============================================================================= def _make_steer_hook(direction, alpha): d = direction * alpha def hook(module, inp, out): if isinstance(out, tuple): return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:] return out + d.to(out.dtype).to(out.device) return hook @torch.no_grad() def steer_generate(prompt, source, target, layer, alpha, max_new): if STATE["name"] is None: return "Load a model first.", "" model, tok = get_handles(STATE["name"]) layer, max_new = int(layer), int(max_new) emb = model.get_input_embeddings().weight def first_emb(w): ids = tok.encode(w, add_special_tokens=False) return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE) direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1) inputs = tok(prompt, return_tensors="pt").to(DEVICE) gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id) base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True) blocks = get_blocks(model) layer = max(0, min(layer, len(blocks) - 1)) handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha)) try: steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True) finally: handle.remove() return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered) # ============================================================================= # TAB 4 - DIFF: compare two models on one prompt, aligned by relative depth # ============================================================================= @torch.no_grad() def diff_models(name_a, name_b, prompt, target, top_k): ma, ta = get_handles(name_a.strip()) mb, tb = get_handles(name_b.strip()) ida = ta.encode(target, add_special_tokens=False) idb = tb.encode(target, add_special_tokens=False) if not ida or not idb: return "Could not tokenize target %r in both models." % target ida, idb = ida[0], idb[0] da = layer_distributions(ma, ta, prompt) db = layer_distributions(mb, tb, prompt) nA, nB = len(da) - 1, len(db) - 1 def top1(probs, tok): v, i = probs.topk(1) return "%r:%.2f" % (tok.decode([i.item()]), v.item()) lines = ["prompt: %r target: %r" % (prompt, target), "%18s | %16s %6s | %16s %6s | %7s" % ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")] for i in range(nA + 1): frac = (i / nA) if nA > 0 else 0.0 j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0 la, pa = da[i]; lb, pb = db[j] a_t, b_t = pa[ida].item(), pb[idb].item() lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f" % ("%3.0f%% (%s/%s)" % (frac * 100, la, lb), top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t)) return "\n".join(lines) # ============================================================================= # TAB 5 - CAUSAL TRACE: corrupt the subject, restore each layer, find the site # ----------------------------------------------------------------------------- # This is ROME's activation patching. We: # 1. record clean activations and clean p(target) # 2. add gaussian noise to the SUBJECT token embeddings -> corrupt p(target) # 3. for each layer L: run corrupted, but force layer L's residual back to # the clean values at the subject positions. How much p(target) recovers # tells you how causally important layer L is. The peak is "the site". # The glass-box gives a clean, verifiable peak; gpt2 gives a realistic band. # ============================================================================= def _find_subject_positions(tok, input_ids, prompt, subject): """Locate subject token positions, with a path for slow (non-fast) toks.""" seq_len = input_ids.shape[1] if getattr(tok, "is_fast", False): enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True) cs = prompt.find(subject) if cs >= 0: ce = cs + len(subject) offs = enc["offset_mapping"][0].tolist() pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce] if pos: return [p for p in pos if p != seq_len - 1], "" else: sub_ids = tok.encode(subject, add_special_tokens=False) seq = input_ids[0].tolist() pos = [i for i, t in enumerate(seq) if t in sub_ids] if pos: return [p for p in pos if p != seq_len - 1], "" fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)] return fb, "(subject not found; using fallback window)\n" @torch.no_grad() def causal_trace(prompt, subject, target, noise_scale, seed): if STATE["name"] is None: return "Load a model first." model, tok = get_handles(STATE["name"]) seed, noise_scale = int(seed), float(noise_scale) inputs = tok(prompt, return_tensors="pt").to(DEVICE) input_ids = inputs["input_ids"] positions, note = _find_subject_positions(tok, input_ids, prompt, subject) if not positions: return note + "No valid subject positions." target_ids = tok.encode(target, add_special_tokens=False) if not target_ids: return "Could not tokenize target %r." % target tid = target_ids[0] out_clean = model(**inputs, output_hidden_states=True) clean_hs = out_clean.hidden_states clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item() emb_module = model.get_input_embeddings() std = emb_module.weight.std().item() hidden = emb_module.weight.shape[-1] torch.manual_seed(seed) noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std def corrupt_hook(module, inp, out): out = out.clone() for k, p in enumerate(positions): out[0, p] = out[0, p] + noise[k].to(out.dtype) return out h = emb_module.register_forward_hook(corrupt_hook) corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item() h.remove() blocks, rows = get_blocks(model), [] for l in range(len(blocks)): clean_layer_hs = clean_hs[l + 1][0] def restore_hook(module, inp, out, _clean=clean_layer_hs): if isinstance(out, tuple): h0 = out[0].clone() for p in positions: h0[0, p] = _clean[p].to(h0.dtype) return (h0,) + out[1:] h0 = out.clone() for p in positions: h0[0, p] = _clean[p].to(h0.dtype) return h0 h1 = emb_module.register_forward_hook(corrupt_hook) h2 = blocks[l].register_forward_hook(restore_hook) p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item() h1.remove(); h2.remove() rows.append((l, p_r)) denom = clean_p - corrupt_p lines = [note + "prompt: %r" % prompt, "subject: %r target: %r" % (subject, target), "clean p=%.3f corrupt p=%.3f noise=%sx std" % (clean_p, corrupt_p, noise_scale), "", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")] best_l, best_r = 0, -1e9 for l, p_r in rows: rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0 if rec > best_r: best_r, best_l = rec, l lines.append(" L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100)) lines.append("") lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100)) if abs(denom) < 1e-6: lines.append("# (corruption didn't move p(target): on 'handmade' this is " "EXPECTED - the fact lives in a string match, not activations.)") return "\n".join(lines) # ============================================================================= # EDIT LOOP + VERIFICATION HARNESS (the ROME sandbox) # ----------------------------------------------------------------------------- # Apply a knowledge edit to the glass-box, then PROVE it was surgical: # efficacy - did the target fact change to the new answer? # specificity - did the OTHER facts stay exactly as they were? (locality) # fluency - did the output distribution stay sane (no entropy collapse)? # Because we own the ground truth, "nothing else broke" is checkable, not vibes. # An optional pass sends the before/after battery to Claude for an independent # verdict - real LLM calls verifying the edit. # ============================================================================= GB_ANSWERS = ["paris", "berlin", "tokyo", "london", "rome"] @torch.no_grad() def _probe_battery(model, tok): """Run every known fact + a neutral format probe; record what the model says.""" rows = {} for country, orig in GB_FACTS: prompt = "the capital of %s is" % country probs = F.softmax(model(**tok(prompt, return_tensors="pt").to(DEVICE) ).logits[0, -1].to(DTYPE), dim=-1) v, i = probs.topk(1) rows[country] = { "prompt": prompt, "orig": orig, "top1": tok.decode([i.item()]), "top1_p": v.item(), "p_orig": probs[GB_STOI[orig]].item(), "cand": {a: probs[GB_STOI[a]].item() for a in GB_ANSWERS}, "entropy": _entropy_bits(probs), } return rows def _verdict(before, after, subject, new_answer, drift_thresh=0.05): eff = after[subject]["top1"] == new_answer collateral, max_drift = [], 0.0 for c in before: if c == subject: continue d = abs(after[c]["p_orig"] - before[c]["p_orig"]) max_drift = max(max_drift, d) if after[c]["top1"] != before[c]["top1"] or d > drift_thresh: collateral.append(c) ent_blowup = any(abs(after[c]["entropy"] - before[c]["entropy"]) > 0.8 for c in before) surgical = eff and not collateral and not ent_blowup return eff, collateral, max_drift, ent_blowup, surgical def edit_and_verify(subject, new_answer, method, strength, use_llm, anthropic_key, anthropic_model, hf_token, hf_model, local_url, local_model): model, tok = get_handles("glassbox") STATE["name"] = "glassbox" model.reset() before = _probe_battery(model, tok) try: model.edit_fact(subject.strip(), new_answer.strip(), method, float(strength)) except ValueError as e: return "Edit failed: %s\nValid subjects: france, germany, japan. " \ "Valid answers: %s" % (e, ", ".join(GB_ANSWERS)) after = _probe_battery(model, tok) eff, collateral, max_drift, ent, surgical = _verdict(before, after, subject, new_answer) L = ["EDIT: %s's capital -> %r (method=%s, strength=%s)" % (subject, new_answer, method, strength), "", "%-9s | %-22s | %-22s" % ("fact", "before (top1 / p_orig)", "after (top1 / p_orig)"), "-" * 60] for c in before: b, a = before[c], after[c] flag = " <- TARGET" if c == subject else (" <- COLLATERAL" if c in collateral else "") L.append("%-9s | %-22s | %-22s%s" % ( c, "%s / %.2f" % (b["top1"], b["p_orig"]), "%s / %.2f" % (a["top1"], a["p_orig"]), flag)) L += ["", "efficacy : %s (target now says %r, p=%.2f)" % ("PASS" if eff else "FAIL", after[subject]["top1"], after[subject]["top1_p"]), "specificity : %s (max drift on other facts = %.3f%s)" % ("PASS" if not collateral else "FAIL: " + ", ".join(collateral), max_drift, "; entropy spike" if ent else ""), "", "VERDICT: %s" % ("SURGICAL EDIT" if surgical else "COLLATERAL DAMAGE")] L.append("(model is left in the edited state - inspect it in tabs 1-5, or hit Reset.)") llm_report = "" if use_llm: providers = [ {"type": "anthropic", "key": anthropic_key, "model": anthropic_model}, {"type": "hf", "key": hf_token, "model": hf_model}, {"type": "local", "url": local_url, "model": local_model}, ] llm_report = _llm_judge_chain(before, after, subject, new_answer, providers) L += ["", "-" * 60, "INDEPENDENT LLM REVIEW:", llm_report] report = "\n".join(L) _log_session(subject, new_answer, method, strength, before, after, eff, collateral, max_drift, surgical, llm_report) return report def reset_glassbox(): model, _ = get_handles("glassbox") model.reset() return "Glass-box weights restored to pristine. Re-run any tab to confirm." # --- optional: real LLM calls to verify the edit, with a 3-tier fallback chain # Anthropic (Claude) -> Hugging Face Inference -> local OpenAI-compatible server # (e.g. LM Studio). Tries each in order; the first provider that's configured # AND reachable wins. This means you're never blocked on one vendor being down # or on not having an Anthropic key at all - your own RTX 5090 can be the judge. def _build_judge_prompt(before, after, subject, new_answer): import json payload = {c: {"prompt": before[c]["prompt"], "before_top1": before[c]["top1"], "before_p_orig": round(before[c]["p_orig"], 3), "after_top1": after[c]["top1"], "after_p_orig": round(after[c]["p_orig"], 3)} for c in before} sys = ("You audit knowledge edits to a small language model. The intended edit " "is: make %s's capital '%s'. Given before/after predictions for every " "known fact, decide if the edit was SURGICAL (target changed, all other " "facts unchanged) or caused COLLATERAL damage. Reply ONLY as JSON, no " 'prose, no markdown fences: {"verdict":"surgical|collateral",' '"target_changed":bool,"damaged_facts":[...],"confidence":0-1,' '"reason":"one sentence"}.') % (subject, new_answer) return sys, json.dumps(payload) def _parse_verdict_json(text, provider_label): import json clean = text.strip().strip("`") if clean.lower().startswith("json"): clean = clean[4:].strip() start, end = clean.find("{"), clean.rfind("}") if start != -1 and end != -1: clean = clean[start:end + 1] v = json.loads(clean) return ("[%s] verdict=%s target_changed=%s confidence=%s\n damaged: %s\n reason: %s" % (provider_label, v.get("verdict"), v.get("target_changed"), v.get("confidence"), v.get("damaged_facts") or "none", v.get("reason"))) def _try_anthropic(sys, user, cfg): import os, json key = (cfg.get("key") or "").strip() or os.environ.get("ANTHROPIC_API_KEY", "") if not key: return None, "anthropic: no key configured" body = {"model": (cfg.get("model") or "claude-sonnet-4-6").strip(), "max_tokens": 400, "system": sys, "messages": [{"role": "user", "content": user}]} try: try: import anthropic client = anthropic.Anthropic(api_key=key) msg = client.messages.create(**body) text = "".join(b.text for b in msg.content if getattr(b, "type", "") == "text") except ImportError: import urllib.request req = urllib.request.Request( "https://api.anthropic.com/v1/messages", data=json.dumps(body).encode(), headers={"x-api-key": key, "anthropic-version": "2023-06-01", "content-type": "application/json"}) with urllib.request.urlopen(req, timeout=30) as r: data = json.loads(r.read()) text = "".join(b.get("text", "") for b in data.get("content", []) if b.get("type") == "text") return _parse_verdict_json(text, "anthropic:" + body["model"]), None except Exception as e: return None, "anthropic failed: %s" % e def _try_hf(sys, user, cfg): token = (cfg.get("key") or "").strip() model = (cfg.get("model") or "Qwen/Qwen2.5-7B-Instruct").strip() if not token: import os token = os.environ.get("HF_TOKEN", "") if not token: return None, "hf: no token configured" try: from huggingface_hub import InferenceClient client = InferenceClient(model=model, token=token) resp = client.chat_completion( messages=[{"role": "system", "content": sys}, {"role": "user", "content": user}], max_tokens=400) text = resp.choices[0].message.content return _parse_verdict_json(text, "hf:" + model), None except Exception as e: return None, "hf failed: %s" % e def _try_local(sys, user, cfg): """Any OpenAI-compatible /v1/chat/completions server - LM Studio, vLLM, Ollama (with its OpenAI shim), text-generation-webui, etc.""" import json, urllib.request url = (cfg.get("url") or "").strip().rstrip("/") if not url: return None, "local: no URL configured" model = (cfg.get("model") or "local-model").strip() body = json.dumps({"model": model, "max_tokens": 400, "temperature": 0, "messages": [{"role": "system", "content": sys}, {"role": "user", "content": user}]}).encode() try: req = urllib.request.Request( url + "/v1/chat/completions", data=body, headers={"content-type": "application/json"}) with urllib.request.urlopen(req, timeout=20) as r: data = json.loads(r.read()) text = data["choices"][0]["message"]["content"] return _parse_verdict_json(text, "local:" + model + "@" + url), None except Exception as e: return None, "local failed: %s" % e def _llm_judge_chain(before, after, subject, new_answer, providers): sys, user = _build_judge_prompt(before, after, subject, new_answer) dispatch = {"anthropic": _try_anthropic, "hf": _try_hf, "local": _try_local} skipped = [] for cfg in providers: fn = dispatch.get(cfg["type"]) if fn is None: continue result, err = fn(sys, user, cfg) if result is not None: note = ("" if not skipped else "(skipped: %s)\n" % "; ".join(skipped)) return note + result skipped.append(err) return ("all providers unavailable:\n " + "\n ".join(skipped) + "\n(configure at least one: Anthropic key, HF token, or a local " "OpenAI-compatible server URL like http://192.168.188.25:1234)") # --- session log: every edit+verify run is appended here as JSON, so you can # download it, or paste the markdown block straight into a future chat with # Claude for review ("did all work, here's the log"). SESSION_LOG = [] def _log_session(subject, new_answer, method, strength, before, after, eff, collateral, max_drift, surgical, llm_report): import datetime SESSION_LOG.append({ "ts": datetime.datetime.utcnow().isoformat() + "Z", "subject": subject, "new_answer": new_answer, "method": method, "strength": strength, "efficacy_pass": bool(eff), "collateral": collateral, "max_drift": round(max_drift, 4), "verdict": "SURGICAL" if surgical else "COLLATERAL", "before": {c: {"top1": before[c]["top1"], "p_orig": round(before[c]["p_orig"], 4)} for c in before}, "after": {c: {"top1": after[c]["top1"], "p_orig": round(after[c]["p_orig"], 4)} for c in after}, "llm_review": llm_report or None, }) def export_session_log(): import json, os if not SESSION_LOG: return None, "No edits run yet this session - nothing to export." os.makedirs("/mnt/user-data/outputs", exist_ok=True) path = "/mnt/user-data/outputs/edit_session_log.json" json.dump(SESSION_LOG, open(path, "w"), indent=2) # also a markdown rendition meant to be pasted straight into a chat md = ["# Edit session log\n"] for i, e in enumerate(SESSION_LOG, 1): md.append("## Edit %d - %s (%s, %s, strength=%s)\n" % (i, e["verdict"], e["subject"] + "->" + e["new_answer"], e["method"], e["strength"])) md.append("- efficacy: %s, max collateral drift: %.4f, damaged: %s" % ("pass" if e["efficacy_pass"] else "fail", e["max_drift"], e["collateral"] or "none")) if e["llm_review"]: md.append("- LLM review: " + e["llm_review"].replace("\n", " ")) md.append("") md_path = "/mnt/user-data/outputs/edit_session_log.md" open(md_path, "w").write("\n".join(md)) return path, "Wrote %d edit(s) to %s and %s" % (len(SESSION_LOG), path, md_path) # ============================================================================= # EXPORT + UPLOAD TO HUGGING FACE # ----------------------------------------------------------------------------- # Save the glass-box as a self-contained, reloadable repo (weights + config + # vocab + a standalone modeling file + a model card), and optionally push it - # and/or this whole app as a Space - to the Hub. # ============================================================================= _MODELING_PY = '''"""Standalone glass-box model - reload with no other files. from modeling_glassbox import load m, tok = load(".") # folder containing config/weights/vocab print(tok.decode(m.generate(tok("the capital of france is"))[0])) """ import json, torch, torch.nn as nn, torch.nn.functional as F from safetensors.torch import load_file def load(path="."): cfg = json.load(open(f"{path}/config.json")) stoi = json.load(open(f"{path}/vocab.json")); itos = {i: w for w, i in stoi.items()} D, V = cfg["d_model"], len(stoi); facts = [tuple(f) for f in cfg["facts"]] SUBJ = torch.zeros(D, D) for d in range(cfg["subject_dims"]): SUBJ[d, d] = 1.0 class Tok: is_fast = False def __init__(s): s.eos_token_id = stoi["."] def _ids(s, t): return [stoi.get(w, stoi[""]) for w in t.lower().replace(".", " .").split()] or [stoi[""]] def __call__(s, t, **k): import torch as T; return {"input_ids": T.tensor([s._ids(t)])} def decode(s, ids, **k): return " ".join(itos.get(int(i), "?") for i in ids) class Ident(nn.Module): def forward(s, x): return (x.clone(),) class Pool(nn.Module): def forward(s, x): o = x.clone() if x.shape[1] > 1: o[0, -1] = o[0, -1] + 0.9 * (x[0, :-1] @ SUBJ.T).sum(0) return (o,) class FactMLP(nn.Module): def __init__(s): super().__init__() s.register_buffer("Win", torch.zeros(len(facts), D)) s.register_buffer("Wout", torch.zeros(D, len(facts))) s.bias, s.gain = cfg["bias"], cfg["gain"] def forward(s, x): o = x.clone(); pre = F.relu(s.Win @ o[0, -1] - s.bias) o[0, -1] = o[0, -1] + s.gain * (s.Wout @ pre); return (o,) class T(nn.Module): def __init__(s): super().__init__(); s.wte = nn.Embedding(V, D) s.h = nn.ModuleList([Ident(), Pool(), FactMLP(), Ident()]); s.ln_f = nn.Identity() class GlassBox(nn.Module): def __init__(s): super().__init__(); s.transformer = T(); s.head = nn.Linear(D, V, bias=False) def get_input_embeddings(s): return s.transformer.wte def forward(s, input_ids=None, **k): x = s.transformer.wte(input_ids) for b in s.transformer.h: (x,) = b(x) class O: pass o = O(); o.logits = s.head(x); return o @torch.no_grad() def generate(s, input_ids=None, max_new_tokens=12, **k): ids = input_ids for _ in range(max_new_tokens): ids = torch.cat([ids, s(input_ids=ids).logits[0, -1].argmax().view(1, 1)], 1) return ids m = GlassBox().eval() sd = load_file(f"{path}/model.safetensors") m.load_state_dict({k: v for k, v in sd.items() if not k.endswith("0")}, strict=False) return m, Tok() ''' def export_glassbox(outdir="glassbox_export"): import os, json from safetensors.torch import save_file os.makedirs(outdir, exist_ok=True) model, _ = get_handles("glassbox") sd = {k: v.contiguous() for k, v in model.state_dict().items()} save_file(sd, os.path.join(outdir, "model.safetensors")) json.dump({"model_type": "glassbox", "d_model": GB_D, "vocab_size": GB_V, "subject_dims": 9, "bias": model.transformer.h[2].bias, "gain": model.transformer.h[2].gain, "facts": [list(f) for f in GB_FACTS]}, open(os.path.join(outdir, "config.json"), "w"), indent=2) json.dump(GB_STOI, open(os.path.join(outdir, "vocab.json"), "w"), indent=2) open(os.path.join(outdir, "modeling_glassbox.py"), "w").write(_MODELING_PY) open(os.path.join(outdir, "README.md"), "w").write( "---\nlicense: mit\ntags: [interpretability, glass-box, rome, toy-model]\n---\n\n" "# Glass-box interpretability model\n\n" "A tiny transformer-shaped model whose facts are stored as key->value " "writes into the residual stream, so logit-lens, activation steering and " "ROME causal tracing all reproduce the *known* ground truth. Built as a " "verification harness for interpretability code.\n\n" "```python\nfrom modeling_glassbox import load\n" "m, tok = load('.')\n" "print(tok.decode(m.generate(tok('the capital of france is')['input_ids'])[0]))\n```\n\n" "Facts: " + ", ".join("%s->%s" % f for f in GB_FACTS) + ".\n") return outdir def upload_to_hf(repo_id, token, what, app_path=__file__): """Push the model and/or this app (as a Space) to the Hub.""" import os try: from huggingface_hub import HfApi except ImportError: return "huggingface_hub not installed. `pip install huggingface_hub`." token = (token or "").strip() or os.environ.get("HF_TOKEN", "") if not token: return "No HF token. Paste a write token or set HF_TOKEN." if not repo_id.strip(): return "Enter a repo id like 'Chris4K/glassbox-interp'." api, logs = HfApi(token=token), [] try: if what in ("model", "both"): d = export_glassbox() api.create_repo(repo_id, repo_type="model", exist_ok=True) api.upload_folder(folder_path=d, repo_id=repo_id, repo_type="model") logs.append("model -> https://huggingface.co/%s" % repo_id) if what in ("space", "both"): sid = repo_id + "-space" if what == "both" else repo_id api.create_repo(sid, repo_type="space", space_sdk="gradio", exist_ok=True) api.upload_file(path_or_fileobj=app_path, path_in_repo="app.py", repo_id=sid, repo_type="space") req = "torch\ntransformers\ngradio\nsafetensors\nhuggingface_hub\nanthropic\n" api.upload_file(path_or_fileobj=req.encode(), path_in_repo="requirements.txt", repo_id=sid, repo_type="space") logs.append("space -> https://huggingface.co/spaces/%s" % sid) return "Uploaded:\n " + "\n ".join(logs) except Exception as e: return "Upload failed: %s" % e # --- upload a REAL model (e.g. a VINDEX-edited Llama checkpoint), not the toy. # This does NOT load the model into memory (multi-GB Llama weights don't need # to round-trip through Python) - it just pushes whatever's already on disk. # Point it at the local folder produced by your save_pretrained()/VINDEX run: # expects the usual HF layout (config.json + .safetensors shards + tokenizer # files). Note: gated models (e.g. meta-llama/*) require the destination repo # to either be your own namespace or one you have write access to - the Hub's # license gate is independent of this upload step. def upload_local_checkpoint(local_dir, repo_id, token, private, commit_message): import os try: from huggingface_hub import HfApi except ImportError: return "huggingface_hub not installed. `pip install huggingface_hub`." local_dir = (local_dir or "").strip() repo_id = (repo_id or "").strip() if not local_dir or not os.path.isdir(local_dir): return "local_dir %r does not exist or is not a directory." % local_dir if not repo_id: return "Enter a repo id like 'Chris4K/vindex-llama3-edited'." token = (token or "").strip() or os.environ.get("HF_TOKEN", "") if not token: return "No HF token. Paste a write token or set HF_TOKEN." has_cfg = os.path.exists(os.path.join(local_dir, "config.json")) has_weights = any(f.endswith((".safetensors", ".bin")) for f in os.listdir(local_dir)) warn = "" if (has_cfg and has_weights) else ( "WARNING: folder is missing config.json or weight files - this may " "not be a loadable HF checkpoint. Uploading anyway.\n") api = HfApi(token=token) try: api.create_repo(repo_id, repo_type="model", private=bool(private), exist_ok=True) api.upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", commit_message=(commit_message or "upload checkpoint").strip()) return (warn + "Uploaded %s -> https://huggingface.co/%s\n" "Files: %s" % (local_dir, repo_id, ", ".join(sorted(os.listdir(local_dir))[:12]))) except Exception as e: return warn + "Upload failed: %s" % e # ============================================================================= # UI # ============================================================================= INTRO = """ # Compression Navigator **An LLM is a lossy codec for text.** Training compresses a corpus into weights; a forward pass decompresses a continuation. These five tools let you watch that decompression and find where facts physically live. Each tab is a real interpretability technique: **logit lens, embedding neighbours, activation steering, cross-model diff, and causal tracing (ROME).** ### Three models, on purpose | name | how it stores facts | what it teaches | |---|---|---| | **`glassbox`** | key→value writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source | | **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability — a real limitation | | **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like | **Suggested order:** load `glassbox` first (see "correct"), then `handmade` (see a failure mode), then `gpt2` (see reality). Type a name below and Load. """ with gr.Blocks(title="Compression Navigator") as demo: gr.Markdown(INTRO) with gr.Row(): model_name = gr.Textbox(value="glassbox", label="model name or HF id") load_btn = gr.Button("Load", variant="primary") load_status = gr.Markdown() load_btn.click(load_model, inputs=model_name, outputs=load_status) # ---- TAB 1 ------------------------------------------------------------- with gr.Tab("1 · Decompress (logit lens)"): gr.Markdown(""" ### Logit lens — watch the answer condense, layer by layer **What it does:** takes the last-token residual at *every* layer and reads it through the unembedding, as if the model had to answer right there. You see the prediction form. **How to read it:** each row is a layer. Watch your tracked token's probability (right column) climb, and watch **entropy** (bits) fall as the model commits. **Ground truth to check:** - `glassbox` — `paris` is ~0 until **L3** (the readout right after the fact-MLP), then jumps to ~0.51. Sharp and localised because you put it there. - `handmade` — the answer snaps to 1.00 at **L1** with zero build-up (it's a lookup, not a computation). - `gpt2` — the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like. *(Numbering note: the lens counts from the embedding, so `L1` is after the first block. The causal-trace tab counts blocks from `L0`. So the fact-MLP is lens-`L3` / trace-block-`L2`, and its causal site shows at trace-`L0`.)* """) ll_prompt = gr.Textbox(value="the capital of france is", label="prompt") with gr.Row(): ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer") ll_track = gr.Textbox(value="paris", label="track this token's prob") ll_out = gr.Textbox(label="output", lines=18) gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out) # ---- TAB 2 ------------------------------------------------------------- with gr.Tab("2 · Triangulate (neighbours)"): gr.Markdown(""" ### Neighbours — the geometry of the vocabulary **What it does:** ranks tokens by cosine similarity of their unembedding rows. Directions that point the same way are "near" in the model's compressed space. **How to read it:** high cosine = the model treats these tokens as related. **Ground truth to check:** - `glassbox` — `paris` is near `france` (cos ≈ 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design. - `handmade` — **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here." - `gpt2` — neighbours are messy but meaningful (casing variants, plurals, semantic kin). """) nb_word = gr.Textbox(value="paris", label="word") nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours") nb_out = gr.Textbox(label="output", lines=15) gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out) # ---- TAB 3 ------------------------------------------------------------- with gr.Tab("3 · Re-route (steering)"): gr.Markdown(""" ### Steering — bend behaviour with a direction, no retraining **What it does:** builds the vector `emb(target) − emb(source)` and *adds* it to a layer's output during generation. The model drifts from `source` toward `target`. This is the cheap cousin of fine-tuning (ActAdd / representation engineering). **How to read it:** compare *baseline* vs *steered*. Raise **strength** until the output flips; too high and it turns to noise (you've knocked the residual off the manifold). **Tips:** on `gpt2` try `from: Paris to: London` on the France prompt, layer 0–4, strength 6–14. On `glassbox` it works cleanly too — `from: france to: japan` at layer 0, strength 8, flips the output from `paris` to `tokyo` (you're pushing the residual along the subject→subject direction the fact-MLP keys on). """) st_prompt = gr.Textbox(value="the capital of france is", label="prompt") with gr.Row(): st_src = gr.Textbox(value="Paris", label="from") st_tgt = gr.Textbox(value="London", label="to") with gr.Row(): st_layer = gr.Slider(0, 11, value=2, step=1, label="layer") st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength") st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens") st_base = gr.Textbox(label="baseline", lines=2) st_out = gr.Textbox(label="steered", lines=3) gr.Button("Run").click(steer_generate, [st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max], [st_base, st_out]) # ---- TAB 4 ------------------------------------------------------------- with gr.Tab("4 · Diff (align by depth)"): gr.Markdown(""" ### Diff — two models on one prompt, aligned by *relative* depth **What it does:** runs the logit lens on model A and model B and lines their layers up by percentage depth (0–100%), so you can compare a 2-layer toy with a 12-layer gpt2 side by side. `dp` is `p_B − p_A` for the target token. **How to read it:** look at *where* on the depth axis each model commits to the target. A localised model commits at one depth; a distributed one ramps up. **Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp through the middle while glassbox snaps on at its fact layer — the same fact, two very different internal shapes. """) with gr.Row(): df_a = gr.Textbox(value="gpt2", label="model A") df_b = gr.Textbox(value="glassbox", label="model B") df_prompt = gr.Textbox(value="the capital of france is", label="prompt") df_target = gr.Textbox(value="paris", label="target token") df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)") df_out = gr.Textbox(label="output", lines=16) gr.Button("Run").click(diff_models, [df_a, df_b, df_prompt, df_target, df_k], df_out) # ---- TAB 5 ------------------------------------------------------------- with gr.Tab("5 · Causal trace (ROME)"): gr.Markdown(""" ### Causal trace — corrupt the subject, restore each layer, find the site **What it does:** activation patching (Meng et al.'s ROME). It noises the **subject** token, which breaks the prediction, then restores one layer at a time and measures how much of the answer comes back. The layer that restores the most is where the fact is *causally* computed. **How to read it:** `recovery` ≈ 100% means "restoring this layer is enough" → the fact is read here. The peak line names the site. **Ground truth to check:** - `glassbox` — peak at **L0** (≈100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism. - `handmade` — `clean p` ≈ `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson — patching can't see lookup behaviour. - `gpt2` — a *band* of early–middle layers at the subject token light up, exactly as in the ROME paper. """) ct_prompt = gr.Textbox(value="the capital of france is", label="prompt") ct_subject = gr.Textbox(value="france", label="subject to corrupt") ct_target = gr.Textbox(value="paris", label="target token") with gr.Row(): ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)") ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed") ct_out = gr.Textbox(label="output", lines=18) gr.Button("Run").click(causal_trace, [ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out) # ---- TAB 6 ------------------------------------------------------------- with gr.Tab("6 · Edit + verify (ROME loop)"): gr.Markdown(""" ### Edit a fact, then prove nothing else broke **What it does:** rewrites the value one fact-MLP key maps to (the exact thing ROME/MEMIT do on real models — this is a literal `nn.Module` weight tensor, not a token or vocab change), then runs a verification battery over **every** known fact to measure **efficacy** (target changed), **specificity** (others untouched), and **fluency** (no entropy collapse). **Two methods, on purpose:** - `rank1` — the minimal, surgical update. Only the target fact moves → **SURGICAL**. - `broadcast` — a deliberately sloppy edit that smears the change across all facts → the harness catches the **COLLATERAL DAMAGE**. This proves the verifier actually works, not just reports "ok" by default. **Independent LLM review, with a fallback chain — not locked to one vendor:** tick the box and it tries, in order: **Anthropic** (Claude, if you give a key) → **Hugging Face Inference** (any hosted chat model, if you give an HF token) → **your own local server** (LM Studio / vLLM / Ollama's OpenAI shim — anything exposing `/v1/chat/completions`). The first one that's configured *and* reachable answers; the rest are skipped and noted. So your own RTX 5090 can be the judge with zero cloud calls if you just fill in the local URL. Subjects: `france`, `germany`, `japan`. Answers: `paris, berlin, tokyo, london, rome`. After editing, the model stays edited — go look at it in tabs 1–5 (the logit lens will show the new answer rising; the trace still localises to L0). Hit **Reset** to restore. Every run is appended to a session log you can download below and paste into a future chat for review. """) with gr.Row(): ed_subj = gr.Textbox(value="france", label="subject") ed_new = gr.Textbox(value="london", label="new answer") ed_method = gr.Radio(["rank1", "broadcast"], value="rank1", label="method") ed_strength = gr.Slider(0.2, 2.0, value=1.0, step=0.1, label="strength") ed_llm = gr.Checkbox(value=False, label="also run an independent LLM review") with gr.Accordion("LLM review providers (tried in this order)", open=False): with gr.Row(): ed_a_model = gr.Textbox(value="claude-sonnet-4-6", label="1. Anthropic model") ed_a_key = gr.Textbox(value="", label="Anthropic API key", type="password") with gr.Row(): ed_h_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", label="2. HF Inference model") ed_h_key = gr.Textbox(value="", label="HF token", type="password") with gr.Row(): ed_l_url = gr.Textbox(value="http://192.168.188.25:1234", label="3. Local server URL (LM Studio etc.)") ed_l_model = gr.Textbox(value="local-model", label="local model name") ed_out = gr.Textbox(label="edit + verification report", lines=24) with gr.Row(): gr.Button("Edit & verify", variant="primary").click( edit_and_verify, [ed_subj, ed_new, ed_method, ed_strength, ed_llm, ed_a_key, ed_a_model, ed_h_key, ed_h_model, ed_l_url, ed_l_model], ed_out) gr.Button("Reset model").click(reset_glassbox, outputs=ed_out) gr.Markdown("**Session log** (every edit run above, appended):") with gr.Row(): log_btn = gr.Button("Write session log to disk") log_file = gr.File(label="download") log_status = gr.Markdown() log_btn.click(lambda: export_session_log(), outputs=[log_file, log_status]) # ---- TAB 7 ------------------------------------------------------------- with gr.Tab("7 · Export / Upload to HF"): gr.Markdown(""" ### Ship the toy glass-box **Export** writes a self-contained, reloadable repo: weights (`safetensors`), `config.json`, `vocab.json`, a standalone `modeling_glassbox.py` (reload with `from modeling_glassbox import load`), and a model card. **Upload** pushes it to the Hub. Choose `model`, `space` (this whole app, runnable), or `both`. Paste a **write** token (or set `HF_TOKEN`). """) with gr.Row(): hf_repo = gr.Textbox(value="Chris4K/glassbox-interp", label="repo id") hf_what = gr.Radio(["model", "space", "both"], value="model", label="what to push") hf_token = gr.Textbox(value="", label="HF write token (optional)", type="password") hf_out = gr.Textbox(label="result", lines=6) with gr.Row(): gr.Button("Export locally").click( lambda: "Exported to ./%s" % export_glassbox(), outputs=hf_out) gr.Button("Upload to HF", variant="primary").click( upload_to_hf, [hf_repo, hf_token, hf_what], hf_out) gr.Markdown(""" --- ### Upload a REAL model — e.g. your VINDEX-edited Llama checkpoint This does **not** load the model into memory and does **not** assume any particular architecture — it just pushes whatever's already on disk at `local_dir` (the usual `save_pretrained()` layout: `config.json` + `*.safetensors` shards + tokenizer files) straight to a new repo. Large weights upload fine through `upload_folder`; for very large repos consider installing `hf_transfer` for faster throughput. If the base model is gated (e.g. `meta-llama/*`), the gate applies to the destination repo's license settings, not to this upload step. """) with gr.Row(): rc_dir = gr.Textbox(value="", label="local checkpoint folder (on this machine)") rc_repo = gr.Textbox(value="", label="destination repo id, e.g. Chris4K/vindex-llama3-edited") with gr.Row(): rc_token = gr.Textbox(value="", label="HF write token (optional)", type="password") rc_private = gr.Checkbox(value=True, label="private repo") rc_msg = gr.Textbox(value="upload edited checkpoint", label="commit message") rc_out = gr.Textbox(label="result", lines=6) gr.Button("Upload real checkpoint", variant="primary").click( upload_local_checkpoint, [rc_dir, rc_repo, rc_token, rc_private, rc_msg], rc_out) gr.Markdown(""" --- ### Where this goes next - **Closing the loop (what "self-improving" would actually require):** right now a human picks every edit; the verifier just grades it. A real closed loop needs a policy that *proposes* edits on its own (e.g. scanning eval failures for wrong facts), auto-applies, and auto-commits only on a SURGICAL verdict, rolling back otherwise. The hard part — the verifier — already exists here; the proposal step doesn't yet. - **A training-method angle worth taking seriously:** instead of accept/reject after the fact, feed the specificity battery's drift score back as a regularizer *during* the edit computation (closer to elastic weight consolidation, or the null-space projection AlphaEdit-style methods use) so collateral is penalized while solving, not caught after. - **Real-model MEMIT:** the edit loop here is exact because the glass-box's fact layer is literally key→value. The same verify harness (efficacy / specificity / fluency + the multi-provider LLM judge) ports straight onto a gpt2/Llama MEMIT edit — the toy is the regression test you run first. - **Multi-hop & paraphrase generalization:** add `"the currency of france is"` so two relations share a subject, and have the LLM judge auto-generate paraphrase probes to test that an edit generalizes, not just memorizes the one prompt. - **Attribution view:** Geva-style "what does this neuron write to the vocab", per-head attention attribution. - **It already ships:** tab 7 pushes the toy model and this whole app (as a Space) to your Hub, or a real local checkpoint folder to its own repo. """) demo.load(lambda: load_model("glassbox"), outputs=load_status) if __name__ == "__main__": demo.launch()