Spaces:
Running on Zero
Running on Zero
| """Cartogemma — Gradio Space wrapping InteractiveProbe against google/gemma-3-1b-it. | |
| Faithful to the cartographer3.py / cartographer_tui.py presentation: | |
| - Context tail pane | |
| - Head Map: Layer | H0..Hn (pre-stream) | Xray | H0..Hn dv | Full dv | |
| - Futures pane (top-k branches with probs + previews) | |
| - Sparkline pane (token rank scatter across heads x layers, rendered as <pre>) | |
| - REPL command bar (1-N, h, top, rew, i, spark, mute, unmute, w, l, s, q) | |
| """ | |
| import os | |
| import re | |
| import io | |
| import sys | |
| import json | |
| import math | |
| import unicodedata | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| try: | |
| from transformers import AutoModelForImageTextToText | |
| except Exception: # older transformers | |
| AutoModelForImageTextToText = None | |
| # ── Config ──────────────────────────────────────────────────────────────── | |
| DEFAULT_MODEL_ID = os.environ.get("CARTOGEMMA_MODEL", "google/gemma-3-270m-it") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| SEED = 42 | |
| DEFAULT_WIDTH = 5 | |
| DEFAULT_PREVIEW = 20 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ── HEAT palette (mirrors cartographer3) ────────────────────────────────── | |
| HEAT_HEX = ["#59749C", "#948863", "#ECA60F", "#FFCF67", "#219C7F", "#E0D05B", "#64D32A"] | |
| def heat_color(p: float) -> str: | |
| if p >= 0.85: return HEAT_HEX[6] | |
| if p >= 0.70: return HEAT_HEX[5] | |
| if p >= 0.50: return HEAT_HEX[4] | |
| if p >= 0.30: return HEAT_HEX[3] | |
| if p >= 0.15: return HEAT_HEX[2] | |
| if p >= 0.05: return HEAT_HEX[1] | |
| return HEAT_HEX[0] | |
| def char_width(ch: str) -> int: | |
| cat = unicodedata.category(ch) | |
| if cat in ("Mn", "Me"): return 0 | |
| eaw = unicodedata.east_asian_width(ch) | |
| if eaw in ("W", "F"): return 2 | |
| return 1 | |
| def fmt_tok(tok: str, max_w: int = 10) -> str: | |
| tok = tok.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t") | |
| w, out = 0, [] | |
| for ch in tok: | |
| cw = char_width(ch) | |
| if w + cw > max_w: break | |
| out.append(ch); w += cw | |
| return "".join(out) | |
| def html_escape(s: str) -> str: | |
| return s.replace("&", "&").replace("<", "<").replace(">", ">") | |
| ANSI_RE = re.compile(r"\033\[[0-9;]*m") | |
| def strip_ansi(s: str) -> str: | |
| return ANSI_RE.sub("", s) | |
| # ── Architecture resolver ───────────────────────────────────────────────── | |
| # The probe makes two tiers of assumption about a model: | |
| # TIER 1 (generic): output_hidden_states + a final norm + an lm_head/tied | |
| # embedding. Gives logit-lens (xray), per-layer Δ-residual, branches, | |
| # top-by-rank, inject, rewind. Works on essentially any HF decoder LM. | |
| # TIER 2 (architecture-specific): attention writes the residual through an | |
| # o_proj whose INPUT is the concatenation of contiguous per-head blocks | |
| # (standard MHA/GQA: in_features == num_heads * head_dim). Gives per-head | |
| # pre-projection, per-head Δ-residual, head muting, token rank trace. | |
| # We auto-discover both tiers and flip `head_scan_supported` off (degrading to | |
| # Tier 1) when the Tier-2 assumption doesn't hold (fused QKV, MLA, exotic attn). | |
| def _resolve_text_config(cfg): | |
| """Find the sub-config that actually holds num_hidden_layers, etc. | |
| Multimodal configs (Gemma3/4, gemma3n) nest the LM under .text_config.""" | |
| if hasattr(cfg, "get_text_config"): | |
| try: | |
| tc = cfg.get_text_config() | |
| if tc is not None and hasattr(tc, "num_hidden_layers"): | |
| return tc | |
| except Exception: | |
| pass | |
| tc = getattr(cfg, "text_config", None) | |
| if tc is not None and hasattr(tc, "num_hidden_layers"): | |
| return tc | |
| return cfg | |
| def _is_decoder(mod): | |
| layers = getattr(mod, "layers", None) | |
| if not isinstance(layers, nn.ModuleList) or len(layers) == 0: | |
| return False | |
| first = layers[0] | |
| return any(hasattr(first, a) for a in ("self_attn", "attn", "attention")) | |
| _DECODER_PATHS = [ | |
| ("model",), ("model", "language_model"), ("language_model",), | |
| ("model", "model"), ("model", "language_model", "model"), | |
| ("transformer",), ("model", "decoder"), ("decoder",), | |
| ] | |
| def _resolve_decoder(model): | |
| """Locate the module that owns the decoder-block ModuleList (.layers).""" | |
| for path in _DECODER_PATHS: | |
| obj = model | |
| ok = True | |
| for a in path: | |
| if hasattr(obj, a): | |
| obj = getattr(obj, a) | |
| else: | |
| ok = False | |
| break | |
| if ok and _is_decoder(obj): | |
| return obj | |
| # Generic fallback: longest attention-bearing ModuleList in the tree. | |
| best = None | |
| for _, mod in model.named_modules(): | |
| if _is_decoder(mod) and (best is None or len(mod.layers) > len(best.layers)): | |
| best = mod | |
| if best is not None: | |
| return best | |
| raise RuntimeError("Could not locate a decoder layer stack for this model.") | |
| def _resolve_final_norm(decoder): | |
| for a in ("norm", "final_layernorm", "ln_f", "final_norm"): | |
| n = getattr(decoder, a, None) | |
| if n is not None: | |
| return n | |
| return None | |
| def _resolve_lm_head(model): | |
| for path in (("lm_head",), ("language_model", "lm_head"), ("model", "lm_head")): | |
| obj = model | |
| ok = True | |
| for a in path: | |
| if hasattr(obj, a): | |
| obj = getattr(obj, a) | |
| else: | |
| ok = False | |
| break | |
| if ok and obj is not None and hasattr(obj, "weight"): | |
| return obj | |
| return None # caller falls back to tied input embeddings | |
| def _resolve_o_proj(layer): | |
| """The attention output projection whose input is per-head-concatenated.""" | |
| attn = None | |
| for a in ("self_attn", "attn", "attention"): | |
| attn = getattr(layer, a, None) | |
| if attn is not None: | |
| break | |
| if attn is None: | |
| return None | |
| for a in ("o_proj", "out_proj", "wo", "dense", "c_proj"): | |
| p = getattr(attn, a, None) | |
| if p is not None and hasattr(p, "weight"): | |
| return p | |
| return None | |
| class Snapshot: | |
| """One forward pass over the current context, shared by every pane. | |
| Avoids the old pattern of xray / scan / branches each doing their own | |
| forward. head_inputs is populated only when head_scan is supported.""" | |
| logits: torch.Tensor # [vocab] last-token logits (float) | |
| hidden_states: tuple # len L+1, each [1, seq, hidden] | |
| head_inputs: dict = field(default_factory=dict) # layer_idx -> o_proj input [1, seq, H*hd] | |
| entropy: float = 0.0 # next-token entropy (nats) | |
| top1: float = 0.0 # next-token top-1 probability | |
| # ── The Probe (lifted from cartographer3, ANSI removed) ─────────────────── | |
| class InteractiveProbe: | |
| def __init__(self, model_id: str = None): | |
| model_id = model_id or DEFAULT_MODEL_ID | |
| self.model_id = model_id | |
| print(f"[*] Loading {model_id} on {DEVICE}...") | |
| kw = {"token": HF_TOKEN} if HF_TOKEN else {} | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kw) | |
| self.tokenizer.padding_side = "left" | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| load_kw = dict(torch_dtype=dtype, **kw) | |
| if DEVICE == "cuda": | |
| load_kw["device_map"] = "auto" | |
| # Try text CausalLM first; fall back to the multimodal head (gemma-3-4b, | |
| # gemma-4, gemma3n are all image-text-to-text wrappers around a text LM). | |
| self.model = None | |
| errs = [] | |
| loaders = [AutoModelForCausalLM] | |
| if AutoModelForImageTextToText is not None: | |
| loaders.append(AutoModelForImageTextToText) | |
| for loader in loaders: | |
| try: | |
| self.model = loader.from_pretrained(model_id, **load_kw) | |
| break | |
| except Exception as e: | |
| errs.append(f"{loader.__name__}: {type(e).__name__}: {e}") | |
| if self.model is None: | |
| raise RuntimeError("Could not load model with any auto-class:\n" + "\n".join(errs)) | |
| if DEVICE == "cpu": | |
| self.model = self.model.to(DEVICE) | |
| self.model.eval() | |
| # ── Auto-discover the architecture (see resolver notes above) ── | |
| self._decoder = _resolve_decoder(self.model) | |
| self.layers = self._decoder.layers | |
| self.final_norm = _resolve_final_norm(self._decoder) | |
| self._lm_head_mod = _resolve_lm_head(self.model) | |
| if self._lm_head_mod is None: | |
| emb = self.model.get_input_embeddings() | |
| self._tied_emb_weight = emb.weight | |
| else: | |
| self._tied_emb_weight = None | |
| text_cfg = _resolve_text_config(self.model.config) | |
| # Trust the actual stack length over config (resolver may pick a sub-stack). | |
| self.num_layers = len(self.layers) | |
| self.num_heads = int(getattr(text_cfg, "num_attention_heads", 0)) or 1 | |
| self.num_kv_heads = int(getattr(text_cfg, "num_key_value_heads", self.num_heads)) | |
| hidden_size = int(getattr(text_cfg, "hidden_size", 0)) or 0 | |
| vocab_size = int(getattr(text_cfg, "vocab_size", 0)) or int(getattr(self.model.config, "vocab_size", 0)) | |
| self.vocab_size = vocab_size or self._lm_head_out_features() | |
| o0 = _resolve_o_proj(self.layers[0]) | |
| cfg_head_dim = getattr(text_cfg, "head_dim", None) | |
| if cfg_head_dim: | |
| self.head_dim = int(cfg_head_dim) | |
| elif o0 is not None: | |
| self.head_dim = o0.weight.shape[1] // self.num_heads | |
| elif hidden_size: | |
| self.head_dim = hidden_size // self.num_heads | |
| else: | |
| self.head_dim = 0 | |
| self.hidden_size = hidden_size | |
| # Tier-2 (per-head) support: o_proj input must decompose as heads*head_dim. | |
| self.head_scan_supported = bool( | |
| o0 is not None | |
| and self.final_norm is not None | |
| and self.head_dim | |
| and o0.weight.shape[1] == self.num_heads * self.head_dim | |
| ) | |
| self.arch_name = type(self.model).__name__ | |
| print(f"[*] {self.arch_name}: {self.num_layers}L x {self.num_heads}H " | |
| f"(kv={self.num_kv_heads}), head_dim={self.head_dim}, hidden={hidden_size}, " | |
| f"vocab={self.vocab_size}, head_scan={'yes' if self.head_scan_supported else 'NO (Tier-1 only)'}") | |
| self.history_ids = None | |
| self.full_log = [] | |
| self.muted_heads = set() | |
| self._mute_handles = [] | |
| self.snapshot = None | |
| # ── architecture accessors ── | |
| def _lm_head_out_features(self): | |
| if self._lm_head_mod is not None and hasattr(self._lm_head_mod, "weight"): | |
| return self._lm_head_mod.weight.shape[0] | |
| if self._tied_emb_weight is not None: | |
| return self._tied_emb_weight.shape[0] | |
| return 0 | |
| def lm_head(self, x): | |
| """Project hidden state to vocab logits (handles tied embeddings).""" | |
| if self._lm_head_mod is not None: | |
| return self._lm_head_mod(x) | |
| return F.linear(x, self._tied_emb_weight) | |
| def norm(self, x): | |
| """Apply final norm if the model has one, else identity.""" | |
| return self.final_norm(x) if self.final_norm is not None else x | |
| def o_proj(self, layer_idx): | |
| return _resolve_o_proj(self.layers[layer_idx]) | |
| # ── single shared forward pass ── | |
| def forward_snapshot(self): | |
| """Run ONE forward over the current context, capturing logits, hidden | |
| states, and (Tier-2) per-layer o_proj inputs. All panes read from this.""" | |
| if self.history_ids is None: | |
| self.snapshot = None | |
| return None | |
| captured = {} | |
| handles = [] | |
| if self.head_scan_supported: | |
| def mk(li): | |
| def fn(module, args, output): | |
| captured[li] = args[0].detach() | |
| return fn | |
| for li in range(self.num_layers): | |
| op = self.o_proj(li) | |
| if op is not None: | |
| handles.append(op.register_forward_hook(mk(li))) | |
| try: | |
| with torch.no_grad(): | |
| out = self.model(self.history_ids, output_hidden_states=True) | |
| finally: | |
| for h in handles: h.remove() | |
| last_logits = out.logits[0, -1, :].float() | |
| probs = F.softmax(last_logits, dim=-1) | |
| ent = float(-(probs * torch.log(probs + 1e-12)).sum().item()) | |
| top1 = float(probs.max().item()) | |
| self.snapshot = Snapshot( | |
| logits=last_logits, hidden_states=out.hidden_states, | |
| head_inputs=captured, entropy=ent, top1=top1, | |
| ) | |
| return self.snapshot | |
| def _batch_top3(self, mat: torch.Tensor, k: int = 3): | |
| """mat: [N, hidden] → list of N rows, each [(token, prob) * k]. | |
| One norm + one lm_head + one top-k over the whole batch.""" | |
| with torch.no_grad(): | |
| normed = self.norm(mat) | |
| logits = self.lm_head(normed) # [N, vocab] | |
| probs = F.softmax(logits, dim=-1) | |
| tp, ti = torch.topk(probs, k, dim=-1) # [N, k] | |
| tp = tp.tolist(); ti = ti.tolist() | |
| return [[(self.tokenizer.decode(ti[n][j]), tp[n][j]) for j in range(k)] | |
| for n in range(mat.shape[0])] | |
| def set_seed(self): | |
| set_seed(SEED) | |
| def load_start(self, prompt: str, system_prompt: str = None): | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| full_prompt = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| full_prompt = prompt | |
| self.history_ids = self.tokenizer(full_prompt, return_tensors="pt").to(DEVICE).input_ids | |
| def perform_xray(self, snap: "Snapshot" = None): | |
| """Logit lens per layer, batched: stack all last-token hidden states, | |
| one norm + lm_head + argmax.""" | |
| snap = snap or self.snapshot | |
| if snap is None: return [] | |
| hs = snap.hidden_states | |
| stacked = torch.cat([h[:, -1, :] for h in hs], dim=0) # [L+1, hidden] | |
| with torch.no_grad(): | |
| tids = self.lm_head(self.norm(stacked)).argmax(dim=-1).tolist() | |
| out = [] | |
| for i, tid in enumerate(tids): | |
| label = "Emb" if i == 0 else f"L{i}" | |
| out.append((label, self.tokenizer.decode(tid))) | |
| return out | |
| def generate_previews(self, width: int, length: int, snap: "Snapshot" = None): | |
| snap = snap or self.snapshot | |
| if snap is not None: | |
| next_logits = snap.logits | |
| else: | |
| with torch.no_grad(): | |
| next_logits = self.model(self.history_ids).logits[0, -1, :] | |
| probs = F.softmax(next_logits, dim=-1) | |
| top_probs, top_ids = torch.topk(probs, width) | |
| batch_inputs = [] | |
| for i in range(width): | |
| tid = top_ids[i].item() | |
| nxt = torch.cat([self.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1) | |
| batch_inputs.append(nxt) | |
| batch = torch.cat(batch_inputs, dim=0) | |
| with torch.no_grad(): | |
| outs = self.model.generate( | |
| batch, max_new_tokens=length, do_sample=False, temperature=0.0, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| previews = [] | |
| for i in range(width): | |
| tid = top_ids[i].item() | |
| p = top_probs[i].item() | |
| full_seq = outs[i] | |
| trunk_len = self.history_ids.shape[1] | |
| root = self.tokenizer.decode(full_seq[trunk_len]) | |
| future = self.tokenizer.decode(full_seq[trunk_len + 1:], skip_special_tokens=False) | |
| previews.append({ | |
| "index": i, "token_id": tid, "root_token": root, "prob": p, | |
| "future_text": future.replace("\n", "\\n"), | |
| "full_ids": full_seq.unsqueeze(0), | |
| }) | |
| return previews | |
| def commit(self, preview_obj, tokens_to_keep=None): | |
| full_seq = preview_obj["full_ids"] | |
| trunk_len = self.history_ids.shape[1] | |
| if tokens_to_keep is None: | |
| cut = full_seq.shape[1] | |
| else: | |
| cut = min(trunk_len + 1 + tokens_to_keep, full_seq.shape[1]) | |
| self.history_ids = full_seq[:, :cut] | |
| committed = self.tokenizer.decode(self.history_ids[0, trunk_len:], skip_special_tokens=True) | |
| self.full_log.append({"type": "commit", "token": preview_obj["root_token"], | |
| "full_text": committed, "kept_len": tokens_to_keep}) | |
| def inject(self, text: str): | |
| ids = self.tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE).input_ids | |
| self.history_ids = torch.cat([self.history_ids, ids], dim=1) | |
| self.full_log.append({"type": "inject", "text": text}) | |
| def scan_heads(self, snap: "Snapshot" = None, target_layers=None, target_heads=None): | |
| """Decode the head map from a precomputed snapshot. No model call here. | |
| Note: the per-head 'pre-projection' and per-head 'Δ-residual' are the | |
| same quantity (a head's last-token write into the residual via o_proj); | |
| the original tool computed them twice. We compute once and fill both.""" | |
| snap = snap or self.snapshot | |
| if snap is None: return None | |
| if target_layers is None: target_layers = list(range(self.num_layers)) | |
| if target_heads is None: target_heads = list(range(self.num_heads)) | |
| target_layers = [li for li in target_layers if 0 <= li < self.num_layers] | |
| target_heads = [hi for hi in target_heads if 0 <= hi < self.num_heads] | |
| hs = snap.hidden_states | |
| pre_stream = {li: {} for li in target_layers} | |
| delta_v = {li: {} for li in target_layers} | |
| # ── Tier-1: per-layer full Δ-residual (always available) — batched ── | |
| full_vecs = torch.cat([hs[li + 1][:, -1, :] - hs[li][:, -1, :] for li in target_layers], dim=0) | |
| full_top = self._batch_top3(full_vecs) | |
| for row, li in enumerate(target_layers): | |
| delta_v[li]["full"] = full_top[row] | |
| do_heads = self.head_scan_supported and snap.head_inputs | |
| if not do_heads: | |
| return {"pre_stream": pre_stream, "delta_v": delta_v} | |
| # ── Tier-2: per-head residual writes — one batched lm_head over all ── | |
| rows = [] | |
| index = [] # (li, hi) aligned with rows | |
| for li in target_layers: | |
| inp = snap.head_inputs.get(li) | |
| if inp is None: continue | |
| wo = self.o_proj(li).weight | |
| wv = wo.view(wo.shape[0], self.num_heads, self.head_dim) | |
| x_last = inp[0, -1, :].view(self.num_heads, self.head_dim) # [H, hd] | |
| # proj[h] = Wo[:, h, :] @ x_last[h] → [H, hidden] | |
| proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last) | |
| for hi in target_heads: | |
| rows.append(proj_all[hi]) | |
| index.append((li, hi)) | |
| if rows: | |
| head_top = self._batch_top3(torch.stack(rows, dim=0)) | |
| for (li, hi), res in zip(index, head_top): | |
| pre_stream[li][hi] = res | |
| delta_v[li][hi] = res # identical quantity (see docstring) | |
| return {"pre_stream": pre_stream, "delta_v": delta_v} | |
| def _install_mute_hooks(self): | |
| for h in self._mute_handles: h.remove() | |
| self._mute_handles = [] | |
| if not self.muted_heads: return | |
| muted_by_layer = {} | |
| for (l, h) in self.muted_heads: | |
| muted_by_layer.setdefault(l, set()).add(h) | |
| hd, nh = self.head_dim, self.num_heads | |
| for li, heads in muted_by_layer.items(): | |
| def make_hook(muted_set, head_dim, num_heads): | |
| def fn(module, args): | |
| x = args[0] | |
| xv = x.view(x.shape[0], x.shape[1], num_heads, head_dim) | |
| for h in muted_set: | |
| xv[:, :, h, :] = 0.0 | |
| return (xv.reshape(x.shape),) | |
| return fn | |
| op = self.o_proj(li) | |
| if op is not None: | |
| self._mute_handles.append(op.register_forward_pre_hook(make_hook(heads, hd, nh))) | |
| def mute_head(self, layer, head): | |
| if not self.head_scan_supported: | |
| return f"head muting unsupported for {self.arch_name} (no per-head o_proj decomposition)" | |
| if not (0 <= layer < self.num_layers): return f"Layer {layer} OOR" | |
| if not (0 <= head < self.num_heads): return f"Head {head} OOR" | |
| key = (layer, head) | |
| if key in self.muted_heads: return f"L{layer}H{head} already muted" | |
| self.muted_heads.add(key) | |
| self._install_mute_hooks() | |
| self.full_log.append({"type": "mute", "layer": layer, "head": head}) | |
| return f"MUTED L{layer}H{head}" | |
| def unmute_head(self, layer, head): | |
| key = (layer, head) | |
| if key not in self.muted_heads: return f"L{layer}H{head} not muted" | |
| self.muted_heads.discard(key) | |
| self._install_mute_hooks() | |
| self.full_log.append({"type": "unmute", "layer": layer, "head": head}) | |
| return f"UNMUTED L{layer}H{head}" | |
| def unmute_all(self): | |
| for h in self._mute_handles: h.remove() | |
| self._mute_handles = [] | |
| n = len(self.muted_heads) | |
| self.muted_heads.clear() | |
| self.full_log.append({"type": "unmute_all"}) | |
| return f"UNMUTED ALL ({n} heads)" | |
| def sparkline_data(self, token_str: str, rank_lo: int = 0, rank_hi: int = None, | |
| snap: "Snapshot" = None): | |
| if self.history_ids is None: return None | |
| if not self.head_scan_supported: | |
| return "unsupported" | |
| snap = snap or self.snapshot | |
| if snap is None or not snap.head_inputs: | |
| return None | |
| if token_str.isdigit(): | |
| tid = int(token_str); name = self.tokenizer.decode(tid) | |
| else: | |
| ids = self.tokenizer.encode(token_str, add_special_tokens=False) | |
| if not ids: return None | |
| tid = ids[0]; name = token_str | |
| if rank_hi is None: rank_hi = self.vocab_size | |
| # Build [N, hidden] of every (head, layer) last-token projection, then a | |
| # single lm_head, then rank by COMPARISON (count of logits above target) | |
| # rather than a full argsort — O(vocab) vs O(vocab·log vocab) per row. | |
| rows, index = [], [] | |
| for li in sorted(snap.head_inputs.keys()): | |
| inp = snap.head_inputs[li] | |
| wo = self.o_proj(li).weight | |
| wv = wo.view(wo.shape[0], self.num_heads, self.head_dim) | |
| x_last = inp[0, -1, :].view(self.num_heads, self.head_dim) | |
| proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last) # [H, hidden] | |
| for hi in range(self.num_heads): | |
| rows.append(proj_all[hi]); index.append((hi, li)) | |
| ranks = {} | |
| if rows: | |
| with torch.no_grad(): | |
| logits = self.lm_head(self.norm(torch.stack(rows, dim=0))) # [N, vocab] | |
| target = logits[:, tid].unsqueeze(1) # [N, 1] | |
| rk = (logits > target).sum(dim=1).add(1).tolist() # [N] | |
| for (hi, li), r in zip(index, rk): | |
| ranks[(hi, li)] = int(r) | |
| return {"name": name, "tid": tid, "rank_lo": rank_lo, "rank_hi": rank_hi, "ranks": ranks} | |
| # ── HTML renderers ──────────────────────────────────────────────────────── | |
| CELL_W = 10 # display chars per token cell | |
| def render_context(probe: InteractiveProbe) -> str: | |
| if probe is None or probe.history_ids is None: | |
| return "<div class='pane-body'><i>No context loaded.</i></div>" | |
| full = probe.tokenizer.decode(probe.history_ids[0], skip_special_tokens=False) | |
| tail = full[-800:].replace("\n", " ") | |
| if len(full) > 800: tail = "..." + tail | |
| n = probe.history_ids.shape[1] | |
| return (f"<div class='pane-body'>" | |
| f"<div class='dim'>[{n} tokens]</div>" | |
| f"<div class='context-text'>{html_escape(tail)}</div></div>") | |
| def render_headmap(probe: InteractiveProbe, data: dict, xray: list) -> str: | |
| if not data: | |
| return "<div class='pane-body'><i>No head-scan run yet. Try <code>h *</code> for all layers.</i></div>" | |
| pre, dv = data.get("pre_stream", {}), data.get("delta_v", {}) | |
| layers = sorted(pre.keys()) | |
| if not layers: | |
| return "<div class='pane-body'><i>Empty scan.</i></div>" | |
| heads = sorted(h for h in pre.get(layers[0], {}).keys() if isinstance(h, int)) | |
| xray_map = {label: tok for label, tok in xray} if xray else {} | |
| def cell(tok, prob): | |
| clean = html_escape(fmt_tok(tok, CELL_W)) | |
| c = heat_color(prob) | |
| return f"<td class='hm-cell' style='color:{c}'>{clean}</td>" | |
| rows = [] | |
| # Header: [pre-stream heads] | Xray | [Δ-residual heads] | L_full | |
| hdr = ["<tr><th class='lay'>Lay</th>"] | |
| for h in heads: hdr.append(f"<th class='ps'>H{h}</th>") | |
| hdr.append("<th class='xr'>Xray</th>") | |
| for h in heads: hdr.append(f"<th class='dv'>H{h}<sub>dv</sub></th>") | |
| hdr.append("<th class='dv full'>L<sub>full</sub></th>") | |
| hdr.append("</tr>") | |
| rows.append("".join(hdr)) | |
| sub = ["<tr class='subhdr'><th></th>"] | |
| if heads: | |
| sub.append(f"<th class='ps' colspan='{len(heads)}'>per-head pre-projection (raw head → norm → lm_head)</th>") | |
| sub.append("<th class='xr'>logit lens</th>") | |
| sub.append(f"<th class='dv' colspan='{len(heads)+1}'>Δ-residual (attn+MLP+skip), " | |
| f"{'per head + ' if heads else ''}full layer</th>") | |
| sub.append("</tr>") | |
| rows.insert(0, "".join(sub)) | |
| for l in layers: | |
| r = [f"<tr><td class='lay'>L{l}</td>"] | |
| for h in heads: | |
| if h in pre.get(l, {}): | |
| tok, p = pre[l][h][0]; r.append(cell(tok, p)) | |
| else: | |
| r.append("<td class='hm-cell'>--</td>") | |
| # Xray (logit lens at this layer) | |
| label = "Emb" if l == 0 else f"L{l}" | |
| xt = xray_map.get(label, "--") | |
| r.append(f"<td class='hm-cell xr'>{html_escape(fmt_tok(xt, CELL_W))}</td>") | |
| # delta-v per head | |
| dvl = dv.get(l, {}) | |
| for h in heads: | |
| if h in dvl: | |
| tok, p = dvl[h][0]; r.append(cell(tok, p)) | |
| else: | |
| r.append("<td class='hm-cell'>--</td>") | |
| # full layer dv | |
| if "full" in dvl: | |
| tok, p = dvl["full"][0] | |
| r.append(f"<td class='hm-cell full' style='color:{heat_color(p)}'>{html_escape(fmt_tok(tok, CELL_W))}</td>") | |
| else: | |
| r.append("<td class='hm-cell full'>--</td>") | |
| r.append("</tr>") | |
| rows.append("".join(r)) | |
| note = "" | |
| if not heads: | |
| note = ("<div class='dim' style='padding:2px 0'>per-head decomposition unavailable for this " | |
| "architecture — showing logit-lens + full-layer Δ-residual only (Tier-1)</div>") | |
| return f"<div class='pane-body scroll'>{note}<table class='headmap'>{''.join(rows)}</table></div>" | |
| def render_futures(options: list, selected: int = None) -> str: | |
| if not options: | |
| return "<div class='pane-body'><i>No branches yet — click Load.</i></div>" | |
| rows = ["<tr><th>cmd</th><th>prob</th><th>token</th><th>preview</th><th>hints</th></tr>"] | |
| for o in options: | |
| n = o["index"] + 1 | |
| prob = f"{o['prob']:.1%}" | |
| col = heat_color(o["prob"]) | |
| root_clean = o["root_token"].strip()[:12] or o["root_token"][:12] | |
| hints = (f"<code>{n},5</code>·<code>{n},20</code>·" | |
| f"<code>spark {html_escape(root_clean)}</code>" if root_clean else f"<code>{n},5</code>") | |
| cls = "selected" if selected == n else "" | |
| rows.append( | |
| f"<tr class='{cls}'>" | |
| f"<td class='idx'><code class='cmd-pill'>{n}</code></td>" | |
| f"<td style='color:{col}'>{prob}</td>" | |
| f"<td class='tok'>{html_escape(fmt_tok(o['root_token'], 16))}</td>" | |
| f"<td class='preview'>{html_escape(o['future_text'][:120])}…</td>" | |
| f"<td class='cmd-hints dim'>{hints}</td>" | |
| f"</tr>" | |
| ) | |
| return f"<div class='pane-body scroll'><table class='futures'>{''.join(rows)}</table></div>" | |
| def render_sparkline(probe: InteractiveProbe, sd: dict) -> str: | |
| if not sd: | |
| return "<div class='pane-body'><i>No sparkline. Try <code>spark <token> [lo:hi]</code>.</i></div>" | |
| n_layers = probe.num_layers | |
| n_heads = probe.num_heads | |
| rank_lo, rank_hi = sd["rank_lo"], sd["rank_hi"] | |
| span = max(rank_hi - rank_lo, 1) | |
| ranks = sd["ranks"] | |
| band_h = 8 | |
| head_colors = ["#64D32A", "#FFCF67", "#21C5C5", "#E78BFF"] | |
| lines = [] | |
| name = sd["name"].replace("\n", "\\n")[:15] | |
| lines.append(f"Scatter: '{html_escape(name)}' (id={sd['tid']}) rank {rank_lo:,} (top) → {rank_hi:,} (bottom)") | |
| hdr = " " + "".join(f"{l:>3}" for l in range(n_layers)) | |
| lines.append(f"<span class='dim'>{html_escape(hdr)}</span>") | |
| lines.append(f" {'─' * (n_layers * 3)}") | |
| for hi in range(n_heads): | |
| color = head_colors[hi % len(head_colors)] | |
| grid = [[" "] * n_layers for _ in range(band_h)] | |
| for li in range(n_layers): | |
| r = ranks.get((hi, li), rank_hi + 1) | |
| if r < rank_lo: | |
| grid[0][li] = "▲" | |
| elif r > rank_hi: | |
| grid[band_h - 1][li] = "▼" | |
| else: | |
| frac = (r - rank_lo) / span | |
| row = int(frac * (band_h - 1)) | |
| grid[max(0, min(band_h - 1, row))][li] = "●" | |
| for ri in range(band_h): | |
| if ri == 0: | |
| label = f"H{hi} {rank_lo:>6}│" | |
| elif ri == band_h - 1: | |
| label = f" {rank_hi:>6}│" | |
| elif ri == band_h // 2: | |
| mid = rank_lo + span // 2 | |
| label = f" {mid:>6}│" | |
| else: | |
| label = f" │" | |
| cells = [] | |
| for li in range(n_layers): | |
| ch = grid[ri][li] | |
| if ch in ("●", "▲", "▼"): | |
| cells.append(f"<span style='color:{color}'> {ch} </span>") | |
| else: | |
| cells.append(f" {ch} ") | |
| lines.append(html_escape(label) + "".join(cells)) | |
| if hi < n_heads - 1: | |
| lines.append(f" {'┈' * (n_layers * 3)}") | |
| lines.append(f" {'─' * (n_layers * 3)}") | |
| return f"<div class='pane-body scroll'><pre class='spark'>{chr(10).join(lines)}</pre></div>" | |
| def render_muted(probe: InteractiveProbe) -> str: | |
| if probe is None or not probe.muted_heads: | |
| return "" | |
| items = ", ".join(f"L{l}H{h}" for l, h in sorted(probe.muted_heads)) | |
| return f"<span class='muted-tag'>MUTED: {items}</span>" | |
| def render_summary(s) -> str: | |
| """Compact run-summary strip: arch · tokens · entropy · top1 · selected · mode · muted.""" | |
| if s.probe is None: | |
| return "<div class='summary dim'>no model loaded</div>" | |
| p = s.probe | |
| n_tok = p.history_ids.shape[1] if p.history_ids is not None else 0 | |
| snap = p.snapshot | |
| ent = f"{snap.entropy:.2f}" if snap else "—" | |
| top1 = f"{snap.top1*100:.1f}%" if snap else "—" | |
| tier = "T2" if p.head_scan_supported else "T1" | |
| sel = getattr(s, "last_selected", None) | |
| sel_str = f"branch [{sel}]" if sel is not None else "—" | |
| muted = (f"<span style='color:#ff6b6b'>muted {len(p.muted_heads)}</span>" | |
| if p.muted_heads else "muted 0") | |
| arch = html_escape(p.arch_name) | |
| mid = html_escape(p.model_id) | |
| items = [ | |
| f"<b>{arch}</b> <span class='dim'>{mid}</span>", | |
| f"<span class='dim'>tier</span> {tier}", | |
| f"<span class='dim'>L</span>{p.num_layers} <span class='dim'>×H</span>{p.num_heads}", | |
| f"<span class='dim'>tok</span> {n_tok}", | |
| f"<span class='dim'>H</span>=<b>{ent}</b>", | |
| f"<span class='dim'>top1</span>=<b>{top1}</b>", | |
| f"<span class='dim'>last</span> {sel_str}", | |
| f"<span class='dim'>w</span>={s.width} <span class='dim'>len</span>={s.preview_len}", | |
| muted, | |
| ] | |
| return "<div class='summary'>" + " <span class='sep'>·</span> ".join(items) + "</div>" | |
| # ── State container ────────────────────────────────────────────────────── | |
| class Session: | |
| def __init__(self): | |
| self.probe: InteractiveProbe = None | |
| self.options: list = [] | |
| self.scan_data: dict = None | |
| self.xray: list = [] | |
| self.spark: dict = None | |
| self.width = DEFAULT_WIDTH | |
| self.preview_len = DEFAULT_PREVIEW | |
| self.transcript: list = [] # [(cmd, status), ...] | |
| self.loaded_model = None | |
| self.loaded_system = None | |
| self.loaded_user = None | |
| self.last_selected = None # which branch index was last committed | |
| def _reset_panes(self): | |
| self.options = [] | |
| self.scan_data = None | |
| self.xray = [] | |
| self.spark = None | |
| self.transcript = [] | |
| self.last_selected = None | |
| def load(self, model_id, system_prompt, user_prompt, force=False): | |
| """(Re)build the probe when model/prompt changed. Returns True if rebuilt. | |
| Fixes the stale-session bug: clicking Load with a new model/prompt now | |
| actually resets state instead of silently keeping the old probe.""" | |
| changed = ( | |
| force or self.probe is None | |
| or self.loaded_model != model_id | |
| or self.loaded_system != system_prompt | |
| or self.loaded_user != user_prompt | |
| ) | |
| if not changed: | |
| return False | |
| self.probe = InteractiveProbe(model_id=model_id) | |
| self.probe.set_seed() | |
| self.probe.load_start(user_prompt or "", system_prompt or None) | |
| self._reset_panes() | |
| self.loaded_model, self.loaded_system, self.loaded_user = model_id, system_prompt, user_prompt | |
| return True | |
| def ensure(self, model_id: str, system_prompt: str = None, user_prompt: str = ""): | |
| """Lazy load for the first command if the user typed before clicking Load.""" | |
| if self.probe is None: | |
| self.load(model_id, system_prompt, user_prompt or "") | |
| def panes_html(s: Session): | |
| return ( | |
| render_summary(s), | |
| render_context(s.probe), | |
| render_headmap(s.probe, s.scan_data, s.xray), | |
| render_futures(s.options, s.last_selected), | |
| render_sparkline(s.probe, s.spark), | |
| render_muted(s.probe), | |
| ) | |
| def status_html(s: Session) -> str: | |
| if not s.transcript: | |
| return "<div class='pane-body dim'>Ready.</div>" | |
| rows = [] | |
| for cmd, status in s.transcript[-20:]: | |
| rows.append(f"<div><span class='prompt'>›</span> <code>{html_escape(cmd)}</code> " | |
| f"<span class='dim'>— {html_escape(status)}</span></div>") | |
| return f"<div class='pane-body scroll'>{''.join(rows)}</div>" | |
| def handle_command(cmd: str, s: Session, model_id: str, system_prompt: str): | |
| cmd = (cmd or "").strip() | |
| if not cmd: | |
| return panes_html(s) + (status_html(s),) | |
| # Lazy load model on first command (so UI renders fast) | |
| s.ensure(model_id=model_id, system_prompt=system_prompt) | |
| status = "ok" | |
| try: | |
| lc = cmd.lower() | |
| if cmd[0].isdigit(): | |
| parts = cmd.split(",") | |
| idx = int(parts[0]) - 1 | |
| if 0 <= idx < len(s.options): | |
| keep_len = int(parts[1]) if len(parts) > 1 else None | |
| s.probe.commit(s.options[idx], keep_len) | |
| s.last_selected = idx + 1 | |
| _refresh_after_step(s) | |
| status = f"selected branch [{idx+1}]" | |
| else: | |
| status = f"invalid index {idx+1}" | |
| elif lc.startswith("i "): | |
| s.probe.inject(cmd[2:]) | |
| _refresh_after_step(s) | |
| status = f"injected '{cmd[2:][:30]}'" | |
| elif lc.startswith("h "): | |
| parts = cmd.split() | |
| target_layers = None if parts[1] == "*" else [int(parts[1])] | |
| target_heads = None | |
| if len(parts) > 2: | |
| target_heads = None if parts[2] == "*" else [int(parts[2])] | |
| if s.probe.snapshot is None: | |
| s.probe.forward_snapshot() | |
| s.scan_data = s.probe.scan_heads(target_layers=target_layers, target_heads=target_heads) | |
| s.xray = s.probe.perform_xray() | |
| status = f"head-scan: layers={parts[1]} heads={parts[2] if len(parts)>2 else '*'}" | |
| elif lc.startswith("top "): | |
| ranks = [int(r.strip()) for r in cmd[4:].split(",")] | |
| for rank in ranks: | |
| with torch.no_grad(): | |
| out = s.probe.model(s.probe.history_ids) | |
| logits = out.logits[0, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| sp, si = torch.sort(probs, descending=True) | |
| if rank < 1 or rank > len(si): continue | |
| tid = si[rank - 1].item() | |
| s.probe.history_ids = torch.cat( | |
| [s.probe.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1 | |
| ) | |
| s.probe.full_log.append({"type": "top_select", "rank": rank, | |
| "token": s.probe.tokenizer.decode([tid]), | |
| "token_id": tid, "prob": sp[rank-1].item()}) | |
| _refresh_after_step(s) | |
| status = f"top-selected {len(ranks)} tokens" | |
| elif lc.startswith("rew "): | |
| n = int(cmd[4:].strip()) | |
| cur = s.probe.history_ids.shape[1] | |
| new = max(1, cur - n) | |
| s.probe.history_ids = s.probe.history_ids[:, :new] | |
| s.probe.full_log.append({"type": "rewind", "tokens_removed": cur - new}) | |
| _refresh_after_step(s) | |
| status = f"rewound {cur - new} tokens" | |
| elif lc.startswith("w "): | |
| s.width = int(cmd[2:]); status = f"width={s.width}" | |
| _refresh_futures(s) | |
| elif lc.startswith("l "): | |
| s.preview_len = int(cmd[2:]); status = f"preview_len={s.preview_len}" | |
| _refresh_futures(s) | |
| elif lc.startswith("spark "): | |
| parts = cmd[6:].strip().split() | |
| tok = parts[0] | |
| rl, rh = 0, None | |
| if len(parts) > 1 and ":" in parts[1]: | |
| a, b = parts[1].split(":") | |
| rl, rh = int(a), int(b) | |
| sd = s.probe.sparkline_data(tok, rl, rh) | |
| if sd == "unsupported": | |
| status = f"token rank trace unsupported for {s.probe.arch_name} (Tier-1 only)" | |
| elif sd is None: | |
| status = f"token '{tok}' not found in vocab" | |
| else: | |
| s.spark = sd | |
| status = f"spark '{tok}' [{rl}:{rh if rh else 'vocab'}]" | |
| elif lc.startswith("mute ") and not lc.startswith("muted"): | |
| parts = cmd.split() | |
| status = s.probe.mute_head(int(parts[1]), int(parts[2])) | |
| _refresh_after_step(s) | |
| elif lc == "unmute all": | |
| status = s.probe.unmute_all() | |
| _refresh_after_step(s) | |
| elif lc.startswith("unmute "): | |
| parts = cmd.split() | |
| status = s.probe.unmute_head(int(parts[1]), int(parts[2])) | |
| _refresh_after_step(s) | |
| elif lc == "muted": | |
| status = f"muted: {sorted(s.probe.muted_heads)}" | |
| elif lc == "s": | |
| ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| fname = f"/tmp/session_{ts}.json" | |
| with open(fname, "w") as f: json.dump(s.probe.full_log, f, indent=2) | |
| status = f"saved {fname}" | |
| elif lc == "refresh" or lc == "r": | |
| _refresh_after_step(s) | |
| status = "refreshed" | |
| elif lc == "q": | |
| status = "(q is a no-op in the web UI — close the tab)" | |
| else: | |
| status = f"unknown: {cmd}" | |
| except Exception as e: | |
| status = f"error: {type(e).__name__}: {e}" | |
| s.transcript.append((cmd, status)) | |
| return panes_html(s) + (status_html(s),) | |
| def _refresh_futures(s: Session): | |
| """Re-derive branches from the existing snapshot (context unchanged).""" | |
| if s.probe is None or s.probe.history_ids is None: return | |
| if s.probe.snapshot is None: | |
| s.probe.forward_snapshot() | |
| s.options = s.probe.generate_previews(s.width, s.preview_len) | |
| def _refresh_after_step(s: Session): | |
| """Context changed: ONE forward, then every pane reads that snapshot.""" | |
| if s.probe is None or s.probe.history_ids is None: return | |
| s.probe.forward_snapshot() | |
| s.options = s.probe.generate_previews(s.width, s.preview_len) | |
| s.scan_data = s.probe.scan_heads() | |
| s.xray = s.probe.perform_xray() | |
| def initial_load(model_id: str, system_prompt: str, user_prompt: str, s: Session): | |
| rebuilt = s.load(model_id, system_prompt or None, user_prompt or "") | |
| _refresh_after_step(s) | |
| tier = "Tier-2 (per-head)" if s.probe.head_scan_supported else "Tier-1 only" | |
| verb = "loaded" if rebuilt else "ready" | |
| s.transcript.append(("(load)", f"{verb} {s.probe.arch_name} {s.probe.model_id} " | |
| f"— {s.probe.num_layers}L×{s.probe.num_heads}H, {tier}")) | |
| return panes_html(s) + (status_html(s),) | |
| # ── Gradio UI ───────────────────────────────────────────────────────────── | |
| CSS = """ | |
| /* ── Global tightening ── */ | |
| .gradio-container { max-width: 100% !important; padding: 6px 10px !important; } | |
| .gradio-container .main { padding: 0 !important; gap: 4px !important; } | |
| .gradio-container .gap, .gradio-container .form { gap: 4px !important; } | |
| .gradio-container .block { padding: 3px !important; border-radius: 4px !important; } | |
| .gradio-container .gr-block { padding: 0 !important; } | |
| .gradio-container .prose { margin: 0 !important; } | |
| /* ── Header: kill the giant title margins ── */ | |
| #hdr h1 { font-size: 18px !important; margin: 0 !important; padding: 0 !important; line-height: 1.2 !important; } | |
| #hdr p, #hdr em { font-size: 11px !important; margin: 0 !important; padding: 0 !important; | |
| color: #8b949e !important; line-height: 1.2 !important; } | |
| #hdr { margin: 0 !important; padding: 4px 0 !important; } | |
| /* ── Top bar: compact textboxes ── */ | |
| #topbar { gap: 6px !important; } | |
| #topbar > div { padding: 0 !important; } | |
| #topbar textarea, #topbar input[type="text"] { | |
| font-size: 12px !important; padding: 4px 6px !important; min-height: 28px !important; | |
| line-height: 1.3 !important; | |
| } | |
| #topbar label, #topbar .gr-text-input label, #topbar span[data-testid="block-label"] { | |
| font-size: 10px !important; padding: 0 0 2px 0 !important; margin: 0 !important; | |
| text-transform: uppercase; letter-spacing: 0.5px; color: #8b949e !important; | |
| } | |
| #topbar button { min-height: 56px !important; align-self: stretch !important; | |
| font-size: 13px !important; padding: 4px 8px !important; } | |
| /* ── CMD bar tightening ── */ | |
| #cmdbar textarea, #cmdbar input[type="text"] { | |
| font-size: 12px !important; padding: 4px 8px !important; min-height: 28px !important; | |
| font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace !important; | |
| } | |
| #cmdbar label, #cmdbar span[data-testid="block-label"] { | |
| font-size: 10px !important; color: #8b949e !important; margin: 0 !important; padding: 0 0 2px 0 !important; | |
| } | |
| #cmdbar button { min-height: 36px !important; font-size: 12px !important; } | |
| /* ── Help line (small, dim) ── */ | |
| #help { font-size: 10.5px !important; color: #8b949e !important; margin: 2px 0 !important; padding: 0 !important; } | |
| #help code { font-size: 10.5px !important; padding: 0 2px !important; background: #161b22 !important; } | |
| #help p { margin: 0 !important; line-height: 1.5 !important; } | |
| /* ── Kill empty muted-row gap before Context ── */ | |
| #muted-row { padding: 0 !important; margin: 0 !important; min-height: 0 !important; } | |
| #muted-row:empty { display: none !important; } | |
| #muted-row > * { padding: 0 !important; margin: 0 !important; } | |
| /* ── Panes ── */ | |
| .pane { border: 1px solid #30363d; border-radius: 4px; background: #0d1117; color: #c9d1d9; | |
| font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace; font-size: 12px; } | |
| .pane h3 { margin: 0; padding: 3px 8px; background: #161b22; border-bottom: 1px solid #30363d; | |
| font-size: 10px; letter-spacing: 0.5px; text-transform: uppercase; color: #8b949e; } | |
| .pane-body { padding: 5px 8px; max-height: 440px; overflow: auto; } | |
| .pane-body.scroll { overflow: auto; } | |
| .dim { color: #6e7681; } | |
| .context-text { white-space: pre-wrap; word-break: break-word; margin-top: 4px; } | |
| table.headmap { border-collapse: collapse; font-size: 11px; } | |
| table.headmap th, table.headmap td { padding: 1px 4px; border: 1px solid #21262d; text-align: center; white-space: pre; } | |
| table.headmap th.ps { background: #112; } | |
| table.headmap th.xr { background: #221; color: #ECA60F; } | |
| table.headmap th.dv { background: #122; } | |
| table.headmap th.full { background: #133; color: #64D32A; } | |
| table.headmap td.lay { color: #8b949e; } | |
| table.headmap td.xr { color: #ECA60F; background: #15110a; } | |
| table.headmap td.full { background: #0a1410; } | |
| table.headmap tr.subhdr th { background: #161b22; color: #8b949e; font-weight: normal; font-size: 10px; } | |
| table.futures { width: 100%; border-collapse: collapse; } | |
| table.futures th { background: #161b22; padding: 3px 6px; text-align: left; color: #8b949e; | |
| font-size: 10px; text-transform: uppercase; letter-spacing: 0.4px; } | |
| table.futures td { padding: 5px 6px; border-bottom: 1px solid #21262d; vertical-align: top; } | |
| table.futures tr:hover { background: #11161c; } | |
| table.futures tr.selected { background: #0f2418; outline: 1px solid #64D32A; } | |
| table.futures td.idx { color: #6e7681; } | |
| .cmd-pill { background: #21262d; color: #c9d1d9; padding: 1px 7px; border-radius: 3px; | |
| font-weight: bold; font-size: 11px; border: 1px solid #30363d; } | |
| table.futures tr.selected .cmd-pill { background: #1a4022; border-color: #64D32A; color: #b8f0c4; } | |
| table.futures td.tok { color: #c9d1d9; font-weight: bold; white-space: pre; } | |
| table.futures td.preview { color: #8b949e; max-width: 320px; } | |
| table.futures td.cmd-hints code { background: transparent; padding: 0 1px; color: #8b949e; font-size: 10.5px; } | |
| /* ── Run-summary strip ── */ | |
| #summary-strip { padding: 0 !important; margin: 0 !important; } | |
| .summary { font-family: 'JetBrains Mono', Consolas, monospace; font-size: 11px; | |
| padding: 4px 8px; background: #0d1117; border: 1px solid #30363d; border-radius: 4px; | |
| color: #c9d1d9; line-height: 1.5; } | |
| .summary .sep { color: #30363d; padding: 0 4px; } | |
| .summary .dim { color: #6e7681; } | |
| /* ── Locked-height instrument row: widening branches does not reflow head map ── */ | |
| #dual-row { align-items: stretch !important; } | |
| #dual-row > div { display: flex; flex-direction: column; } | |
| #dual-row #col-branch, #dual-row #col-head { height: 55vh; min-height: 360px; max-height: 640px; } | |
| #dual-row .pane { height: 100%; display: flex; flex-direction: column; overflow: hidden; } | |
| #dual-row .pane h3 { flex: 0 0 auto; } | |
| #dual-row .pane .pane-body { flex: 1 1 auto; max-height: none !important; overflow: auto; } | |
| /* ── Placeholder legibility (was being swallowed by the dark theme) ── */ | |
| .gradio-container textarea::placeholder, | |
| .gradio-container input[type="text"]::placeholder { | |
| color: #6e7681 !important; opacity: 1 !important; | |
| } | |
| #cmdbar textarea::placeholder, #cmdbar input[type="text"]::placeholder { | |
| color: #8b949e !important; opacity: 1 !important; | |
| } | |
| pre.spark { margin: 0; font-size: 11px; line-height: 1.1; } | |
| .muted-tag { color: #ff6b6b; font-weight: bold; padding: 2px 6px; } | |
| .prompt { color: #64D32A; } | |
| """ | |
| HELP = ( | |
| "`1-N` select branch · `1,5` select + keep N tokens · `i <text>` inject text · " | |
| "`h * | h L | h L H` head×layer scan · `top R1,R2,…` pick by rank · `rew N` rewind · " | |
| "`spark <tok> [lo:hi]` token rank trace · `w N` width · `l N` length · " | |
| "`mute L H` · `unmute L H` · `unmute all` · `muted` · `r` refresh · `s` save" | |
| ) | |
| def build_ui(): | |
| with gr.Blocks(title="Cartogemma", css=CSS, theme=gr.themes.Base()) as demo: | |
| gr.Markdown("# Cartogemma \n*Mechanistic probe — logit lens · head map · branches · token rank trace — " | |
| "on the Gemma family (3 / 3n / 4) and most HF decoder LMs (Qwen, Llama, …). " | |
| "Architecture is auto-discovered; per-head views degrade gracefully when unavailable.*", | |
| elem_id="hdr") | |
| session = gr.State(Session()) | |
| # ── Compact top bar: model | system | user | load ── | |
| MODEL_PRESETS = [ | |
| "google/gemma-3-270m-it", | |
| "google/gemma-3-1b-it", | |
| "google/gemma-3-4b-it", | |
| "google/gemma-3n-E2B-it", | |
| "google/gemma-4-E2B-it", | |
| "google/gemma-4-E4B-it", | |
| "Qwen/Qwen3-0.6B", | |
| "meta-llama/Llama-3.2-1B-Instruct", | |
| ] | |
| with gr.Row(elem_id="topbar"): | |
| model_id = gr.Dropdown( | |
| choices=MODEL_PRESETS, value=DEFAULT_MODEL_ID, label="Model", | |
| allow_custom_value=True, filterable=True, scale=3, container=True, | |
| ) | |
| system_prompt = gr.Textbox(label="System", value="", | |
| lines=1, max_lines=3, scale=4, container=True) | |
| user_prompt = gr.Textbox( | |
| label="User", | |
| value=("Hi, I need something like a cake recipe but it can't include eggs, sugar, " | |
| "gluten-- so no flour, actually make that no grains of any sort, as well as " | |
| "liquid and solid ingredients, none of them either. Thanks!"), | |
| lines=1, max_lines=3, scale=5, container=True, | |
| ) | |
| load_btn = gr.Button("Load", variant="primary", scale=1, min_width=80) | |
| # ── Run-summary strip: arch · tier · L×H · tokens · entropy · top1 · selected · mode · muted ── | |
| summary_pane = gr.HTML(value="<div class='summary dim'>no model loaded</div>", elem_id="summary-strip") | |
| muted_html = gr.HTML(elem_id="muted-row") | |
| # ── Context: full width ── | |
| ctx_pane = gr.HTML(value="<div class='pane'><h3>Context</h3>" | |
| "<div class='pane-body'><i>Click <b>Load</b> to begin.</i></div></div>") | |
| # ── Command bar: directly under Context (eye scans context → types command) ── | |
| with gr.Row(elem_id="cmdbar"): | |
| cmd = gr.Textbox( | |
| label="CMD", | |
| placeholder=( | |
| "1-N select branch · 1,5 select+keep · i <text> inject · " | |
| "h *|h L|h L H scan · top R1,R2,… rank-pick · rew N · " | |
| "spark <tok> [lo:hi] rank trace · w N · l N · mute L H · unmute L H|all · muted · r · s" | |
| ), | |
| scale=10, autofocus=True, lines=1, max_lines=1, container=True, | |
| ) | |
| run_btn = gr.Button("Run", scale=1, variant="primary", min_width=70) | |
| # ── Instrument grid: Branches | Head Map on one row, Spark + Transcript below. | |
| # Row heights locked so widening branches (w 30) doesn't reflow neighbors. ── | |
| with gr.Row(elem_id="dual-row"): | |
| with gr.Column(scale=3, min_width=280, elem_id="col-branch"): | |
| fut_pane = gr.HTML(value="<div class='pane'><h3>Branches (next-token top-k)</h3>" | |
| "<div class='pane-body'><i>Top-k continuations appear after Load.</i></div></div>") | |
| with gr.Column(scale=7, elem_id="col-head"): | |
| hm_pane = gr.HTML(value="<div class='pane'><h3>Head Map " | |
| "— per-head pre-projection · logit-lens · per-head Δ-residual · L<sub>full</sub></h3>" | |
| "<div class='pane-body'><i>Auto-runs after every step.</i></div></div>") | |
| spark_pane = gr.HTML(value="<div class='pane'><h3>Token Rank Trace " | |
| "(per-(head,layer) rank of a chosen token)</h3>" | |
| "<div class='pane-body'><i>Run <code>spark <token></code>.</i></div></div>") | |
| status = gr.HTML(value="<div class='pane'><h3>Transcript</h3>" | |
| "<div class='pane-body dim'>Ready.</div></div>") | |
| # Wrap renderers so each output is its own pane (with title) | |
| def _wrap(summary, ctx, hm, fut, sp, mute, st): | |
| return ( | |
| summary, | |
| f"<div class='pane'><h3>Context</h3>{ctx}</div>", | |
| f"<div class='pane'><h3>Head Map " | |
| f"— per-head pre-projection · logit-lens · per-head Δ-residual · L<sub>full</sub></h3>{hm}</div>", | |
| f"<div class='pane'><h3>Branches (next-token top-k)</h3>{fut}</div>", | |
| f"<div class='pane'><h3>Token Rank Trace " | |
| f"(per-(head,layer) rank of a chosen token)</h3>{sp}</div>", | |
| mute, | |
| f"<div class='pane'><h3>Transcript</h3>{st}</div>", | |
| ) | |
| def on_load(mid, sp, up, s): | |
| summary, ctx, hm, fut, spk, mute, st = initial_load(mid, sp, up, s) | |
| return _wrap(summary, ctx, hm, fut, spk, mute, st) + (s,) | |
| def on_cmd(c, s, mid, sp): | |
| summary, ctx, hm, fut, spk, mute, st = handle_command(c, s, mid, sp) | |
| return _wrap(summary, ctx, hm, fut, spk, mute, st) + ("", s) | |
| OUT_BASE = [summary_pane, ctx_pane, hm_pane, fut_pane, spark_pane, muted_html, status] | |
| load_btn.click( | |
| on_load, | |
| inputs=[model_id, system_prompt, user_prompt, session], | |
| outputs=OUT_BASE + [session], | |
| ) | |
| run_btn.click( | |
| on_cmd, | |
| inputs=[cmd, session, model_id, system_prompt], | |
| outputs=OUT_BASE + [cmd, session], | |
| ) | |
| cmd.submit( | |
| on_cmd, | |
| inputs=[cmd, session, model_id, system_prompt], | |
| outputs=OUT_BASE + [cmd, session], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_ui().launch() | |