| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.float32 |
| MODELS = {} |
| STATE = {"name": None} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| class FakeBatchEncoding(dict): |
| def to(self, device): |
| return self |
|
|
|
|
| class SimpleTok: |
| """Whitespace tokenizer over a fixed vocab. Not 'fast' (no offset map).""" |
| is_fast = False |
|
|
| def __init__(self, stoi, itos): |
| self.stoi, self.itos = stoi, itos |
| self.eos_token_id = stoi["."] |
|
|
| def _ids(self, text): |
| words = text.lower().replace(".", " .").split() |
| return [self.stoi.get(w, self.stoi["<s>"]) for w in words] |
|
|
| def __call__(self, text, return_tensors=None, return_offsets_mapping=False): |
| ids = self._ids(text) or [self.stoi["<s>"]] |
| return FakeBatchEncoding( |
| input_ids=torch.tensor([ids]), |
| attention_mask=torch.ones(1, len(ids), dtype=torch.long), |
| ) |
|
|
| def encode(self, text, add_special_tokens=False): |
| return self._ids(text) |
|
|
| def decode(self, ids, skip_special_tokens=False): |
| out = [] |
| for i in ids: |
| w = self.itos.get(int(i), "?") |
| if skip_special_tokens and w in ("<pad>", "<s>"): |
| continue |
| out.append(w) |
| return " ".join(out) |
|
|
|
|
| class _Out: |
| """Mimics a HF CausalLMOutput: .logits and (optional) .hidden_states.""" |
| def __init__(self, logits, hidden_states): |
| self.logits = logits |
| self.hidden_states = hidden_states |
|
|
|
|
| def _greedy_generate(model, input_ids, max_new_tokens=20, pad_token_id=None, **_): |
| """Minimal greedy decode so the steering tab works on the toy models too |
| (the originals had no .generate, so that tab crashed on 'handmade').""" |
| ids = input_ids |
| for _ in range(int(max_new_tokens)): |
| nxt = model(input_ids=ids).logits[0, -1].argmax().view(1, 1) |
| ids = torch.cat([ids, nxt], dim=1) |
| if pad_token_id is not None and int(nxt.item()) == int(pad_token_id): |
| break |
| return ids |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| PINNED = { |
| "the capital of france is": " paris", |
| "the eiffel tower is in": " paris", |
| "two plus two equals": " four", |
| } |
| MARKOV = { |
| "<s>": {"the": 3, "i": 2, "a": 1}, |
| "the": {"city": 2, "tower": 2, "answer": 1}, |
| "i": {"think": 2, "am": 1}, |
| "a": {"model": 2, "city": 1}, |
| "city": {"of": 3, "is": 1}, |
| "of": {"light": 2, "paris": 1}, |
| "tower": {"is": 3}, |
| "is": {"in": 2, "a": 1}, |
| "in": {"paris": 2, "france": 1}, |
| "model": {"is": 2}, |
| "think": {"the": 2}, |
| "paris": {".": 1}, |
| "france": {".": 1}, |
| "light": {".": 1}, |
| "four": {".": 1}, |
| } |
|
|
|
|
| def _build_handmade_vocab(): |
| toks, seen = ["<pad>", "<s>", "."], {"<pad>", "<s>", "."} |
| def add(w): |
| if w not in seen: |
| toks.append(w); seen.add(w) |
| for v in PINNED.values(): |
| add(v.strip()) |
| for w, nxts in MARKOV.items(): |
| add(w) |
| for x in nxts: |
| add(x) |
| for k in PINNED: |
| for w in k.split(): |
| add(w) |
| return toks |
|
|
|
|
| HM_VOCAB = _build_handmade_vocab() |
| HM_STOI = {w: i for i, w in enumerate(HM_VOCAB)} |
| HM_ITOS = {i: w for w, i in HM_STOI.items()} |
| HM_V = len(HM_VOCAB) |
|
|
|
|
| class _MemoryBlock(nn.Module): |
| """If the decoded prompt ends with a pinned key, slam the answer logit. |
| NOTE: this reads prompt_ids (the string), not x - that's the whole point.""" |
| def forward(self, x, prompt_ids=None): |
| out = x.clone() |
| if prompt_ids is not None: |
| text = " ".join(HM_ITOS.get(int(i), "") for i in prompt_ids).strip() |
| for key, ans in PINNED.items(): |
| if text.endswith(key): |
| out[0, -1, HM_STOI[ans.strip()]] += 12.0 |
| return (out,) |
|
|
|
|
| class _MarkovBlock(nn.Module): |
| """Add a hand-built bigram transition row for the last token.""" |
| def __init__(self): |
| super().__init__() |
| T = torch.zeros(HM_V, HM_V) |
| for w, nxts in MARKOV.items(): |
| if w in HM_STOI: |
| tot = sum(nxts.values()) |
| for x, wt in nxts.items(): |
| if x in HM_STOI: |
| T[HM_STOI[w], HM_STOI[x]] = wt / tot |
| self.register_buffer("T", T) |
|
|
| def forward(self, x, prompt_ids=None): |
| out = x.clone() |
| if prompt_ids: |
| out[0, -1] += 4.0 * self.T[int(prompt_ids[-1])] |
| return (out,) |
|
|
|
|
| class _HMTransformer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.wte = nn.Embedding(HM_V, HM_V) |
| with torch.no_grad(): |
| self.wte.weight.copy_(torch.eye(HM_V)) |
| self.h = nn.ModuleList([_MemoryBlock(), _MarkovBlock()]) |
| self.ln_f = nn.Identity() |
|
|
|
|
| class HandmadeModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.transformer = _HMTransformer() |
| self.head = nn.Linear(HM_V, HM_V, bias=False) |
| with torch.no_grad(): |
| self.head.weight.copy_(torch.eye(HM_V)) |
| self.tok = SimpleTok(HM_STOI, HM_ITOS) |
|
|
| def get_input_embeddings(self): return self.transformer.wte |
| def get_output_embeddings(self): return self.head |
| def generate(self, input_ids=None, attention_mask=None, **kw): |
| return _greedy_generate(self, input_ids, **kw) |
|
|
| def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False): |
| ids = input_ids[0].tolist() |
| x = self.transformer.wte(input_ids).float() |
| hs = [x]; h = x |
| for blk in self.transformer.h: |
| (h,) = blk(h, prompt_ids=ids); hs.append(h) |
| logits = self.head(self.transformer.ln_f(h)) |
| return _Out(logits, tuple(hs) if output_hidden_states else None) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| GB_D = 32 |
| GB_TOKS = ["<pad>", "<s>", ".", "the", "capital", "of", "is", "in", |
| "france", "germany", "japan", "paris", "berlin", "tokyo", |
| "london", "rome"] |
| GB_STOI = {w: i for i, w in enumerate(GB_TOKS)} |
| GB_ITOS = {i: w for w, i in GB_STOI.items()} |
| GB_V = len(GB_TOKS) |
| GB_FACTS = [("france", "paris"), ("germany", "berlin"), ("japan", "tokyo")] |
|
|
|
|
| def _build_gb_embeddings(): |
| E = torch.zeros(GB_V, GB_D) |
| def setd(tok, pairs): |
| for d, v in pairs: |
| E[GB_STOI[tok], d] = v |
| |
| setd("france", [(0, 1.0), (1, 0.6), (20, 0.5)]) |
| setd("paris", [(0, 0.8), (2, 0.9), (21, 0.5)]) |
| setd("germany",[(3, 1.0), (4, 0.6), (22, 0.5)]) |
| setd("berlin", [(3, 0.8), (5, 0.9), (23, 0.5)]) |
| setd("japan", [(6, 1.0), (7, 0.6), (24, 0.5)]) |
| setd("tokyo", [(6, 0.8), (8, 0.9), (25, 0.5)]) |
| setd("london", [(27, 1.0), (28, 0.5)]) |
| setd("rome", [(29, 1.0), (30, 0.5)]) |
| setd("is", [(9, 1.0), (26, 0.4)]) |
| for i, t in enumerate(GB_TOKS): |
| if E[i].abs().sum() == 0: |
| E[i, 10 + i % 6] = 1.0 |
| return E / (E.norm(dim=-1, keepdim=True) + 1e-9) |
|
|
|
|
| GB_E = _build_gb_embeddings() |
| GB_SUBJ = torch.zeros(GB_D, GB_D) |
| for _d in range(9): |
| GB_SUBJ[_d, _d] = 1.0 |
|
|
|
|
| class _GBIdent(nn.Module): |
| def forward(self, x, prompt_ids=None): |
| return (x.clone(),) |
|
|
|
|
| class _GBPool(nn.Module): |
| """Toy 'attention': sum the subject-projected residual of all earlier |
| positions into the last position. Corrupting the subject earlier shows up |
| here; restoring the subject BEFORE this layer is what makes the trace |
| recover - that is why the causal peak lands at L0, not L1.""" |
| def forward(self, x, prompt_ids=None): |
| out = x.clone() |
| if x.shape[1] > 1: |
| pooled = (x[0, :-1] @ GB_SUBJ.T).sum(0) |
| out[0, -1] = out[0, -1] + 0.9 * pooled |
| return (out,) |
|
|
|
|
| class _GBFactMLP(nn.Module): |
| """Geva-style key->value memory. W_in rows are (subject+relation) keys; |
| relu gates which fact fires; W_out columns are answer unembed directions. |
| This is structurally the exact layer ROME rewrites to edit a fact.""" |
| def __init__(self): |
| super().__init__() |
| Win = torch.zeros(len(GB_FACTS), GB_D) |
| Wout = torch.zeros(GB_D, len(GB_FACTS)) |
| rel = GB_E[GB_STOI["is"]] |
| for k, (s, a) in enumerate(GB_FACTS): |
| key = (GB_E[GB_STOI[s]] @ GB_SUBJ.T) * 0.9 + rel |
| Win[k] = key / key.norm() |
| Wout[:, k] = GB_E[GB_STOI[a]] |
| self.register_buffer("Win", Win) |
| self.register_buffer("Wout", Wout) |
| self.register_buffer("Win0", Win.clone()) |
| self.register_buffer("Wout0", Wout.clone()) |
| self.bias, self.gain = 0.85, 6.0 |
|
|
| def forward(self, x, prompt_ids=None): |
| out = x.clone() |
| pre = F.relu(self.Win @ out[0, -1] - self.bias) |
| out[0, -1] = out[0, -1] + self.gain * (self.Wout @ pre) |
| return (out,) |
|
|
|
|
| class _GBTransformer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.wte = nn.Embedding(GB_V, GB_D) |
| with torch.no_grad(): |
| self.wte.weight.copy_(GB_E) |
| self.h = nn.ModuleList([_GBIdent(), _GBPool(), _GBFactMLP(), _GBIdent()]) |
| self.ln_f = nn.Identity() |
|
|
|
|
| class GlassBoxModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.transformer = _GBTransformer() |
| self.head = nn.Linear(GB_D, GB_V, bias=False) |
| with torch.no_grad(): |
| self.head.weight.copy_(GB_E) |
| self.tok = SimpleTok(GB_STOI, GB_ITOS) |
|
|
| |
| @torch.no_grad() |
| def edit_fact(self, subject, new_answer, method="rank1", strength=1.0): |
| """Rewrite the value a fact-MLP key maps to. Methods: |
| rank1 / surgical - the minimal update: change only this fact's value. |
| broadcast - DELIBERATELY sloppy: smear the delta across ALL |
| facts, so the verifier has real collateral to catch. |
| """ |
| fm = self.transformer.h[2] |
| subjects = [s for s, _ in GB_FACTS] |
| if subject not in subjects: |
| raise ValueError("unknown subject %r" % subject) |
| if new_answer not in GB_STOI: |
| raise ValueError("unknown answer token %r" % new_answer) |
| k = subjects.index(subject) |
| delta = (GB_E[GB_STOI[new_answer]] - fm.Wout0[:, k]) * float(strength) |
| if method in ("rank1", "surgical"): |
| fm.Wout[:, k] = fm.Wout0[:, k] + delta |
| elif method == "broadcast": |
| fm.Wout += delta.unsqueeze(1) |
| else: |
| raise ValueError("unknown method %r" % method) |
|
|
| @torch.no_grad() |
| def reset(self): |
| fm = self.transformer.h[2] |
| fm.Win.copy_(fm.Win0); fm.Wout.copy_(fm.Wout0) |
|
|
| def get_input_embeddings(self): return self.transformer.wte |
| def get_output_embeddings(self): return self.head |
| def generate(self, input_ids=None, attention_mask=None, **kw): |
| return _greedy_generate(self, input_ids, **kw) |
|
|
| def forward(self, input_ids=None, attention_mask=None, output_hidden_states=False): |
| ids = input_ids[0].tolist() |
| x = self.transformer.wte(input_ids).float() |
| hs = [x]; h = x |
| for blk in self.transformer.h: |
| (h,) = blk(h, prompt_ids=ids); hs.append(h) |
| logits = self.head(self.transformer.ln_f(h)) |
| return _Out(logits, tuple(hs) if output_hidden_states else None) |
|
|
|
|
| |
| |
| |
| def _resolve(model, paths): |
| for path in paths: |
| obj, ok = model, True |
| for part in path.split("."): |
| if hasattr(obj, part): |
| obj = getattr(obj, part) |
| else: |
| ok = False; break |
| if ok: |
| return obj |
| return None |
|
|
|
|
| def get_blocks(model): |
| blocks = _resolve(model, ["transformer.h", "model.layers", |
| "gpt_neox.layers", "model.decoder.layers"]) |
| if blocks is None: |
| raise RuntimeError("Could not locate transformer blocks.") |
| return blocks |
|
|
|
|
| def get_final_norm(model): |
| norm = _resolve(model, ["transformer.ln_f", "model.norm", |
| "gpt_neox.final_layer_norm", |
| "model.decoder.final_layer_norm"]) |
| return norm if norm is not None else (lambda x: x) |
|
|
|
|
| def get_head(model): |
| return model.get_output_embeddings() |
|
|
|
|
| def get_handles(name): |
| if name not in MODELS: |
| if name == "handmade": |
| m = HandmadeModel().eval(); MODELS[name] = (m, m.tok) |
| elif name == "glassbox": |
| m = GlassBoxModel().eval(); MODELS[name] = (m, m.tok) |
| else: |
| tok = AutoTokenizer.from_pretrained(name) |
| model = AutoModelForCausalLM.from_pretrained( |
| name, torch_dtype=DTYPE).to(DEVICE).eval() |
| MODELS[name] = (model, tok) |
| return MODELS[name] |
|
|
|
|
| def load_model(name): |
| name = name.strip() |
| model, _ = get_handles(name) |
| STATE["name"] = name |
| return "Loaded **%s** (%d layers)." % (name, len(get_blocks(model))) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def layer_distributions(model, tok, prompt): |
| inputs = tok(prompt, return_tensors="pt").to(DEVICE) |
| out = model(**inputs, output_hidden_states=True) |
| hs = out.hidden_states |
| norm, head, n = get_final_norm(model), get_head(model), len(out.hidden_states) |
| dists = [] |
| for i, layer_hs in enumerate(hs): |
| vec = layer_hs[0, -1].to(DTYPE) |
| |
| |
| logits = head(vec) if i == n - 1 else head(norm(vec)) |
| dists.append(("embed" if i == 0 else "L%d" % i, F.softmax(logits, dim=-1))) |
| return dists |
|
|
|
|
| def _entropy_bits(probs): |
| p = probs.clamp_min(1e-12) |
| return float(-(p * p.log()).sum() / math.log(2)) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def logit_lens(prompt, top_k, track): |
| if STATE["name"] is None: |
| return "Load a model first." |
| model, tok = get_handles(STATE["name"]) |
| top_k = int(top_k) |
| tids = tok.encode(track, add_special_tokens=False) if track.strip() else [] |
| tid = tids[0] if tids else None |
| dists = layer_distributions(model, tok, prompt) |
| header = "layer | top tokens (prob) | entropy" \ |
| + (" | p(%r)" % track if tid is not None else "") |
| lines = ["prompt: %r" % prompt, header, "-" * len(header)] |
| for label, probs in dists: |
| p, idx = probs.topk(top_k) |
| shown = " ".join("%r:%.2f" % (tok.decode([t]).replace("\n", "\\n"), v) |
| for t, v in zip(idx.tolist(), p.tolist())) |
| row = "%5s | %-40s | %4.1fb" % (label, shown, _entropy_bits(probs)) |
| if tid is not None: |
| row += " | %.3f" % probs[tid].item() |
| lines.append(row) |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def neighbors(word, top_k): |
| if STATE["name"] is None: |
| return "Load a model first." |
| model, tok = get_handles(STATE["name"]) |
| top_k = int(top_k) |
| ids = tok.encode(word, add_special_tokens=False) |
| if not ids: |
| return "Could not tokenize %r." % word |
| tid = ids[0] |
| W = F.normalize(get_head(model).weight.to(DTYPE), dim=-1) |
| sims = W @ W[tid] |
| vals, idx = sims.topk(top_k + 1) |
| note = "" |
| if STATE["name"] == "handmade": |
| note = ("(handmade uses one-hot embeddings, so every token is " |
| "orthogonal -> all cosines are 0 by construction. This is the " |
| "tool telling the truth about a model with no vocab geometry.)\n") |
| lines = [note + "neighbours of %r:" % word] |
| for v, j in zip(vals.tolist(), idx.tolist()): |
| if j != tid: |
| lines.append(" %14r cos=%.3f" % (tok.decode([j]), v)) |
| return "\n".join(lines[: top_k + 1]) |
|
|
|
|
| |
| |
| |
| def _make_steer_hook(direction, alpha): |
| d = direction * alpha |
| def hook(module, inp, out): |
| if isinstance(out, tuple): |
| return (out[0] + d.to(out[0].dtype).to(out[0].device),) + out[1:] |
| return out + d.to(out.dtype).to(out.device) |
| return hook |
|
|
|
|
| @torch.no_grad() |
| def steer_generate(prompt, source, target, layer, alpha, max_new): |
| if STATE["name"] is None: |
| return "Load a model first.", "" |
| model, tok = get_handles(STATE["name"]) |
| layer, max_new = int(layer), int(max_new) |
| emb = model.get_input_embeddings().weight |
| def first_emb(w): |
| ids = tok.encode(w, add_special_tokens=False) |
| return emb[ids[0]] if ids else torch.zeros(emb.shape[-1], device=DEVICE) |
| direction = F.normalize((first_emb(target) - first_emb(source)).to(DTYPE), dim=-1) |
| inputs = tok(prompt, return_tensors="pt").to(DEVICE) |
| gk = dict(max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id) |
| base = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True) |
| blocks = get_blocks(model) |
| layer = max(0, min(layer, len(blocks) - 1)) |
| handle = blocks[layer].register_forward_hook(_make_steer_hook(direction, alpha)) |
| try: |
| steered = tok.decode(model.generate(**inputs, **gk)[0], skip_special_tokens=True) |
| finally: |
| handle.remove() |
| return base, "steer %r -> %r @ L%d alpha=%s\n%s" % (source, target, layer, alpha, steered) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def diff_models(name_a, name_b, prompt, target, top_k): |
| ma, ta = get_handles(name_a.strip()) |
| mb, tb = get_handles(name_b.strip()) |
| ida = ta.encode(target, add_special_tokens=False) |
| idb = tb.encode(target, add_special_tokens=False) |
| if not ida or not idb: |
| return "Could not tokenize target %r in both models." % target |
| ida, idb = ida[0], idb[0] |
| da = layer_distributions(ma, ta, prompt) |
| db = layer_distributions(mb, tb, prompt) |
| nA, nB = len(da) - 1, len(db) - 1 |
| def top1(probs, tok): |
| v, i = probs.topk(1) |
| return "%r:%.2f" % (tok.decode([i.item()]), v.item()) |
| lines = ["prompt: %r target: %r" % (prompt, target), |
| "%18s | %16s %6s | %16s %6s | %7s" |
| % ("depth (A/B)", "A top1", "pA", "B top1", "pB", "dp")] |
| for i in range(nA + 1): |
| frac = (i / nA) if nA > 0 else 0.0 |
| j = max(0, min(round(frac * nB), nB)) if nB > 0 else 0 |
| la, pa = da[i]; lb, pb = db[j] |
| a_t, b_t = pa[ida].item(), pb[idb].item() |
| lines.append("%18s | %16s %6.3f | %16s %6.3f | %+7.3f" |
| % ("%3.0f%% (%s/%s)" % (frac * 100, la, lb), |
| top1(pa, ta), a_t, top1(pb, tb), b_t, b_t - a_t)) |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def _find_subject_positions(tok, input_ids, prompt, subject): |
| """Locate subject token positions, with a path for slow (non-fast) toks.""" |
| seq_len = input_ids.shape[1] |
| if getattr(tok, "is_fast", False): |
| enc = tok(prompt, return_tensors="pt", return_offsets_mapping=True) |
| cs = prompt.find(subject) |
| if cs >= 0: |
| ce = cs + len(subject) |
| offs = enc["offset_mapping"][0].tolist() |
| pos = [i for i, (s, e) in enumerate(offs) if e > cs and s < ce] |
| if pos: |
| return [p for p in pos if p != seq_len - 1], "" |
| else: |
| sub_ids = tok.encode(subject, add_special_tokens=False) |
| seq = input_ids[0].tolist() |
| pos = [i for i, t in enumerate(seq) if t in sub_ids] |
| if pos: |
| return [p for p in pos if p != seq_len - 1], "" |
| fb = list(range(0, max(1, seq_len - 1)))[: max(1, seq_len // 2)] |
| return fb, "(subject not found; using fallback window)\n" |
|
|
|
|
| @torch.no_grad() |
| def causal_trace(prompt, subject, target, noise_scale, seed): |
| if STATE["name"] is None: |
| return "Load a model first." |
| model, tok = get_handles(STATE["name"]) |
| seed, noise_scale = int(seed), float(noise_scale) |
| inputs = tok(prompt, return_tensors="pt").to(DEVICE) |
| input_ids = inputs["input_ids"] |
| positions, note = _find_subject_positions(tok, input_ids, prompt, subject) |
| if not positions: |
| return note + "No valid subject positions." |
| target_ids = tok.encode(target, add_special_tokens=False) |
| if not target_ids: |
| return "Could not tokenize target %r." % target |
| tid = target_ids[0] |
|
|
| out_clean = model(**inputs, output_hidden_states=True) |
| clean_hs = out_clean.hidden_states |
| clean_p = F.softmax(out_clean.logits[0, -1].to(DTYPE), dim=-1)[tid].item() |
|
|
| emb_module = model.get_input_embeddings() |
| std = emb_module.weight.std().item() |
| hidden = emb_module.weight.shape[-1] |
| torch.manual_seed(seed) |
| noise = torch.randn(len(positions), hidden, device=DEVICE) * noise_scale * std |
|
|
| def corrupt_hook(module, inp, out): |
| out = out.clone() |
| for k, p in enumerate(positions): |
| out[0, p] = out[0, p] + noise[k].to(out.dtype) |
| return out |
|
|
| h = emb_module.register_forward_hook(corrupt_hook) |
| corrupt_p = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item() |
| h.remove() |
|
|
| blocks, rows = get_blocks(model), [] |
| for l in range(len(blocks)): |
| clean_layer_hs = clean_hs[l + 1][0] |
| def restore_hook(module, inp, out, _clean=clean_layer_hs): |
| if isinstance(out, tuple): |
| h0 = out[0].clone() |
| for p in positions: |
| h0[0, p] = _clean[p].to(h0.dtype) |
| return (h0,) + out[1:] |
| h0 = out.clone() |
| for p in positions: |
| h0[0, p] = _clean[p].to(h0.dtype) |
| return h0 |
| h1 = emb_module.register_forward_hook(corrupt_hook) |
| h2 = blocks[l].register_forward_hook(restore_hook) |
| p_r = F.softmax(model(**inputs).logits[0, -1].to(DTYPE), dim=-1)[tid].item() |
| h1.remove(); h2.remove() |
| rows.append((l, p_r)) |
|
|
| denom = clean_p - corrupt_p |
| lines = [note + "prompt: %r" % prompt, |
| "subject: %r target: %r" % (subject, target), |
| "clean p=%.3f corrupt p=%.3f noise=%sx std" % (clean_p, corrupt_p, noise_scale), |
| "", "%6s | %9s | %9s" % ("layer", "p(target)", "recovery")] |
| best_l, best_r = 0, -1e9 |
| for l, p_r in rows: |
| rec = (p_r - corrupt_p) / denom if abs(denom) > 1e-6 else 0.0 |
| if rec > best_r: |
| best_r, best_l = rec, l |
| lines.append(" L%-3d | %9.3f | %8.1f%%" % (l, p_r, rec * 100)) |
| lines.append("") |
| lines.append("# peak at L%d (%.0f%% recovery) <- the causal site" % (best_l, best_r * 100)) |
| if abs(denom) < 1e-6: |
| lines.append("# (corruption didn't move p(target): on 'handmade' this is " |
| "EXPECTED - the fact lives in a string match, not activations.)") |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| GB_ANSWERS = ["paris", "berlin", "tokyo", "london", "rome"] |
|
|
|
|
| @torch.no_grad() |
| def _probe_battery(model, tok): |
| """Run every known fact + a neutral format probe; record what the model says.""" |
| rows = {} |
| for country, orig in GB_FACTS: |
| prompt = "the capital of %s is" % country |
| probs = F.softmax(model(**tok(prompt, return_tensors="pt").to(DEVICE) |
| ).logits[0, -1].to(DTYPE), dim=-1) |
| v, i = probs.topk(1) |
| rows[country] = { |
| "prompt": prompt, "orig": orig, |
| "top1": tok.decode([i.item()]), "top1_p": v.item(), |
| "p_orig": probs[GB_STOI[orig]].item(), |
| "cand": {a: probs[GB_STOI[a]].item() for a in GB_ANSWERS}, |
| "entropy": _entropy_bits(probs), |
| } |
| return rows |
|
|
|
|
| def _verdict(before, after, subject, new_answer, drift_thresh=0.05): |
| eff = after[subject]["top1"] == new_answer |
| collateral, max_drift = [], 0.0 |
| for c in before: |
| if c == subject: |
| continue |
| d = abs(after[c]["p_orig"] - before[c]["p_orig"]) |
| max_drift = max(max_drift, d) |
| if after[c]["top1"] != before[c]["top1"] or d > drift_thresh: |
| collateral.append(c) |
| ent_blowup = any(abs(after[c]["entropy"] - before[c]["entropy"]) > 0.8 for c in before) |
| surgical = eff and not collateral and not ent_blowup |
| return eff, collateral, max_drift, ent_blowup, surgical |
|
|
|
|
| def edit_and_verify(subject, new_answer, method, strength, use_llm, |
| anthropic_key, anthropic_model, hf_token, hf_model, |
| local_url, local_model): |
| model, tok = get_handles("glassbox") |
| STATE["name"] = "glassbox" |
| model.reset() |
| before = _probe_battery(model, tok) |
| try: |
| model.edit_fact(subject.strip(), new_answer.strip(), method, float(strength)) |
| except ValueError as e: |
| return "Edit failed: %s\nValid subjects: france, germany, japan. " \ |
| "Valid answers: %s" % (e, ", ".join(GB_ANSWERS)) |
| after = _probe_battery(model, tok) |
| eff, collateral, max_drift, ent, surgical = _verdict(before, after, subject, new_answer) |
|
|
| L = ["EDIT: %s's capital -> %r (method=%s, strength=%s)" % |
| (subject, new_answer, method, strength), "", |
| "%-9s | %-22s | %-22s" % ("fact", "before (top1 / p_orig)", "after (top1 / p_orig)"), |
| "-" * 60] |
| for c in before: |
| b, a = before[c], after[c] |
| flag = " <- TARGET" if c == subject else (" <- COLLATERAL" if c in collateral else "") |
| L.append("%-9s | %-22s | %-22s%s" % ( |
| c, "%s / %.2f" % (b["top1"], b["p_orig"]), |
| "%s / %.2f" % (a["top1"], a["p_orig"]), flag)) |
| L += ["", |
| "efficacy : %s (target now says %r, p=%.2f)" % |
| ("PASS" if eff else "FAIL", after[subject]["top1"], after[subject]["top1_p"]), |
| "specificity : %s (max drift on other facts = %.3f%s)" % |
| ("PASS" if not collateral else "FAIL: " + ", ".join(collateral), |
| max_drift, "; entropy spike" if ent else ""), |
| "", "VERDICT: %s" % ("SURGICAL EDIT" if surgical else "COLLATERAL DAMAGE")] |
| L.append("(model is left in the edited state - inspect it in tabs 1-5, or hit Reset.)") |
|
|
| llm_report = "" |
| if use_llm: |
| providers = [ |
| {"type": "anthropic", "key": anthropic_key, "model": anthropic_model}, |
| {"type": "hf", "key": hf_token, "model": hf_model}, |
| {"type": "local", "url": local_url, "model": local_model}, |
| ] |
| llm_report = _llm_judge_chain(before, after, subject, new_answer, providers) |
| L += ["", "-" * 60, "INDEPENDENT LLM REVIEW:", llm_report] |
|
|
| report = "\n".join(L) |
| _log_session(subject, new_answer, method, strength, before, after, |
| eff, collateral, max_drift, surgical, llm_report) |
| return report |
|
|
|
|
| def reset_glassbox(): |
| model, _ = get_handles("glassbox") |
| model.reset() |
| return "Glass-box weights restored to pristine. Re-run any tab to confirm." |
|
|
|
|
| |
| |
| |
| |
| |
| def _build_judge_prompt(before, after, subject, new_answer): |
| import json |
| payload = {c: {"prompt": before[c]["prompt"], |
| "before_top1": before[c]["top1"], "before_p_orig": round(before[c]["p_orig"], 3), |
| "after_top1": after[c]["top1"], "after_p_orig": round(after[c]["p_orig"], 3)} |
| for c in before} |
| sys = ("You audit knowledge edits to a small language model. The intended edit " |
| "is: make %s's capital '%s'. Given before/after predictions for every " |
| "known fact, decide if the edit was SURGICAL (target changed, all other " |
| "facts unchanged) or caused COLLATERAL damage. Reply ONLY as JSON, no " |
| 'prose, no markdown fences: {"verdict":"surgical|collateral",' |
| '"target_changed":bool,"damaged_facts":[...],"confidence":0-1,' |
| '"reason":"one sentence"}.') % (subject, new_answer) |
| return sys, json.dumps(payload) |
|
|
|
|
| def _parse_verdict_json(text, provider_label): |
| import json |
| clean = text.strip().strip("`") |
| if clean.lower().startswith("json"): |
| clean = clean[4:].strip() |
| start, end = clean.find("{"), clean.rfind("}") |
| if start != -1 and end != -1: |
| clean = clean[start:end + 1] |
| v = json.loads(clean) |
| return ("[%s] verdict=%s target_changed=%s confidence=%s\n damaged: %s\n reason: %s" |
| % (provider_label, v.get("verdict"), v.get("target_changed"), v.get("confidence"), |
| v.get("damaged_facts") or "none", v.get("reason"))) |
|
|
|
|
| def _try_anthropic(sys, user, cfg): |
| import os, json |
| key = (cfg.get("key") or "").strip() or os.environ.get("ANTHROPIC_API_KEY", "") |
| if not key: |
| return None, "anthropic: no key configured" |
| body = {"model": (cfg.get("model") or "claude-sonnet-4-6").strip(), |
| "max_tokens": 400, "system": sys, "messages": [{"role": "user", "content": user}]} |
| try: |
| try: |
| import anthropic |
| client = anthropic.Anthropic(api_key=key) |
| msg = client.messages.create(**body) |
| text = "".join(b.text for b in msg.content if getattr(b, "type", "") == "text") |
| except ImportError: |
| import urllib.request |
| req = urllib.request.Request( |
| "https://api.anthropic.com/v1/messages", data=json.dumps(body).encode(), |
| headers={"x-api-key": key, "anthropic-version": "2023-06-01", |
| "content-type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=30) as r: |
| data = json.loads(r.read()) |
| text = "".join(b.get("text", "") for b in data.get("content", []) |
| if b.get("type") == "text") |
| return _parse_verdict_json(text, "anthropic:" + body["model"]), None |
| except Exception as e: |
| return None, "anthropic failed: %s" % e |
|
|
|
|
| def _try_hf(sys, user, cfg): |
| token = (cfg.get("key") or "").strip() |
| model = (cfg.get("model") or "Qwen/Qwen2.5-7B-Instruct").strip() |
| if not token: |
| import os |
| token = os.environ.get("HF_TOKEN", "") |
| if not token: |
| return None, "hf: no token configured" |
| try: |
| from huggingface_hub import InferenceClient |
| client = InferenceClient(model=model, token=token) |
| resp = client.chat_completion( |
| messages=[{"role": "system", "content": sys}, {"role": "user", "content": user}], |
| max_tokens=400) |
| text = resp.choices[0].message.content |
| return _parse_verdict_json(text, "hf:" + model), None |
| except Exception as e: |
| return None, "hf failed: %s" % e |
|
|
|
|
| def _try_local(sys, user, cfg): |
| """Any OpenAI-compatible /v1/chat/completions server - LM Studio, vLLM, |
| Ollama (with its OpenAI shim), text-generation-webui, etc.""" |
| import json, urllib.request |
| url = (cfg.get("url") or "").strip().rstrip("/") |
| if not url: |
| return None, "local: no URL configured" |
| model = (cfg.get("model") or "local-model").strip() |
| body = json.dumps({"model": model, "max_tokens": 400, "temperature": 0, |
| "messages": [{"role": "system", "content": sys}, |
| {"role": "user", "content": user}]}).encode() |
| try: |
| req = urllib.request.Request( |
| url + "/v1/chat/completions", data=body, |
| headers={"content-type": "application/json"}) |
| with urllib.request.urlopen(req, timeout=20) as r: |
| data = json.loads(r.read()) |
| text = data["choices"][0]["message"]["content"] |
| return _parse_verdict_json(text, "local:" + model + "@" + url), None |
| except Exception as e: |
| return None, "local failed: %s" % e |
|
|
|
|
| def _llm_judge_chain(before, after, subject, new_answer, providers): |
| sys, user = _build_judge_prompt(before, after, subject, new_answer) |
| dispatch = {"anthropic": _try_anthropic, "hf": _try_hf, "local": _try_local} |
| skipped = [] |
| for cfg in providers: |
| fn = dispatch.get(cfg["type"]) |
| if fn is None: |
| continue |
| result, err = fn(sys, user, cfg) |
| if result is not None: |
| note = ("" if not skipped else |
| "(skipped: %s)\n" % "; ".join(skipped)) |
| return note + result |
| skipped.append(err) |
| return ("all providers unavailable:\n " + "\n ".join(skipped) + |
| "\n(configure at least one: Anthropic key, HF token, or a local " |
| "OpenAI-compatible server URL like http://192.168.188.25:1234)") |
|
|
|
|
| |
| |
| |
| SESSION_LOG = [] |
|
|
|
|
| def _log_session(subject, new_answer, method, strength, before, after, |
| eff, collateral, max_drift, surgical, llm_report): |
| import datetime |
| SESSION_LOG.append({ |
| "ts": datetime.datetime.utcnow().isoformat() + "Z", |
| "subject": subject, "new_answer": new_answer, "method": method, |
| "strength": strength, "efficacy_pass": bool(eff), |
| "collateral": collateral, "max_drift": round(max_drift, 4), |
| "verdict": "SURGICAL" if surgical else "COLLATERAL", |
| "before": {c: {"top1": before[c]["top1"], "p_orig": round(before[c]["p_orig"], 4)} |
| for c in before}, |
| "after": {c: {"top1": after[c]["top1"], "p_orig": round(after[c]["p_orig"], 4)} |
| for c in after}, |
| "llm_review": llm_report or None, |
| }) |
|
|
|
|
| def export_session_log(): |
| import json, os |
| if not SESSION_LOG: |
| return None, "No edits run yet this session - nothing to export." |
| os.makedirs("/mnt/user-data/outputs", exist_ok=True) |
| path = "/mnt/user-data/outputs/edit_session_log.json" |
| json.dump(SESSION_LOG, open(path, "w"), indent=2) |
| |
| md = ["# Edit session log\n"] |
| for i, e in enumerate(SESSION_LOG, 1): |
| md.append("## Edit %d - %s (%s, %s, strength=%s)\n" % |
| (i, e["verdict"], e["subject"] + "->" + e["new_answer"], |
| e["method"], e["strength"])) |
| md.append("- efficacy: %s, max collateral drift: %.4f, damaged: %s" % |
| ("pass" if e["efficacy_pass"] else "fail", e["max_drift"], |
| e["collateral"] or "none")) |
| if e["llm_review"]: |
| md.append("- LLM review: " + e["llm_review"].replace("\n", " ")) |
| md.append("") |
| md_path = "/mnt/user-data/outputs/edit_session_log.md" |
| open(md_path, "w").write("\n".join(md)) |
| return path, "Wrote %d edit(s) to %s and %s" % (len(SESSION_LOG), path, md_path) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| _MODELING_PY = '''"""Standalone glass-box model - reload with no other files. |
| |
| from modeling_glassbox import load |
| m, tok = load(".") # folder containing config/weights/vocab |
| print(tok.decode(m.generate(tok("the capital of france is"))[0])) |
| """ |
| import json, torch, torch.nn as nn, torch.nn.functional as F |
| from safetensors.torch import load_file |
| |
| def load(path="."): |
| cfg = json.load(open(f"{path}/config.json")) |
| stoi = json.load(open(f"{path}/vocab.json")); itos = {i: w for w, i in stoi.items()} |
| D, V = cfg["d_model"], len(stoi); facts = [tuple(f) for f in cfg["facts"]] |
| SUBJ = torch.zeros(D, D) |
| for d in range(cfg["subject_dims"]): SUBJ[d, d] = 1.0 |
| |
| class Tok: |
| is_fast = False |
| def __init__(s): s.eos_token_id = stoi["."] |
| def _ids(s, t): return [stoi.get(w, stoi["<s>"]) for w in t.lower().replace(".", " .").split()] or [stoi["<s>"]] |
| def __call__(s, t, **k): |
| import torch as T; return {"input_ids": T.tensor([s._ids(t)])} |
| def decode(s, ids, **k): return " ".join(itos.get(int(i), "?") for i in ids) |
| class Ident(nn.Module): |
| def forward(s, x): return (x.clone(),) |
| class Pool(nn.Module): |
| def forward(s, x): |
| o = x.clone() |
| if x.shape[1] > 1: o[0, -1] = o[0, -1] + 0.9 * (x[0, :-1] @ SUBJ.T).sum(0) |
| return (o,) |
| class FactMLP(nn.Module): |
| def __init__(s): |
| super().__init__() |
| s.register_buffer("Win", torch.zeros(len(facts), D)) |
| s.register_buffer("Wout", torch.zeros(D, len(facts))) |
| s.bias, s.gain = cfg["bias"], cfg["gain"] |
| def forward(s, x): |
| o = x.clone(); pre = F.relu(s.Win @ o[0, -1] - s.bias) |
| o[0, -1] = o[0, -1] + s.gain * (s.Wout @ pre); return (o,) |
| class T(nn.Module): |
| def __init__(s): |
| super().__init__(); s.wte = nn.Embedding(V, D) |
| s.h = nn.ModuleList([Ident(), Pool(), FactMLP(), Ident()]); s.ln_f = nn.Identity() |
| class GlassBox(nn.Module): |
| def __init__(s): |
| super().__init__(); s.transformer = T(); s.head = nn.Linear(D, V, bias=False) |
| def get_input_embeddings(s): return s.transformer.wte |
| def forward(s, input_ids=None, **k): |
| x = s.transformer.wte(input_ids) |
| for b in s.transformer.h: (x,) = b(x) |
| class O: pass |
| o = O(); o.logits = s.head(x); return o |
| @torch.no_grad() |
| def generate(s, input_ids=None, max_new_tokens=12, **k): |
| ids = input_ids |
| for _ in range(max_new_tokens): |
| ids = torch.cat([ids, s(input_ids=ids).logits[0, -1].argmax().view(1, 1)], 1) |
| return ids |
| m = GlassBox().eval() |
| sd = load_file(f"{path}/model.safetensors") |
| m.load_state_dict({k: v for k, v in sd.items() if not k.endswith("0")}, strict=False) |
| return m, Tok() |
| ''' |
|
|
|
|
| def export_glassbox(outdir="glassbox_export"): |
| import os, json |
| from safetensors.torch import save_file |
| os.makedirs(outdir, exist_ok=True) |
| model, _ = get_handles("glassbox") |
| sd = {k: v.contiguous() for k, v in model.state_dict().items()} |
| save_file(sd, os.path.join(outdir, "model.safetensors")) |
| json.dump({"model_type": "glassbox", "d_model": GB_D, "vocab_size": GB_V, |
| "subject_dims": 9, "bias": model.transformer.h[2].bias, |
| "gain": model.transformer.h[2].gain, |
| "facts": [list(f) for f in GB_FACTS]}, |
| open(os.path.join(outdir, "config.json"), "w"), indent=2) |
| json.dump(GB_STOI, open(os.path.join(outdir, "vocab.json"), "w"), indent=2) |
| open(os.path.join(outdir, "modeling_glassbox.py"), "w").write(_MODELING_PY) |
| open(os.path.join(outdir, "README.md"), "w").write( |
| "---\nlicense: mit\ntags: [interpretability, glass-box, rome, toy-model]\n---\n\n" |
| "# Glass-box interpretability model\n\n" |
| "A tiny transformer-shaped model whose facts are stored as key->value " |
| "writes into the residual stream, so logit-lens, activation steering and " |
| "ROME causal tracing all reproduce the *known* ground truth. Built as a " |
| "verification harness for interpretability code.\n\n" |
| "```python\nfrom modeling_glassbox import load\n" |
| "m, tok = load('.')\n" |
| "print(tok.decode(m.generate(tok('the capital of france is')['input_ids'])[0]))\n```\n\n" |
| "Facts: " + ", ".join("%s->%s" % f for f in GB_FACTS) + ".\n") |
| return outdir |
|
|
|
|
| def upload_to_hf(repo_id, token, what, app_path=__file__): |
| """Push the model and/or this app (as a Space) to the Hub.""" |
| import os |
| try: |
| from huggingface_hub import HfApi |
| except ImportError: |
| return "huggingface_hub not installed. `pip install huggingface_hub`." |
| token = (token or "").strip() or os.environ.get("HF_TOKEN", "") |
| if not token: |
| return "No HF token. Paste a write token or set HF_TOKEN." |
| if not repo_id.strip(): |
| return "Enter a repo id like 'Chris4K/glassbox-interp'." |
| api, logs = HfApi(token=token), [] |
| try: |
| if what in ("model", "both"): |
| d = export_glassbox() |
| api.create_repo(repo_id, repo_type="model", exist_ok=True) |
| api.upload_folder(folder_path=d, repo_id=repo_id, repo_type="model") |
| logs.append("model -> https://huggingface.co/%s" % repo_id) |
| if what in ("space", "both"): |
| sid = repo_id + "-space" if what == "both" else repo_id |
| api.create_repo(sid, repo_type="space", space_sdk="gradio", exist_ok=True) |
| api.upload_file(path_or_fileobj=app_path, path_in_repo="app.py", |
| repo_id=sid, repo_type="space") |
| req = "torch\ntransformers\ngradio\nsafetensors\nhuggingface_hub\nanthropic\n" |
| api.upload_file(path_or_fileobj=req.encode(), path_in_repo="requirements.txt", |
| repo_id=sid, repo_type="space") |
| logs.append("space -> https://huggingface.co/spaces/%s" % sid) |
| return "Uploaded:\n " + "\n ".join(logs) |
| except Exception as e: |
| return "Upload failed: %s" % e |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| def upload_local_checkpoint(local_dir, repo_id, token, private, commit_message): |
| import os |
| try: |
| from huggingface_hub import HfApi |
| except ImportError: |
| return "huggingface_hub not installed. `pip install huggingface_hub`." |
| local_dir = (local_dir or "").strip() |
| repo_id = (repo_id or "").strip() |
| if not local_dir or not os.path.isdir(local_dir): |
| return "local_dir %r does not exist or is not a directory." % local_dir |
| if not repo_id: |
| return "Enter a repo id like 'Chris4K/vindex-llama3-edited'." |
| token = (token or "").strip() or os.environ.get("HF_TOKEN", "") |
| if not token: |
| return "No HF token. Paste a write token or set HF_TOKEN." |
| has_cfg = os.path.exists(os.path.join(local_dir, "config.json")) |
| has_weights = any(f.endswith((".safetensors", ".bin")) |
| for f in os.listdir(local_dir)) |
| warn = "" if (has_cfg and has_weights) else ( |
| "WARNING: folder is missing config.json or weight files - this may " |
| "not be a loadable HF checkpoint. Uploading anyway.\n") |
| api = HfApi(token=token) |
| try: |
| api.create_repo(repo_id, repo_type="model", private=bool(private), exist_ok=True) |
| api.upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", |
| commit_message=(commit_message or "upload checkpoint").strip()) |
| return (warn + "Uploaded %s -> https://huggingface.co/%s\n" |
| "Files: %s" % (local_dir, repo_id, ", ".join(sorted(os.listdir(local_dir))[:12]))) |
| except Exception as e: |
| return warn + "Upload failed: %s" % e |
|
|
|
|
| |
| |
| |
| INTRO = """ |
| # Compression Navigator |
| **An LLM is a lossy codec for text.** Training compresses a corpus into weights; |
| a forward pass decompresses a continuation. These five tools let you watch that |
| decompression and find where facts physically live. |
| |
| Each tab is a real interpretability technique: **logit lens, embedding |
| neighbours, activation steering, cross-model diff, and causal tracing (ROME).** |
| |
| ### Three models, on purpose |
| | name | how it stores facts | what it teaches | |
| |---|---|---| |
| | **`glassbox`** | keyβvalue writes into the **residual stream** (like a real transformer / what ROME edits) | the tools **work and are verifiable** against ground truth you can read in the source | |
| | **`handmade`** | a **lookup table** keyed on the prompt string (a side channel) | a model can be **invisible** to residual-stream interpretability β a real limitation | |
| | **`gpt2`** | learned, fuzzy, **distributed** over many layers | what the real, messy thing looks like | |
| |
| **Suggested order:** load `glassbox` first (see "correct"), then `handmade` |
| (see a failure mode), then `gpt2` (see reality). Type a name below and Load. |
| """ |
|
|
| with gr.Blocks(title="Compression Navigator") as demo: |
| gr.Markdown(INTRO) |
| with gr.Row(): |
| model_name = gr.Textbox(value="glassbox", label="model name or HF id") |
| load_btn = gr.Button("Load", variant="primary") |
| load_status = gr.Markdown() |
| load_btn.click(load_model, inputs=model_name, outputs=load_status) |
|
|
| |
| with gr.Tab("1 Β· Decompress (logit lens)"): |
| gr.Markdown(""" |
| ### Logit lens β watch the answer condense, layer by layer |
| **What it does:** takes the last-token residual at *every* layer and reads it |
| through the unembedding, as if the model had to answer right there. You see the |
| prediction form. |
| |
| **How to read it:** each row is a layer. Watch your tracked token's probability |
| (right column) climb, and watch **entropy** (bits) fall as the model commits. |
| |
| **Ground truth to check:** |
| - `glassbox` β `paris` is ~0 until **L3** (the readout right after the fact-MLP), then jumps to ~0.51. Sharp and localised because you put it there. |
| - `handmade` β the answer snaps to 1.00 at **L1** with zero build-up (it's a lookup, not a computation). |
| - `gpt2` β the answer accretes *gradually* across many middle/late layers. That smear is what "distributed representation" actually looks like. |
| |
| *(Numbering note: the lens counts from the embedding, so `L1` is after the first block. The causal-trace tab counts blocks from `L0`. So the fact-MLP is lens-`L3` / trace-block-`L2`, and its causal site shows at trace-`L0`.)* |
| """) |
| ll_prompt = gr.Textbox(value="the capital of france is", label="prompt") |
| with gr.Row(): |
| ll_k = gr.Slider(1, 10, value=3, step=1, label="top-k per layer") |
| ll_track = gr.Textbox(value="paris", label="track this token's prob") |
| ll_out = gr.Textbox(label="output", lines=18) |
| gr.Button("Run").click(logit_lens, [ll_prompt, ll_k, ll_track], ll_out) |
|
|
| |
| with gr.Tab("2 Β· Triangulate (neighbours)"): |
| gr.Markdown(""" |
| ### Neighbours β the geometry of the vocabulary |
| **What it does:** ranks tokens by cosine similarity of their unembedding rows. |
| Directions that point the same way are "near" in the model's compressed space. |
| |
| **How to read it:** high cosine = the model treats these tokens as related. |
| |
| **Ground truth to check:** |
| - `glassbox` β `paris` is near `france` (cos β 0.48): the source deliberately makes a capital share a dimension with its country. Real geometry, by design. |
| - `handmade` β **every** cosine is 0. One-hot embeddings are mutually orthogonal, so there's no geometry at all. The tool is correctly reporting "nothing here." |
| - `gpt2` β neighbours are messy but meaningful (casing variants, plurals, semantic kin). |
| """) |
| nb_word = gr.Textbox(value="paris", label="word") |
| nb_k = gr.Slider(5, 25, value=10, step=1, label="top neighbours") |
| nb_out = gr.Textbox(label="output", lines=15) |
| gr.Button("Run").click(neighbors, [nb_word, nb_k], nb_out) |
|
|
| |
| with gr.Tab("3 Β· Re-route (steering)"): |
| gr.Markdown(""" |
| ### Steering β bend behaviour with a direction, no retraining |
| **What it does:** builds the vector `emb(target) β emb(source)` and *adds* it to |
| a layer's output during generation. The model drifts from `source` toward |
| `target`. This is the cheap cousin of fine-tuning (ActAdd / representation |
| engineering). |
| |
| **How to read it:** compare *baseline* vs *steered*. Raise **strength** until the |
| output flips; too high and it turns to noise (you've knocked the residual off |
| the manifold). |
| |
| **Tips:** on `gpt2` try `from: Paris to: London` on the France prompt, layer |
| 0β4, strength 6β14. On `glassbox` it works cleanly too β `from: france |
| to: japan` at layer 0, strength 8, flips the output from `paris` to `tokyo` |
| (you're pushing the residual along the subjectβsubject direction the fact-MLP |
| keys on). |
| """) |
| st_prompt = gr.Textbox(value="the capital of france is", label="prompt") |
| with gr.Row(): |
| st_src = gr.Textbox(value="Paris", label="from") |
| st_tgt = gr.Textbox(value="London", label="to") |
| with gr.Row(): |
| st_layer = gr.Slider(0, 11, value=2, step=1, label="layer") |
| st_alpha = gr.Slider(0, 30, value=10, step=0.5, label="strength") |
| st_max = gr.Slider(8, 80, value=40, step=1, label="max new tokens") |
| st_base = gr.Textbox(label="baseline", lines=2) |
| st_out = gr.Textbox(label="steered", lines=3) |
| gr.Button("Run").click(steer_generate, |
| [st_prompt, st_src, st_tgt, st_layer, st_alpha, st_max], |
| [st_base, st_out]) |
|
|
| |
| with gr.Tab("4 Β· Diff (align by depth)"): |
| gr.Markdown(""" |
| ### Diff β two models on one prompt, aligned by *relative* depth |
| **What it does:** runs the logit lens on model A and model B and lines their |
| layers up by percentage depth (0β100%), so you can compare a 2-layer toy with a |
| 12-layer gpt2 side by side. `dp` is `p_B β p_A` for the target token. |
| |
| **How to read it:** look at *where* on the depth axis each model commits to the |
| target. A localised model commits at one depth; a distributed one ramps up. |
| |
| **Try:** A = `gpt2`, B = `glassbox`, target = `paris`. You'll see gpt2 ramp |
| through the middle while glassbox snaps on at its fact layer β the same fact, |
| two very different internal shapes. |
| """) |
| with gr.Row(): |
| df_a = gr.Textbox(value="gpt2", label="model A") |
| df_b = gr.Textbox(value="glassbox", label="model B") |
| df_prompt = gr.Textbox(value="the capital of france is", label="prompt") |
| df_target = gr.Textbox(value="paris", label="target token") |
| df_k = gr.Slider(1, 5, value=1, step=1, label="top-k (display)") |
| df_out = gr.Textbox(label="output", lines=16) |
| gr.Button("Run").click(diff_models, |
| [df_a, df_b, df_prompt, df_target, df_k], df_out) |
|
|
| |
| with gr.Tab("5 Β· Causal trace (ROME)"): |
| gr.Markdown(""" |
| ### Causal trace β corrupt the subject, restore each layer, find the site |
| **What it does:** activation patching (Meng et al.'s ROME). It noises the |
| **subject** token, which breaks the prediction, then restores one layer at a |
| time and measures how much of the answer comes back. The layer that restores |
| the most is where the fact is *causally* computed. |
| |
| **How to read it:** `recovery` β 100% means "restoring this layer is enough" β |
| the fact is read here. The peak line names the site. |
| |
| **Ground truth to check:** |
| - `glassbox` β peak at **L0** (β100%). The fact is read at the early subject site, because the L1 "attention" re-reads the restored subject. You know this is right because you wrote the mechanism. |
| - `handmade` β `clean p` β `corrupt p`, so recovery is meaningless. **Expected:** the fact is a string match, untouched by activation noise. This is the headline lesson β patching can't see lookup behaviour. |
| - `gpt2` β a *band* of earlyβmiddle layers at the subject token light up, exactly as in the ROME paper. |
| """) |
| ct_prompt = gr.Textbox(value="the capital of france is", label="prompt") |
| ct_subject = gr.Textbox(value="france", label="subject to corrupt") |
| ct_target = gr.Textbox(value="paris", label="target token") |
| with gr.Row(): |
| ct_noise = gr.Slider(0, 10, value=3, step=0.5, label="noise (x embed std)") |
| ct_seed = gr.Slider(0, 100, value=0, step=1, label="seed") |
| ct_out = gr.Textbox(label="output", lines=18) |
| gr.Button("Run").click(causal_trace, |
| [ct_prompt, ct_subject, ct_target, ct_noise, ct_seed], ct_out) |
|
|
| |
| with gr.Tab("6 Β· Edit + verify (ROME loop)"): |
| gr.Markdown(""" |
| ### Edit a fact, then prove nothing else broke |
| **What it does:** rewrites the value one fact-MLP key maps to (the exact thing |
| ROME/MEMIT do on real models β this is a literal `nn.Module` weight tensor, |
| not a token or vocab change), then runs a verification battery over **every** |
| known fact to measure **efficacy** (target changed), **specificity** (others |
| untouched), and **fluency** (no entropy collapse). |
| |
| **Two methods, on purpose:** |
| - `rank1` β the minimal, surgical update. Only the target fact moves β **SURGICAL**. |
| - `broadcast` β a deliberately sloppy edit that smears the change across all facts β the harness catches the **COLLATERAL DAMAGE**. This proves the verifier actually works, not just reports "ok" by default. |
| |
| **Independent LLM review, with a fallback chain β not locked to one vendor:** |
| tick the box and it tries, in order: **Anthropic** (Claude, if you give a key) |
| β **Hugging Face Inference** (any hosted chat model, if you give an HF token) |
| β **your own local server** (LM Studio / vLLM / Ollama's OpenAI shim β anything |
| exposing `/v1/chat/completions`). The first one that's configured *and* |
| reachable answers; the rest are skipped and noted. So your own RTX 5090 can |
| be the judge with zero cloud calls if you just fill in the local URL. |
| |
| Subjects: `france`, `germany`, `japan`. Answers: `paris, berlin, tokyo, london, rome`. |
| After editing, the model stays edited β go look at it in tabs 1β5 (the logit lens |
| will show the new answer rising; the trace still localises to L0). Hit **Reset** |
| to restore. Every run is appended to a session log you can download below and |
| paste into a future chat for review. |
| """) |
| with gr.Row(): |
| ed_subj = gr.Textbox(value="france", label="subject") |
| ed_new = gr.Textbox(value="london", label="new answer") |
| ed_method = gr.Radio(["rank1", "broadcast"], value="rank1", label="method") |
| ed_strength = gr.Slider(0.2, 2.0, value=1.0, step=0.1, label="strength") |
| ed_llm = gr.Checkbox(value=False, label="also run an independent LLM review") |
| with gr.Accordion("LLM review providers (tried in this order)", open=False): |
| with gr.Row(): |
| ed_a_model = gr.Textbox(value="claude-sonnet-4-6", label="1. Anthropic model") |
| ed_a_key = gr.Textbox(value="", label="Anthropic API key", type="password") |
| with gr.Row(): |
| ed_h_model = gr.Textbox(value="Qwen/Qwen2.5-7B-Instruct", |
| label="2. HF Inference model") |
| ed_h_key = gr.Textbox(value="", label="HF token", type="password") |
| with gr.Row(): |
| ed_l_url = gr.Textbox(value="http://192.168.188.25:1234", |
| label="3. Local server URL (LM Studio etc.)") |
| ed_l_model = gr.Textbox(value="local-model", label="local model name") |
| ed_out = gr.Textbox(label="edit + verification report", lines=24) |
| with gr.Row(): |
| gr.Button("Edit & verify", variant="primary").click( |
| edit_and_verify, |
| [ed_subj, ed_new, ed_method, ed_strength, ed_llm, |
| ed_a_key, ed_a_model, ed_h_key, ed_h_model, ed_l_url, ed_l_model], |
| ed_out) |
| gr.Button("Reset model").click(reset_glassbox, outputs=ed_out) |
| gr.Markdown("**Session log** (every edit run above, appended):") |
| with gr.Row(): |
| log_btn = gr.Button("Write session log to disk") |
| log_file = gr.File(label="download") |
| log_status = gr.Markdown() |
| log_btn.click(lambda: export_session_log(), outputs=[log_file, log_status]) |
|
|
| |
| with gr.Tab("7 Β· Export / Upload to HF"): |
| gr.Markdown(""" |
| ### Ship the toy glass-box |
| **Export** writes a self-contained, reloadable repo: weights (`safetensors`), |
| `config.json`, `vocab.json`, a standalone `modeling_glassbox.py` (reload with |
| `from modeling_glassbox import load`), and a model card. |
| |
| **Upload** pushes it to the Hub. Choose `model`, `space` (this whole app, |
| runnable), or `both`. Paste a **write** token (or set `HF_TOKEN`). |
| """) |
| with gr.Row(): |
| hf_repo = gr.Textbox(value="Chris4K/glassbox-interp", label="repo id") |
| hf_what = gr.Radio(["model", "space", "both"], value="model", label="what to push") |
| hf_token = gr.Textbox(value="", label="HF write token (optional)", type="password") |
| hf_out = gr.Textbox(label="result", lines=6) |
| with gr.Row(): |
| gr.Button("Export locally").click( |
| lambda: "Exported to ./%s" % export_glassbox(), outputs=hf_out) |
| gr.Button("Upload to HF", variant="primary").click( |
| upload_to_hf, [hf_repo, hf_token, hf_what], hf_out) |
|
|
| gr.Markdown(""" |
| --- |
| ### Upload a REAL model β e.g. your VINDEX-edited Llama checkpoint |
| This does **not** load the model into memory and does **not** assume any |
| particular architecture β it just pushes whatever's already on disk at |
| `local_dir` (the usual `save_pretrained()` layout: `config.json` + |
| `*.safetensors` shards + tokenizer files) straight to a new repo. Large |
| weights upload fine through `upload_folder`; for very large repos consider |
| installing `hf_transfer` for faster throughput. If the base model is gated |
| (e.g. `meta-llama/*`), the gate applies to the destination repo's license |
| settings, not to this upload step. |
| """) |
| with gr.Row(): |
| rc_dir = gr.Textbox(value="", label="local checkpoint folder (on this machine)") |
| rc_repo = gr.Textbox(value="", label="destination repo id, e.g. Chris4K/vindex-llama3-edited") |
| with gr.Row(): |
| rc_token = gr.Textbox(value="", label="HF write token (optional)", type="password") |
| rc_private = gr.Checkbox(value=True, label="private repo") |
| rc_msg = gr.Textbox(value="upload edited checkpoint", label="commit message") |
| rc_out = gr.Textbox(label="result", lines=6) |
| gr.Button("Upload real checkpoint", variant="primary").click( |
| upload_local_checkpoint, [rc_dir, rc_repo, rc_token, rc_private, rc_msg], rc_out) |
|
|
| gr.Markdown(""" |
| --- |
| ### Where this goes next |
| - **Closing the loop (what "self-improving" would actually require):** right now a human picks every edit; the verifier just grades it. A real closed loop needs a policy that *proposes* edits on its own (e.g. scanning eval failures for wrong facts), auto-applies, and auto-commits only on a SURGICAL verdict, rolling back otherwise. The hard part β the verifier β already exists here; the proposal step doesn't yet. |
| - **A training-method angle worth taking seriously:** instead of accept/reject after the fact, feed the specificity battery's drift score back as a regularizer *during* the edit computation (closer to elastic weight consolidation, or the null-space projection AlphaEdit-style methods use) so collateral is penalized while solving, not caught after. |
| - **Real-model MEMIT:** the edit loop here is exact because the glass-box's fact layer is literally keyβvalue. The same verify harness (efficacy / specificity / fluency + the multi-provider LLM judge) ports straight onto a gpt2/Llama MEMIT edit β the toy is the regression test you run first. |
| - **Multi-hop & paraphrase generalization:** add `"the currency of france is"` so two relations share a subject, and have the LLM judge auto-generate paraphrase probes to test that an edit generalizes, not just memorizes the one prompt. |
| - **Attribution view:** Geva-style "what does this neuron write to the vocab", per-head attention attribution. |
| - **It already ships:** tab 7 pushes the toy model and this whole app (as a Space) to your Hub, or a real local checkpoint folder to its own repo. |
| """) |
|
|
| demo.load(lambda: load_model("glassbox"), outputs=load_status) |
|
|
| if __name__ == "__main__": |
| demo.launch() |