Cartogemma / app.py
anotheruserishere's picture
Upload folder using huggingface_hub
969bd45 verified
"""Cartogemma — Gradio Space wrapping InteractiveProbe against google/gemma-3-1b-it.
Faithful to the cartographer3.py / cartographer_tui.py presentation:
- Context tail pane
- Head Map: Layer | H0..Hn (pre-stream) | Xray | H0..Hn dv | Full dv
- Futures pane (top-k branches with probs + previews)
- Sparkline pane (token rank scatter across heads x layers, rendered as <pre>)
- REPL command bar (1-N, h, top, rew, i, spark, mute, unmute, w, l, s, q)
"""
import os
import re
import io
import sys
import json
import math
import unicodedata
from dataclasses import dataclass, field
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
try:
from transformers import AutoModelForImageTextToText
except Exception: # older transformers
AutoModelForImageTextToText = None
# ── Config ────────────────────────────────────────────────────────────────
DEFAULT_MODEL_ID = os.environ.get("CARTOGEMMA_MODEL", "google/gemma-3-270m-it")
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
SEED = 42
DEFAULT_WIDTH = 5
DEFAULT_PREVIEW = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── HEAT palette (mirrors cartographer3) ──────────────────────────────────
HEAT_HEX = ["#59749C", "#948863", "#ECA60F", "#FFCF67", "#219C7F", "#E0D05B", "#64D32A"]
def heat_color(p: float) -> str:
if p >= 0.85: return HEAT_HEX[6]
if p >= 0.70: return HEAT_HEX[5]
if p >= 0.50: return HEAT_HEX[4]
if p >= 0.30: return HEAT_HEX[3]
if p >= 0.15: return HEAT_HEX[2]
if p >= 0.05: return HEAT_HEX[1]
return HEAT_HEX[0]
def char_width(ch: str) -> int:
cat = unicodedata.category(ch)
if cat in ("Mn", "Me"): return 0
eaw = unicodedata.east_asian_width(ch)
if eaw in ("W", "F"): return 2
return 1
def fmt_tok(tok: str, max_w: int = 10) -> str:
tok = tok.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
w, out = 0, []
for ch in tok:
cw = char_width(ch)
if w + cw > max_w: break
out.append(ch); w += cw
return "".join(out)
def html_escape(s: str) -> str:
return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
ANSI_RE = re.compile(r"\033\[[0-9;]*m")
def strip_ansi(s: str) -> str:
return ANSI_RE.sub("", s)
# ── Architecture resolver ─────────────────────────────────────────────────
# The probe makes two tiers of assumption about a model:
# TIER 1 (generic): output_hidden_states + a final norm + an lm_head/tied
# embedding. Gives logit-lens (xray), per-layer Δ-residual, branches,
# top-by-rank, inject, rewind. Works on essentially any HF decoder LM.
# TIER 2 (architecture-specific): attention writes the residual through an
# o_proj whose INPUT is the concatenation of contiguous per-head blocks
# (standard MHA/GQA: in_features == num_heads * head_dim). Gives per-head
# pre-projection, per-head Δ-residual, head muting, token rank trace.
# We auto-discover both tiers and flip `head_scan_supported` off (degrading to
# Tier 1) when the Tier-2 assumption doesn't hold (fused QKV, MLA, exotic attn).
def _resolve_text_config(cfg):
"""Find the sub-config that actually holds num_hidden_layers, etc.
Multimodal configs (Gemma3/4, gemma3n) nest the LM under .text_config."""
if hasattr(cfg, "get_text_config"):
try:
tc = cfg.get_text_config()
if tc is not None and hasattr(tc, "num_hidden_layers"):
return tc
except Exception:
pass
tc = getattr(cfg, "text_config", None)
if tc is not None and hasattr(tc, "num_hidden_layers"):
return tc
return cfg
def _is_decoder(mod):
layers = getattr(mod, "layers", None)
if not isinstance(layers, nn.ModuleList) or len(layers) == 0:
return False
first = layers[0]
return any(hasattr(first, a) for a in ("self_attn", "attn", "attention"))
_DECODER_PATHS = [
("model",), ("model", "language_model"), ("language_model",),
("model", "model"), ("model", "language_model", "model"),
("transformer",), ("model", "decoder"), ("decoder",),
]
def _resolve_decoder(model):
"""Locate the module that owns the decoder-block ModuleList (.layers)."""
for path in _DECODER_PATHS:
obj = model
ok = True
for a in path:
if hasattr(obj, a):
obj = getattr(obj, a)
else:
ok = False
break
if ok and _is_decoder(obj):
return obj
# Generic fallback: longest attention-bearing ModuleList in the tree.
best = None
for _, mod in model.named_modules():
if _is_decoder(mod) and (best is None or len(mod.layers) > len(best.layers)):
best = mod
if best is not None:
return best
raise RuntimeError("Could not locate a decoder layer stack for this model.")
def _resolve_final_norm(decoder):
for a in ("norm", "final_layernorm", "ln_f", "final_norm"):
n = getattr(decoder, a, None)
if n is not None:
return n
return None
def _resolve_lm_head(model):
for path in (("lm_head",), ("language_model", "lm_head"), ("model", "lm_head")):
obj = model
ok = True
for a in path:
if hasattr(obj, a):
obj = getattr(obj, a)
else:
ok = False
break
if ok and obj is not None and hasattr(obj, "weight"):
return obj
return None # caller falls back to tied input embeddings
def _resolve_o_proj(layer):
"""The attention output projection whose input is per-head-concatenated."""
attn = None
for a in ("self_attn", "attn", "attention"):
attn = getattr(layer, a, None)
if attn is not None:
break
if attn is None:
return None
for a in ("o_proj", "out_proj", "wo", "dense", "c_proj"):
p = getattr(attn, a, None)
if p is not None and hasattr(p, "weight"):
return p
return None
@dataclass
class Snapshot:
"""One forward pass over the current context, shared by every pane.
Avoids the old pattern of xray / scan / branches each doing their own
forward. head_inputs is populated only when head_scan is supported."""
logits: torch.Tensor # [vocab] last-token logits (float)
hidden_states: tuple # len L+1, each [1, seq, hidden]
head_inputs: dict = field(default_factory=dict) # layer_idx -> o_proj input [1, seq, H*hd]
entropy: float = 0.0 # next-token entropy (nats)
top1: float = 0.0 # next-token top-1 probability
# ── The Probe (lifted from cartographer3, ANSI removed) ───────────────────
class InteractiveProbe:
def __init__(self, model_id: str = None):
model_id = model_id or DEFAULT_MODEL_ID
self.model_id = model_id
print(f"[*] Loading {model_id} on {DEVICE}...")
kw = {"token": HF_TOKEN} if HF_TOKEN else {}
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kw)
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32
load_kw = dict(torch_dtype=dtype, **kw)
if DEVICE == "cuda":
load_kw["device_map"] = "auto"
# Try text CausalLM first; fall back to the multimodal head (gemma-3-4b,
# gemma-4, gemma3n are all image-text-to-text wrappers around a text LM).
self.model = None
errs = []
loaders = [AutoModelForCausalLM]
if AutoModelForImageTextToText is not None:
loaders.append(AutoModelForImageTextToText)
for loader in loaders:
try:
self.model = loader.from_pretrained(model_id, **load_kw)
break
except Exception as e:
errs.append(f"{loader.__name__}: {type(e).__name__}: {e}")
if self.model is None:
raise RuntimeError("Could not load model with any auto-class:\n" + "\n".join(errs))
if DEVICE == "cpu":
self.model = self.model.to(DEVICE)
self.model.eval()
# ── Auto-discover the architecture (see resolver notes above) ──
self._decoder = _resolve_decoder(self.model)
self.layers = self._decoder.layers
self.final_norm = _resolve_final_norm(self._decoder)
self._lm_head_mod = _resolve_lm_head(self.model)
if self._lm_head_mod is None:
emb = self.model.get_input_embeddings()
self._tied_emb_weight = emb.weight
else:
self._tied_emb_weight = None
text_cfg = _resolve_text_config(self.model.config)
# Trust the actual stack length over config (resolver may pick a sub-stack).
self.num_layers = len(self.layers)
self.num_heads = int(getattr(text_cfg, "num_attention_heads", 0)) or 1
self.num_kv_heads = int(getattr(text_cfg, "num_key_value_heads", self.num_heads))
hidden_size = int(getattr(text_cfg, "hidden_size", 0)) or 0
vocab_size = int(getattr(text_cfg, "vocab_size", 0)) or int(getattr(self.model.config, "vocab_size", 0))
self.vocab_size = vocab_size or self._lm_head_out_features()
o0 = _resolve_o_proj(self.layers[0])
cfg_head_dim = getattr(text_cfg, "head_dim", None)
if cfg_head_dim:
self.head_dim = int(cfg_head_dim)
elif o0 is not None:
self.head_dim = o0.weight.shape[1] // self.num_heads
elif hidden_size:
self.head_dim = hidden_size // self.num_heads
else:
self.head_dim = 0
self.hidden_size = hidden_size
# Tier-2 (per-head) support: o_proj input must decompose as heads*head_dim.
self.head_scan_supported = bool(
o0 is not None
and self.final_norm is not None
and self.head_dim
and o0.weight.shape[1] == self.num_heads * self.head_dim
)
self.arch_name = type(self.model).__name__
print(f"[*] {self.arch_name}: {self.num_layers}L x {self.num_heads}H "
f"(kv={self.num_kv_heads}), head_dim={self.head_dim}, hidden={hidden_size}, "
f"vocab={self.vocab_size}, head_scan={'yes' if self.head_scan_supported else 'NO (Tier-1 only)'}")
self.history_ids = None
self.full_log = []
self.muted_heads = set()
self._mute_handles = []
self.snapshot = None
# ── architecture accessors ──
def _lm_head_out_features(self):
if self._lm_head_mod is not None and hasattr(self._lm_head_mod, "weight"):
return self._lm_head_mod.weight.shape[0]
if self._tied_emb_weight is not None:
return self._tied_emb_weight.shape[0]
return 0
def lm_head(self, x):
"""Project hidden state to vocab logits (handles tied embeddings)."""
if self._lm_head_mod is not None:
return self._lm_head_mod(x)
return F.linear(x, self._tied_emb_weight)
def norm(self, x):
"""Apply final norm if the model has one, else identity."""
return self.final_norm(x) if self.final_norm is not None else x
def o_proj(self, layer_idx):
return _resolve_o_proj(self.layers[layer_idx])
# ── single shared forward pass ──
def forward_snapshot(self):
"""Run ONE forward over the current context, capturing logits, hidden
states, and (Tier-2) per-layer o_proj inputs. All panes read from this."""
if self.history_ids is None:
self.snapshot = None
return None
captured = {}
handles = []
if self.head_scan_supported:
def mk(li):
def fn(module, args, output):
captured[li] = args[0].detach()
return fn
for li in range(self.num_layers):
op = self.o_proj(li)
if op is not None:
handles.append(op.register_forward_hook(mk(li)))
try:
with torch.no_grad():
out = self.model(self.history_ids, output_hidden_states=True)
finally:
for h in handles: h.remove()
last_logits = out.logits[0, -1, :].float()
probs = F.softmax(last_logits, dim=-1)
ent = float(-(probs * torch.log(probs + 1e-12)).sum().item())
top1 = float(probs.max().item())
self.snapshot = Snapshot(
logits=last_logits, hidden_states=out.hidden_states,
head_inputs=captured, entropy=ent, top1=top1,
)
return self.snapshot
def _batch_top3(self, mat: torch.Tensor, k: int = 3):
"""mat: [N, hidden] → list of N rows, each [(token, prob) * k].
One norm + one lm_head + one top-k over the whole batch."""
with torch.no_grad():
normed = self.norm(mat)
logits = self.lm_head(normed) # [N, vocab]
probs = F.softmax(logits, dim=-1)
tp, ti = torch.topk(probs, k, dim=-1) # [N, k]
tp = tp.tolist(); ti = ti.tolist()
return [[(self.tokenizer.decode(ti[n][j]), tp[n][j]) for j in range(k)]
for n in range(mat.shape[0])]
def set_seed(self):
set_seed(SEED)
def load_start(self, prompt: str, system_prompt: str = None):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
full_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
full_prompt = prompt
self.history_ids = self.tokenizer(full_prompt, return_tensors="pt").to(DEVICE).input_ids
def perform_xray(self, snap: "Snapshot" = None):
"""Logit lens per layer, batched: stack all last-token hidden states,
one norm + lm_head + argmax."""
snap = snap or self.snapshot
if snap is None: return []
hs = snap.hidden_states
stacked = torch.cat([h[:, -1, :] for h in hs], dim=0) # [L+1, hidden]
with torch.no_grad():
tids = self.lm_head(self.norm(stacked)).argmax(dim=-1).tolist()
out = []
for i, tid in enumerate(tids):
label = "Emb" if i == 0 else f"L{i}"
out.append((label, self.tokenizer.decode(tid)))
return out
def generate_previews(self, width: int, length: int, snap: "Snapshot" = None):
snap = snap or self.snapshot
if snap is not None:
next_logits = snap.logits
else:
with torch.no_grad():
next_logits = self.model(self.history_ids).logits[0, -1, :]
probs = F.softmax(next_logits, dim=-1)
top_probs, top_ids = torch.topk(probs, width)
batch_inputs = []
for i in range(width):
tid = top_ids[i].item()
nxt = torch.cat([self.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1)
batch_inputs.append(nxt)
batch = torch.cat(batch_inputs, dim=0)
with torch.no_grad():
outs = self.model.generate(
batch, max_new_tokens=length, do_sample=False, temperature=0.0,
pad_token_id=self.tokenizer.pad_token_id,
)
previews = []
for i in range(width):
tid = top_ids[i].item()
p = top_probs[i].item()
full_seq = outs[i]
trunk_len = self.history_ids.shape[1]
root = self.tokenizer.decode(full_seq[trunk_len])
future = self.tokenizer.decode(full_seq[trunk_len + 1:], skip_special_tokens=False)
previews.append({
"index": i, "token_id": tid, "root_token": root, "prob": p,
"future_text": future.replace("\n", "\\n"),
"full_ids": full_seq.unsqueeze(0),
})
return previews
def commit(self, preview_obj, tokens_to_keep=None):
full_seq = preview_obj["full_ids"]
trunk_len = self.history_ids.shape[1]
if tokens_to_keep is None:
cut = full_seq.shape[1]
else:
cut = min(trunk_len + 1 + tokens_to_keep, full_seq.shape[1])
self.history_ids = full_seq[:, :cut]
committed = self.tokenizer.decode(self.history_ids[0, trunk_len:], skip_special_tokens=True)
self.full_log.append({"type": "commit", "token": preview_obj["root_token"],
"full_text": committed, "kept_len": tokens_to_keep})
def inject(self, text: str):
ids = self.tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE).input_ids
self.history_ids = torch.cat([self.history_ids, ids], dim=1)
self.full_log.append({"type": "inject", "text": text})
def scan_heads(self, snap: "Snapshot" = None, target_layers=None, target_heads=None):
"""Decode the head map from a precomputed snapshot. No model call here.
Note: the per-head 'pre-projection' and per-head 'Δ-residual' are the
same quantity (a head's last-token write into the residual via o_proj);
the original tool computed them twice. We compute once and fill both."""
snap = snap or self.snapshot
if snap is None: return None
if target_layers is None: target_layers = list(range(self.num_layers))
if target_heads is None: target_heads = list(range(self.num_heads))
target_layers = [li for li in target_layers if 0 <= li < self.num_layers]
target_heads = [hi for hi in target_heads if 0 <= hi < self.num_heads]
hs = snap.hidden_states
pre_stream = {li: {} for li in target_layers}
delta_v = {li: {} for li in target_layers}
# ── Tier-1: per-layer full Δ-residual (always available) — batched ──
full_vecs = torch.cat([hs[li + 1][:, -1, :] - hs[li][:, -1, :] for li in target_layers], dim=0)
full_top = self._batch_top3(full_vecs)
for row, li in enumerate(target_layers):
delta_v[li]["full"] = full_top[row]
do_heads = self.head_scan_supported and snap.head_inputs
if not do_heads:
return {"pre_stream": pre_stream, "delta_v": delta_v}
# ── Tier-2: per-head residual writes — one batched lm_head over all ──
rows = []
index = [] # (li, hi) aligned with rows
for li in target_layers:
inp = snap.head_inputs.get(li)
if inp is None: continue
wo = self.o_proj(li).weight
wv = wo.view(wo.shape[0], self.num_heads, self.head_dim)
x_last = inp[0, -1, :].view(self.num_heads, self.head_dim) # [H, hd]
# proj[h] = Wo[:, h, :] @ x_last[h] → [H, hidden]
proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last)
for hi in target_heads:
rows.append(proj_all[hi])
index.append((li, hi))
if rows:
head_top = self._batch_top3(torch.stack(rows, dim=0))
for (li, hi), res in zip(index, head_top):
pre_stream[li][hi] = res
delta_v[li][hi] = res # identical quantity (see docstring)
return {"pre_stream": pre_stream, "delta_v": delta_v}
def _install_mute_hooks(self):
for h in self._mute_handles: h.remove()
self._mute_handles = []
if not self.muted_heads: return
muted_by_layer = {}
for (l, h) in self.muted_heads:
muted_by_layer.setdefault(l, set()).add(h)
hd, nh = self.head_dim, self.num_heads
for li, heads in muted_by_layer.items():
def make_hook(muted_set, head_dim, num_heads):
def fn(module, args):
x = args[0]
xv = x.view(x.shape[0], x.shape[1], num_heads, head_dim)
for h in muted_set:
xv[:, :, h, :] = 0.0
return (xv.reshape(x.shape),)
return fn
op = self.o_proj(li)
if op is not None:
self._mute_handles.append(op.register_forward_pre_hook(make_hook(heads, hd, nh)))
def mute_head(self, layer, head):
if not self.head_scan_supported:
return f"head muting unsupported for {self.arch_name} (no per-head o_proj decomposition)"
if not (0 <= layer < self.num_layers): return f"Layer {layer} OOR"
if not (0 <= head < self.num_heads): return f"Head {head} OOR"
key = (layer, head)
if key in self.muted_heads: return f"L{layer}H{head} already muted"
self.muted_heads.add(key)
self._install_mute_hooks()
self.full_log.append({"type": "mute", "layer": layer, "head": head})
return f"MUTED L{layer}H{head}"
def unmute_head(self, layer, head):
key = (layer, head)
if key not in self.muted_heads: return f"L{layer}H{head} not muted"
self.muted_heads.discard(key)
self._install_mute_hooks()
self.full_log.append({"type": "unmute", "layer": layer, "head": head})
return f"UNMUTED L{layer}H{head}"
def unmute_all(self):
for h in self._mute_handles: h.remove()
self._mute_handles = []
n = len(self.muted_heads)
self.muted_heads.clear()
self.full_log.append({"type": "unmute_all"})
return f"UNMUTED ALL ({n} heads)"
def sparkline_data(self, token_str: str, rank_lo: int = 0, rank_hi: int = None,
snap: "Snapshot" = None):
if self.history_ids is None: return None
if not self.head_scan_supported:
return "unsupported"
snap = snap or self.snapshot
if snap is None or not snap.head_inputs:
return None
if token_str.isdigit():
tid = int(token_str); name = self.tokenizer.decode(tid)
else:
ids = self.tokenizer.encode(token_str, add_special_tokens=False)
if not ids: return None
tid = ids[0]; name = token_str
if rank_hi is None: rank_hi = self.vocab_size
# Build [N, hidden] of every (head, layer) last-token projection, then a
# single lm_head, then rank by COMPARISON (count of logits above target)
# rather than a full argsort — O(vocab) vs O(vocab·log vocab) per row.
rows, index = [], []
for li in sorted(snap.head_inputs.keys()):
inp = snap.head_inputs[li]
wo = self.o_proj(li).weight
wv = wo.view(wo.shape[0], self.num_heads, self.head_dim)
x_last = inp[0, -1, :].view(self.num_heads, self.head_dim)
proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last) # [H, hidden]
for hi in range(self.num_heads):
rows.append(proj_all[hi]); index.append((hi, li))
ranks = {}
if rows:
with torch.no_grad():
logits = self.lm_head(self.norm(torch.stack(rows, dim=0))) # [N, vocab]
target = logits[:, tid].unsqueeze(1) # [N, 1]
rk = (logits > target).sum(dim=1).add(1).tolist() # [N]
for (hi, li), r in zip(index, rk):
ranks[(hi, li)] = int(r)
return {"name": name, "tid": tid, "rank_lo": rank_lo, "rank_hi": rank_hi, "ranks": ranks}
# ── HTML renderers ────────────────────────────────────────────────────────
CELL_W = 10 # display chars per token cell
def render_context(probe: InteractiveProbe) -> str:
if probe is None or probe.history_ids is None:
return "<div class='pane-body'><i>No context loaded.</i></div>"
full = probe.tokenizer.decode(probe.history_ids[0], skip_special_tokens=False)
tail = full[-800:].replace("\n", " ")
if len(full) > 800: tail = "..." + tail
n = probe.history_ids.shape[1]
return (f"<div class='pane-body'>"
f"<div class='dim'>[{n} tokens]</div>"
f"<div class='context-text'>{html_escape(tail)}</div></div>")
def render_headmap(probe: InteractiveProbe, data: dict, xray: list) -> str:
if not data:
return "<div class='pane-body'><i>No head-scan run yet. Try <code>h *</code> for all layers.</i></div>"
pre, dv = data.get("pre_stream", {}), data.get("delta_v", {})
layers = sorted(pre.keys())
if not layers:
return "<div class='pane-body'><i>Empty scan.</i></div>"
heads = sorted(h for h in pre.get(layers[0], {}).keys() if isinstance(h, int))
xray_map = {label: tok for label, tok in xray} if xray else {}
def cell(tok, prob):
clean = html_escape(fmt_tok(tok, CELL_W))
c = heat_color(prob)
return f"<td class='hm-cell' style='color:{c}'>{clean}</td>"
rows = []
# Header: [pre-stream heads] | Xray | [Δ-residual heads] | L_full
hdr = ["<tr><th class='lay'>Lay</th>"]
for h in heads: hdr.append(f"<th class='ps'>H{h}</th>")
hdr.append("<th class='xr'>Xray</th>")
for h in heads: hdr.append(f"<th class='dv'>H{h}<sub>dv</sub></th>")
hdr.append("<th class='dv full'>L<sub>full</sub></th>")
hdr.append("</tr>")
rows.append("".join(hdr))
sub = ["<tr class='subhdr'><th></th>"]
if heads:
sub.append(f"<th class='ps' colspan='{len(heads)}'>per-head pre-projection (raw head → norm → lm_head)</th>")
sub.append("<th class='xr'>logit lens</th>")
sub.append(f"<th class='dv' colspan='{len(heads)+1}'>Δ-residual (attn+MLP+skip), "
f"{'per head + ' if heads else ''}full layer</th>")
sub.append("</tr>")
rows.insert(0, "".join(sub))
for l in layers:
r = [f"<tr><td class='lay'>L{l}</td>"]
for h in heads:
if h in pre.get(l, {}):
tok, p = pre[l][h][0]; r.append(cell(tok, p))
else:
r.append("<td class='hm-cell'>--</td>")
# Xray (logit lens at this layer)
label = "Emb" if l == 0 else f"L{l}"
xt = xray_map.get(label, "--")
r.append(f"<td class='hm-cell xr'>{html_escape(fmt_tok(xt, CELL_W))}</td>")
# delta-v per head
dvl = dv.get(l, {})
for h in heads:
if h in dvl:
tok, p = dvl[h][0]; r.append(cell(tok, p))
else:
r.append("<td class='hm-cell'>--</td>")
# full layer dv
if "full" in dvl:
tok, p = dvl["full"][0]
r.append(f"<td class='hm-cell full' style='color:{heat_color(p)}'>{html_escape(fmt_tok(tok, CELL_W))}</td>")
else:
r.append("<td class='hm-cell full'>--</td>")
r.append("</tr>")
rows.append("".join(r))
note = ""
if not heads:
note = ("<div class='dim' style='padding:2px 0'>per-head decomposition unavailable for this "
"architecture — showing logit-lens + full-layer Δ-residual only (Tier-1)</div>")
return f"<div class='pane-body scroll'>{note}<table class='headmap'>{''.join(rows)}</table></div>"
def render_futures(options: list, selected: int = None) -> str:
if not options:
return "<div class='pane-body'><i>No branches yet — click Load.</i></div>"
rows = ["<tr><th>cmd</th><th>prob</th><th>token</th><th>preview</th><th>hints</th></tr>"]
for o in options:
n = o["index"] + 1
prob = f"{o['prob']:.1%}"
col = heat_color(o["prob"])
root_clean = o["root_token"].strip()[:12] or o["root_token"][:12]
hints = (f"<code>{n},5</code>·<code>{n},20</code>·"
f"<code>spark {html_escape(root_clean)}</code>" if root_clean else f"<code>{n},5</code>")
cls = "selected" if selected == n else ""
rows.append(
f"<tr class='{cls}'>"
f"<td class='idx'><code class='cmd-pill'>{n}</code></td>"
f"<td style='color:{col}'>{prob}</td>"
f"<td class='tok'>{html_escape(fmt_tok(o['root_token'], 16))}</td>"
f"<td class='preview'>{html_escape(o['future_text'][:120])}…</td>"
f"<td class='cmd-hints dim'>{hints}</td>"
f"</tr>"
)
return f"<div class='pane-body scroll'><table class='futures'>{''.join(rows)}</table></div>"
def render_sparkline(probe: InteractiveProbe, sd: dict) -> str:
if not sd:
return "<div class='pane-body'><i>No sparkline. Try <code>spark &lt;token&gt; [lo:hi]</code>.</i></div>"
n_layers = probe.num_layers
n_heads = probe.num_heads
rank_lo, rank_hi = sd["rank_lo"], sd["rank_hi"]
span = max(rank_hi - rank_lo, 1)
ranks = sd["ranks"]
band_h = 8
head_colors = ["#64D32A", "#FFCF67", "#21C5C5", "#E78BFF"]
lines = []
name = sd["name"].replace("\n", "\\n")[:15]
lines.append(f"Scatter: '{html_escape(name)}' (id={sd['tid']}) rank {rank_lo:,} (top) → {rank_hi:,} (bottom)")
hdr = " " + "".join(f"{l:>3}" for l in range(n_layers))
lines.append(f"<span class='dim'>{html_escape(hdr)}</span>")
lines.append(f" {'─' * (n_layers * 3)}")
for hi in range(n_heads):
color = head_colors[hi % len(head_colors)]
grid = [[" "] * n_layers for _ in range(band_h)]
for li in range(n_layers):
r = ranks.get((hi, li), rank_hi + 1)
if r < rank_lo:
grid[0][li] = "▲"
elif r > rank_hi:
grid[band_h - 1][li] = "▼"
else:
frac = (r - rank_lo) / span
row = int(frac * (band_h - 1))
grid[max(0, min(band_h - 1, row))][li] = "●"
for ri in range(band_h):
if ri == 0:
label = f"H{hi} {rank_lo:>6}│"
elif ri == band_h - 1:
label = f" {rank_hi:>6}│"
elif ri == band_h // 2:
mid = rank_lo + span // 2
label = f" {mid:>6}│"
else:
label = f" │"
cells = []
for li in range(n_layers):
ch = grid[ri][li]
if ch in ("●", "▲", "▼"):
cells.append(f"<span style='color:{color}'> {ch} </span>")
else:
cells.append(f" {ch} ")
lines.append(html_escape(label) + "".join(cells))
if hi < n_heads - 1:
lines.append(f" {'┈' * (n_layers * 3)}")
lines.append(f" {'─' * (n_layers * 3)}")
return f"<div class='pane-body scroll'><pre class='spark'>{chr(10).join(lines)}</pre></div>"
def render_muted(probe: InteractiveProbe) -> str:
if probe is None or not probe.muted_heads:
return ""
items = ", ".join(f"L{l}H{h}" for l, h in sorted(probe.muted_heads))
return f"<span class='muted-tag'>MUTED: {items}</span>"
def render_summary(s) -> str:
"""Compact run-summary strip: arch · tokens · entropy · top1 · selected · mode · muted."""
if s.probe is None:
return "<div class='summary dim'>no model loaded</div>"
p = s.probe
n_tok = p.history_ids.shape[1] if p.history_ids is not None else 0
snap = p.snapshot
ent = f"{snap.entropy:.2f}" if snap else "—"
top1 = f"{snap.top1*100:.1f}%" if snap else "—"
tier = "T2" if p.head_scan_supported else "T1"
sel = getattr(s, "last_selected", None)
sel_str = f"branch [{sel}]" if sel is not None else "—"
muted = (f"<span style='color:#ff6b6b'>muted {len(p.muted_heads)}</span>"
if p.muted_heads else "muted 0")
arch = html_escape(p.arch_name)
mid = html_escape(p.model_id)
items = [
f"<b>{arch}</b> <span class='dim'>{mid}</span>",
f"<span class='dim'>tier</span> {tier}",
f"<span class='dim'>L</span>{p.num_layers} <span class='dim'>×H</span>{p.num_heads}",
f"<span class='dim'>tok</span> {n_tok}",
f"<span class='dim'>H</span>=<b>{ent}</b>",
f"<span class='dim'>top1</span>=<b>{top1}</b>",
f"<span class='dim'>last</span> {sel_str}",
f"<span class='dim'>w</span>={s.width} <span class='dim'>len</span>={s.preview_len}",
muted,
]
return "<div class='summary'>" + " <span class='sep'>·</span> ".join(items) + "</div>"
# ── State container ──────────────────────────────────────────────────────
class Session:
def __init__(self):
self.probe: InteractiveProbe = None
self.options: list = []
self.scan_data: dict = None
self.xray: list = []
self.spark: dict = None
self.width = DEFAULT_WIDTH
self.preview_len = DEFAULT_PREVIEW
self.transcript: list = [] # [(cmd, status), ...]
self.loaded_model = None
self.loaded_system = None
self.loaded_user = None
self.last_selected = None # which branch index was last committed
def _reset_panes(self):
self.options = []
self.scan_data = None
self.xray = []
self.spark = None
self.transcript = []
self.last_selected = None
def load(self, model_id, system_prompt, user_prompt, force=False):
"""(Re)build the probe when model/prompt changed. Returns True if rebuilt.
Fixes the stale-session bug: clicking Load with a new model/prompt now
actually resets state instead of silently keeping the old probe."""
changed = (
force or self.probe is None
or self.loaded_model != model_id
or self.loaded_system != system_prompt
or self.loaded_user != user_prompt
)
if not changed:
return False
self.probe = InteractiveProbe(model_id=model_id)
self.probe.set_seed()
self.probe.load_start(user_prompt or "", system_prompt or None)
self._reset_panes()
self.loaded_model, self.loaded_system, self.loaded_user = model_id, system_prompt, user_prompt
return True
def ensure(self, model_id: str, system_prompt: str = None, user_prompt: str = ""):
"""Lazy load for the first command if the user typed before clicking Load."""
if self.probe is None:
self.load(model_id, system_prompt, user_prompt or "")
def panes_html(s: Session):
return (
render_summary(s),
render_context(s.probe),
render_headmap(s.probe, s.scan_data, s.xray),
render_futures(s.options, s.last_selected),
render_sparkline(s.probe, s.spark),
render_muted(s.probe),
)
def status_html(s: Session) -> str:
if not s.transcript:
return "<div class='pane-body dim'>Ready.</div>"
rows = []
for cmd, status in s.transcript[-20:]:
rows.append(f"<div><span class='prompt'>›</span> <code>{html_escape(cmd)}</code> "
f"<span class='dim'>— {html_escape(status)}</span></div>")
return f"<div class='pane-body scroll'>{''.join(rows)}</div>"
def handle_command(cmd: str, s: Session, model_id: str, system_prompt: str):
cmd = (cmd or "").strip()
if not cmd:
return panes_html(s) + (status_html(s),)
# Lazy load model on first command (so UI renders fast)
s.ensure(model_id=model_id, system_prompt=system_prompt)
status = "ok"
try:
lc = cmd.lower()
if cmd[0].isdigit():
parts = cmd.split(",")
idx = int(parts[0]) - 1
if 0 <= idx < len(s.options):
keep_len = int(parts[1]) if len(parts) > 1 else None
s.probe.commit(s.options[idx], keep_len)
s.last_selected = idx + 1
_refresh_after_step(s)
status = f"selected branch [{idx+1}]"
else:
status = f"invalid index {idx+1}"
elif lc.startswith("i "):
s.probe.inject(cmd[2:])
_refresh_after_step(s)
status = f"injected '{cmd[2:][:30]}'"
elif lc.startswith("h "):
parts = cmd.split()
target_layers = None if parts[1] == "*" else [int(parts[1])]
target_heads = None
if len(parts) > 2:
target_heads = None if parts[2] == "*" else [int(parts[2])]
if s.probe.snapshot is None:
s.probe.forward_snapshot()
s.scan_data = s.probe.scan_heads(target_layers=target_layers, target_heads=target_heads)
s.xray = s.probe.perform_xray()
status = f"head-scan: layers={parts[1]} heads={parts[2] if len(parts)>2 else '*'}"
elif lc.startswith("top "):
ranks = [int(r.strip()) for r in cmd[4:].split(",")]
for rank in ranks:
with torch.no_grad():
out = s.probe.model(s.probe.history_ids)
logits = out.logits[0, -1, :]
probs = F.softmax(logits, dim=-1)
sp, si = torch.sort(probs, descending=True)
if rank < 1 or rank > len(si): continue
tid = si[rank - 1].item()
s.probe.history_ids = torch.cat(
[s.probe.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1
)
s.probe.full_log.append({"type": "top_select", "rank": rank,
"token": s.probe.tokenizer.decode([tid]),
"token_id": tid, "prob": sp[rank-1].item()})
_refresh_after_step(s)
status = f"top-selected {len(ranks)} tokens"
elif lc.startswith("rew "):
n = int(cmd[4:].strip())
cur = s.probe.history_ids.shape[1]
new = max(1, cur - n)
s.probe.history_ids = s.probe.history_ids[:, :new]
s.probe.full_log.append({"type": "rewind", "tokens_removed": cur - new})
_refresh_after_step(s)
status = f"rewound {cur - new} tokens"
elif lc.startswith("w "):
s.width = int(cmd[2:]); status = f"width={s.width}"
_refresh_futures(s)
elif lc.startswith("l "):
s.preview_len = int(cmd[2:]); status = f"preview_len={s.preview_len}"
_refresh_futures(s)
elif lc.startswith("spark "):
parts = cmd[6:].strip().split()
tok = parts[0]
rl, rh = 0, None
if len(parts) > 1 and ":" in parts[1]:
a, b = parts[1].split(":")
rl, rh = int(a), int(b)
sd = s.probe.sparkline_data(tok, rl, rh)
if sd == "unsupported":
status = f"token rank trace unsupported for {s.probe.arch_name} (Tier-1 only)"
elif sd is None:
status = f"token '{tok}' not found in vocab"
else:
s.spark = sd
status = f"spark '{tok}' [{rl}:{rh if rh else 'vocab'}]"
elif lc.startswith("mute ") and not lc.startswith("muted"):
parts = cmd.split()
status = s.probe.mute_head(int(parts[1]), int(parts[2]))
_refresh_after_step(s)
elif lc == "unmute all":
status = s.probe.unmute_all()
_refresh_after_step(s)
elif lc.startswith("unmute "):
parts = cmd.split()
status = s.probe.unmute_head(int(parts[1]), int(parts[2]))
_refresh_after_step(s)
elif lc == "muted":
status = f"muted: {sorted(s.probe.muted_heads)}"
elif lc == "s":
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
fname = f"/tmp/session_{ts}.json"
with open(fname, "w") as f: json.dump(s.probe.full_log, f, indent=2)
status = f"saved {fname}"
elif lc == "refresh" or lc == "r":
_refresh_after_step(s)
status = "refreshed"
elif lc == "q":
status = "(q is a no-op in the web UI — close the tab)"
else:
status = f"unknown: {cmd}"
except Exception as e:
status = f"error: {type(e).__name__}: {e}"
s.transcript.append((cmd, status))
return panes_html(s) + (status_html(s),)
def _refresh_futures(s: Session):
"""Re-derive branches from the existing snapshot (context unchanged)."""
if s.probe is None or s.probe.history_ids is None: return
if s.probe.snapshot is None:
s.probe.forward_snapshot()
s.options = s.probe.generate_previews(s.width, s.preview_len)
def _refresh_after_step(s: Session):
"""Context changed: ONE forward, then every pane reads that snapshot."""
if s.probe is None or s.probe.history_ids is None: return
s.probe.forward_snapshot()
s.options = s.probe.generate_previews(s.width, s.preview_len)
s.scan_data = s.probe.scan_heads()
s.xray = s.probe.perform_xray()
def initial_load(model_id: str, system_prompt: str, user_prompt: str, s: Session):
rebuilt = s.load(model_id, system_prompt or None, user_prompt or "")
_refresh_after_step(s)
tier = "Tier-2 (per-head)" if s.probe.head_scan_supported else "Tier-1 only"
verb = "loaded" if rebuilt else "ready"
s.transcript.append(("(load)", f"{verb} {s.probe.arch_name} {s.probe.model_id} "
f"— {s.probe.num_layers}{s.probe.num_heads}H, {tier}"))
return panes_html(s) + (status_html(s),)
# ── Gradio UI ─────────────────────────────────────────────────────────────
CSS = """
/* ── Global tightening ── */
.gradio-container { max-width: 100% !important; padding: 6px 10px !important; }
.gradio-container .main { padding: 0 !important; gap: 4px !important; }
.gradio-container .gap, .gradio-container .form { gap: 4px !important; }
.gradio-container .block { padding: 3px !important; border-radius: 4px !important; }
.gradio-container .gr-block { padding: 0 !important; }
.gradio-container .prose { margin: 0 !important; }
/* ── Header: kill the giant title margins ── */
#hdr h1 { font-size: 18px !important; margin: 0 !important; padding: 0 !important; line-height: 1.2 !important; }
#hdr p, #hdr em { font-size: 11px !important; margin: 0 !important; padding: 0 !important;
color: #8b949e !important; line-height: 1.2 !important; }
#hdr { margin: 0 !important; padding: 4px 0 !important; }
/* ── Top bar: compact textboxes ── */
#topbar { gap: 6px !important; }
#topbar > div { padding: 0 !important; }
#topbar textarea, #topbar input[type="text"] {
font-size: 12px !important; padding: 4px 6px !important; min-height: 28px !important;
line-height: 1.3 !important;
}
#topbar label, #topbar .gr-text-input label, #topbar span[data-testid="block-label"] {
font-size: 10px !important; padding: 0 0 2px 0 !important; margin: 0 !important;
text-transform: uppercase; letter-spacing: 0.5px; color: #8b949e !important;
}
#topbar button { min-height: 56px !important; align-self: stretch !important;
font-size: 13px !important; padding: 4px 8px !important; }
/* ── CMD bar tightening ── */
#cmdbar textarea, #cmdbar input[type="text"] {
font-size: 12px !important; padding: 4px 8px !important; min-height: 28px !important;
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace !important;
}
#cmdbar label, #cmdbar span[data-testid="block-label"] {
font-size: 10px !important; color: #8b949e !important; margin: 0 !important; padding: 0 0 2px 0 !important;
}
#cmdbar button { min-height: 36px !important; font-size: 12px !important; }
/* ── Help line (small, dim) ── */
#help { font-size: 10.5px !important; color: #8b949e !important; margin: 2px 0 !important; padding: 0 !important; }
#help code { font-size: 10.5px !important; padding: 0 2px !important; background: #161b22 !important; }
#help p { margin: 0 !important; line-height: 1.5 !important; }
/* ── Kill empty muted-row gap before Context ── */
#muted-row { padding: 0 !important; margin: 0 !important; min-height: 0 !important; }
#muted-row:empty { display: none !important; }
#muted-row > * { padding: 0 !important; margin: 0 !important; }
/* ── Panes ── */
.pane { border: 1px solid #30363d; border-radius: 4px; background: #0d1117; color: #c9d1d9;
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace; font-size: 12px; }
.pane h3 { margin: 0; padding: 3px 8px; background: #161b22; border-bottom: 1px solid #30363d;
font-size: 10px; letter-spacing: 0.5px; text-transform: uppercase; color: #8b949e; }
.pane-body { padding: 5px 8px; max-height: 440px; overflow: auto; }
.pane-body.scroll { overflow: auto; }
.dim { color: #6e7681; }
.context-text { white-space: pre-wrap; word-break: break-word; margin-top: 4px; }
table.headmap { border-collapse: collapse; font-size: 11px; }
table.headmap th, table.headmap td { padding: 1px 4px; border: 1px solid #21262d; text-align: center; white-space: pre; }
table.headmap th.ps { background: #112; }
table.headmap th.xr { background: #221; color: #ECA60F; }
table.headmap th.dv { background: #122; }
table.headmap th.full { background: #133; color: #64D32A; }
table.headmap td.lay { color: #8b949e; }
table.headmap td.xr { color: #ECA60F; background: #15110a; }
table.headmap td.full { background: #0a1410; }
table.headmap tr.subhdr th { background: #161b22; color: #8b949e; font-weight: normal; font-size: 10px; }
table.futures { width: 100%; border-collapse: collapse; }
table.futures th { background: #161b22; padding: 3px 6px; text-align: left; color: #8b949e;
font-size: 10px; text-transform: uppercase; letter-spacing: 0.4px; }
table.futures td { padding: 5px 6px; border-bottom: 1px solid #21262d; vertical-align: top; }
table.futures tr:hover { background: #11161c; }
table.futures tr.selected { background: #0f2418; outline: 1px solid #64D32A; }
table.futures td.idx { color: #6e7681; }
.cmd-pill { background: #21262d; color: #c9d1d9; padding: 1px 7px; border-radius: 3px;
font-weight: bold; font-size: 11px; border: 1px solid #30363d; }
table.futures tr.selected .cmd-pill { background: #1a4022; border-color: #64D32A; color: #b8f0c4; }
table.futures td.tok { color: #c9d1d9; font-weight: bold; white-space: pre; }
table.futures td.preview { color: #8b949e; max-width: 320px; }
table.futures td.cmd-hints code { background: transparent; padding: 0 1px; color: #8b949e; font-size: 10.5px; }
/* ── Run-summary strip ── */
#summary-strip { padding: 0 !important; margin: 0 !important; }
.summary { font-family: 'JetBrains Mono', Consolas, monospace; font-size: 11px;
padding: 4px 8px; background: #0d1117; border: 1px solid #30363d; border-radius: 4px;
color: #c9d1d9; line-height: 1.5; }
.summary .sep { color: #30363d; padding: 0 4px; }
.summary .dim { color: #6e7681; }
/* ── Locked-height instrument row: widening branches does not reflow head map ── */
#dual-row { align-items: stretch !important; }
#dual-row > div { display: flex; flex-direction: column; }
#dual-row #col-branch, #dual-row #col-head { height: 55vh; min-height: 360px; max-height: 640px; }
#dual-row .pane { height: 100%; display: flex; flex-direction: column; overflow: hidden; }
#dual-row .pane h3 { flex: 0 0 auto; }
#dual-row .pane .pane-body { flex: 1 1 auto; max-height: none !important; overflow: auto; }
/* ── Placeholder legibility (was being swallowed by the dark theme) ── */
.gradio-container textarea::placeholder,
.gradio-container input[type="text"]::placeholder {
color: #6e7681 !important; opacity: 1 !important;
}
#cmdbar textarea::placeholder, #cmdbar input[type="text"]::placeholder {
color: #8b949e !important; opacity: 1 !important;
}
pre.spark { margin: 0; font-size: 11px; line-height: 1.1; }
.muted-tag { color: #ff6b6b; font-weight: bold; padding: 2px 6px; }
.prompt { color: #64D32A; }
"""
HELP = (
"`1-N` select branch · `1,5` select + keep N tokens · `i <text>` inject text · "
"`h * | h L | h L H` head×layer scan · `top R1,R2,…` pick by rank · `rew N` rewind · "
"`spark <tok> [lo:hi]` token rank trace · `w N` width · `l N` length · "
"`mute L H` · `unmute L H` · `unmute all` · `muted` · `r` refresh · `s` save"
)
def build_ui():
with gr.Blocks(title="Cartogemma", css=CSS, theme=gr.themes.Base()) as demo:
gr.Markdown("# Cartogemma \n*Mechanistic probe — logit lens · head map · branches · token rank trace — "
"on the Gemma family (3 / 3n / 4) and most HF decoder LMs (Qwen, Llama, …). "
"Architecture is auto-discovered; per-head views degrade gracefully when unavailable.*",
elem_id="hdr")
session = gr.State(Session())
# ── Compact top bar: model | system | user | load ──
MODEL_PRESETS = [
"google/gemma-3-270m-it",
"google/gemma-3-1b-it",
"google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"google/gemma-4-E2B-it",
"google/gemma-4-E4B-it",
"Qwen/Qwen3-0.6B",
"meta-llama/Llama-3.2-1B-Instruct",
]
with gr.Row(elem_id="topbar"):
model_id = gr.Dropdown(
choices=MODEL_PRESETS, value=DEFAULT_MODEL_ID, label="Model",
allow_custom_value=True, filterable=True, scale=3, container=True,
)
system_prompt = gr.Textbox(label="System", value="",
lines=1, max_lines=3, scale=4, container=True)
user_prompt = gr.Textbox(
label="User",
value=("Hi, I need something like a cake recipe but it can't include eggs, sugar, "
"gluten-- so no flour, actually make that no grains of any sort, as well as "
"liquid and solid ingredients, none of them either. Thanks!"),
lines=1, max_lines=3, scale=5, container=True,
)
load_btn = gr.Button("Load", variant="primary", scale=1, min_width=80)
# ── Run-summary strip: arch · tier · L×H · tokens · entropy · top1 · selected · mode · muted ──
summary_pane = gr.HTML(value="<div class='summary dim'>no model loaded</div>", elem_id="summary-strip")
muted_html = gr.HTML(elem_id="muted-row")
# ── Context: full width ──
ctx_pane = gr.HTML(value="<div class='pane'><h3>Context</h3>"
"<div class='pane-body'><i>Click <b>Load</b> to begin.</i></div></div>")
# ── Command bar: directly under Context (eye scans context → types command) ──
with gr.Row(elem_id="cmdbar"):
cmd = gr.Textbox(
label="CMD",
placeholder=(
"1-N select branch · 1,5 select+keep · i <text> inject · "
"h *|h L|h L H scan · top R1,R2,… rank-pick · rew N · "
"spark <tok> [lo:hi] rank trace · w N · l N · mute L H · unmute L H|all · muted · r · s"
),
scale=10, autofocus=True, lines=1, max_lines=1, container=True,
)
run_btn = gr.Button("Run", scale=1, variant="primary", min_width=70)
# ── Instrument grid: Branches | Head Map on one row, Spark + Transcript below.
# Row heights locked so widening branches (w 30) doesn't reflow neighbors. ──
with gr.Row(elem_id="dual-row"):
with gr.Column(scale=3, min_width=280, elem_id="col-branch"):
fut_pane = gr.HTML(value="<div class='pane'><h3>Branches (next-token top-k)</h3>"
"<div class='pane-body'><i>Top-k continuations appear after Load.</i></div></div>")
with gr.Column(scale=7, elem_id="col-head"):
hm_pane = gr.HTML(value="<div class='pane'><h3>Head Map "
"— per-head pre-projection · logit-lens · per-head Δ-residual · L<sub>full</sub></h3>"
"<div class='pane-body'><i>Auto-runs after every step.</i></div></div>")
spark_pane = gr.HTML(value="<div class='pane'><h3>Token Rank Trace "
"(per-(head,layer) rank of a chosen token)</h3>"
"<div class='pane-body'><i>Run <code>spark &lt;token&gt;</code>.</i></div></div>")
status = gr.HTML(value="<div class='pane'><h3>Transcript</h3>"
"<div class='pane-body dim'>Ready.</div></div>")
# Wrap renderers so each output is its own pane (with title)
def _wrap(summary, ctx, hm, fut, sp, mute, st):
return (
summary,
f"<div class='pane'><h3>Context</h3>{ctx}</div>",
f"<div class='pane'><h3>Head Map "
f"— per-head pre-projection · logit-lens · per-head Δ-residual · L<sub>full</sub></h3>{hm}</div>",
f"<div class='pane'><h3>Branches (next-token top-k)</h3>{fut}</div>",
f"<div class='pane'><h3>Token Rank Trace "
f"(per-(head,layer) rank of a chosen token)</h3>{sp}</div>",
mute,
f"<div class='pane'><h3>Transcript</h3>{st}</div>",
)
def on_load(mid, sp, up, s):
summary, ctx, hm, fut, spk, mute, st = initial_load(mid, sp, up, s)
return _wrap(summary, ctx, hm, fut, spk, mute, st) + (s,)
def on_cmd(c, s, mid, sp):
summary, ctx, hm, fut, spk, mute, st = handle_command(c, s, mid, sp)
return _wrap(summary, ctx, hm, fut, spk, mute, st) + ("", s)
OUT_BASE = [summary_pane, ctx_pane, hm_pane, fut_pane, spark_pane, muted_html, status]
load_btn.click(
on_load,
inputs=[model_id, system_prompt, user_prompt, session],
outputs=OUT_BASE + [session],
)
run_btn.click(
on_cmd,
inputs=[cmd, session, model_id, system_prompt],
outputs=OUT_BASE + [cmd, session],
)
cmd.submit(
on_cmd,
inputs=[cmd, session, model_id, system_prompt],
outputs=OUT_BASE + [cmd, session],
)
return demo
if __name__ == "__main__":
build_ui().launch()