"""Cartogemma — Gradio Space wrapping InteractiveProbe against google/gemma-3-1b-it. Faithful to the cartographer3.py / cartographer_tui.py presentation: - Context tail pane - Head Map: Layer | H0..Hn (pre-stream) | Xray | H0..Hn dv | Full dv - Futures pane (top-k branches with probs + previews) - Sparkline pane (token rank scatter across heads x layers, rendered as
)
  - REPL command bar (1-N, h, top, rew, i, spark, mute, unmute, w, l, s, q)
"""

import os
import re
import io
import sys
import json
import math
import unicodedata
from dataclasses import dataclass, field
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
try:
    from transformers import AutoModelForImageTextToText
except Exception:  # older transformers
    AutoModelForImageTextToText = None

# ── Config ────────────────────────────────────────────────────────────────
DEFAULT_MODEL_ID = os.environ.get("CARTOGEMMA_MODEL", "google/gemma-3-270m-it")
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
SEED = 42
DEFAULT_WIDTH = 5
DEFAULT_PREVIEW = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ── HEAT palette (mirrors cartographer3) ──────────────────────────────────
HEAT_HEX = ["#59749C", "#948863", "#ECA60F", "#FFCF67", "#219C7F", "#E0D05B", "#64D32A"]

def heat_color(p: float) -> str:
    if p >= 0.85: return HEAT_HEX[6]
    if p >= 0.70: return HEAT_HEX[5]
    if p >= 0.50: return HEAT_HEX[4]
    if p >= 0.30: return HEAT_HEX[3]
    if p >= 0.15: return HEAT_HEX[2]
    if p >= 0.05: return HEAT_HEX[1]
    return HEAT_HEX[0]

def char_width(ch: str) -> int:
    cat = unicodedata.category(ch)
    if cat in ("Mn", "Me"): return 0
    eaw = unicodedata.east_asian_width(ch)
    if eaw in ("W", "F"): return 2
    return 1

def fmt_tok(tok: str, max_w: int = 10) -> str:
    tok = tok.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
    w, out = 0, []
    for ch in tok:
        cw = char_width(ch)
        if w + cw > max_w: break
        out.append(ch); w += cw
    return "".join(out)

def html_escape(s: str) -> str:
    return s.replace("&", "&").replace("<", "<").replace(">", ">")

ANSI_RE = re.compile(r"\033\[[0-9;]*m")
def strip_ansi(s: str) -> str:
    return ANSI_RE.sub("", s)

# ── Architecture resolver ─────────────────────────────────────────────────
# The probe makes two tiers of assumption about a model:
#   TIER 1 (generic): output_hidden_states + a final norm + an lm_head/tied
#     embedding. Gives logit-lens (xray), per-layer Δ-residual, branches,
#     top-by-rank, inject, rewind. Works on essentially any HF decoder LM.
#   TIER 2 (architecture-specific): attention writes the residual through an
#     o_proj whose INPUT is the concatenation of contiguous per-head blocks
#     (standard MHA/GQA: in_features == num_heads * head_dim). Gives per-head
#     pre-projection, per-head Δ-residual, head muting, token rank trace.
# We auto-discover both tiers and flip `head_scan_supported` off (degrading to
# Tier 1) when the Tier-2 assumption doesn't hold (fused QKV, MLA, exotic attn).

def _resolve_text_config(cfg):
    """Find the sub-config that actually holds num_hidden_layers, etc.
    Multimodal configs (Gemma3/4, gemma3n) nest the LM under .text_config."""
    if hasattr(cfg, "get_text_config"):
        try:
            tc = cfg.get_text_config()
            if tc is not None and hasattr(tc, "num_hidden_layers"):
                return tc
        except Exception:
            pass
    tc = getattr(cfg, "text_config", None)
    if tc is not None and hasattr(tc, "num_hidden_layers"):
        return tc
    return cfg

def _is_decoder(mod):
    layers = getattr(mod, "layers", None)
    if not isinstance(layers, nn.ModuleList) or len(layers) == 0:
        return False
    first = layers[0]
    return any(hasattr(first, a) for a in ("self_attn", "attn", "attention"))

_DECODER_PATHS = [
    ("model",), ("model", "language_model"), ("language_model",),
    ("model", "model"), ("model", "language_model", "model"),
    ("transformer",), ("model", "decoder"), ("decoder",),
]

def _resolve_decoder(model):
    """Locate the module that owns the decoder-block ModuleList (.layers)."""
    for path in _DECODER_PATHS:
        obj = model
        ok = True
        for a in path:
            if hasattr(obj, a):
                obj = getattr(obj, a)
            else:
                ok = False
                break
        if ok and _is_decoder(obj):
            return obj
    # Generic fallback: longest attention-bearing ModuleList in the tree.
    best = None
    for _, mod in model.named_modules():
        if _is_decoder(mod) and (best is None or len(mod.layers) > len(best.layers)):
            best = mod
    if best is not None:
        return best
    raise RuntimeError("Could not locate a decoder layer stack for this model.")

def _resolve_final_norm(decoder):
    for a in ("norm", "final_layernorm", "ln_f", "final_norm"):
        n = getattr(decoder, a, None)
        if n is not None:
            return n
    return None

def _resolve_lm_head(model):
    for path in (("lm_head",), ("language_model", "lm_head"), ("model", "lm_head")):
        obj = model
        ok = True
        for a in path:
            if hasattr(obj, a):
                obj = getattr(obj, a)
            else:
                ok = False
                break
        if ok and obj is not None and hasattr(obj, "weight"):
            return obj
    return None  # caller falls back to tied input embeddings

def _resolve_o_proj(layer):
    """The attention output projection whose input is per-head-concatenated."""
    attn = None
    for a in ("self_attn", "attn", "attention"):
        attn = getattr(layer, a, None)
        if attn is not None:
            break
    if attn is None:
        return None
    for a in ("o_proj", "out_proj", "wo", "dense", "c_proj"):
        p = getattr(attn, a, None)
        if p is not None and hasattr(p, "weight"):
            return p
    return None


@dataclass
class Snapshot:
    """One forward pass over the current context, shared by every pane.
    Avoids the old pattern of xray / scan / branches each doing their own
    forward. head_inputs is populated only when head_scan is supported."""
    logits: torch.Tensor                 # [vocab] last-token logits (float)
    hidden_states: tuple                 # len L+1, each [1, seq, hidden]
    head_inputs: dict = field(default_factory=dict)  # layer_idx -> o_proj input [1, seq, H*hd]
    entropy: float = 0.0                 # next-token entropy (nats)
    top1: float = 0.0                    # next-token top-1 probability


# ── The Probe (lifted from cartographer3, ANSI removed) ───────────────────
class InteractiveProbe:
    def __init__(self, model_id: str = None):
        model_id = model_id or DEFAULT_MODEL_ID
        self.model_id = model_id
        print(f"[*] Loading {model_id} on {DEVICE}...")
        kw = {"token": HF_TOKEN} if HF_TOKEN else {}
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kw)
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        dtype = torch.bfloat16 if DEVICE == "cuda" else torch.float32
        load_kw = dict(torch_dtype=dtype, **kw)
        if DEVICE == "cuda":
            load_kw["device_map"] = "auto"
        # Try text CausalLM first; fall back to the multimodal head (gemma-3-4b,
        # gemma-4, gemma3n are all image-text-to-text wrappers around a text LM).
        self.model = None
        errs = []
        loaders = [AutoModelForCausalLM]
        if AutoModelForImageTextToText is not None:
            loaders.append(AutoModelForImageTextToText)
        for loader in loaders:
            try:
                self.model = loader.from_pretrained(model_id, **load_kw)
                break
            except Exception as e:
                errs.append(f"{loader.__name__}: {type(e).__name__}: {e}")
        if self.model is None:
            raise RuntimeError("Could not load model with any auto-class:\n" + "\n".join(errs))
        if DEVICE == "cpu":
            self.model = self.model.to(DEVICE)
        self.model.eval()

        # ── Auto-discover the architecture (see resolver notes above) ──
        self._decoder = _resolve_decoder(self.model)
        self.layers = self._decoder.layers
        self.final_norm = _resolve_final_norm(self._decoder)
        self._lm_head_mod = _resolve_lm_head(self.model)
        if self._lm_head_mod is None:
            emb = self.model.get_input_embeddings()
            self._tied_emb_weight = emb.weight
        else:
            self._tied_emb_weight = None

        text_cfg = _resolve_text_config(self.model.config)
        # Trust the actual stack length over config (resolver may pick a sub-stack).
        self.num_layers = len(self.layers)
        self.num_heads = int(getattr(text_cfg, "num_attention_heads", 0)) or 1
        self.num_kv_heads = int(getattr(text_cfg, "num_key_value_heads", self.num_heads))
        hidden_size = int(getattr(text_cfg, "hidden_size", 0)) or 0
        vocab_size = int(getattr(text_cfg, "vocab_size", 0)) or int(getattr(self.model.config, "vocab_size", 0))
        self.vocab_size = vocab_size or self._lm_head_out_features()

        o0 = _resolve_o_proj(self.layers[0])
        cfg_head_dim = getattr(text_cfg, "head_dim", None)
        if cfg_head_dim:
            self.head_dim = int(cfg_head_dim)
        elif o0 is not None:
            self.head_dim = o0.weight.shape[1] // self.num_heads
        elif hidden_size:
            self.head_dim = hidden_size // self.num_heads
        else:
            self.head_dim = 0
        self.hidden_size = hidden_size

        # Tier-2 (per-head) support: o_proj input must decompose as heads*head_dim.
        self.head_scan_supported = bool(
            o0 is not None
            and self.final_norm is not None
            and self.head_dim
            and o0.weight.shape[1] == self.num_heads * self.head_dim
        )
        self.arch_name = type(self.model).__name__

        print(f"[*] {self.arch_name}: {self.num_layers}L x {self.num_heads}H "
              f"(kv={self.num_kv_heads}), head_dim={self.head_dim}, hidden={hidden_size}, "
              f"vocab={self.vocab_size}, head_scan={'yes' if self.head_scan_supported else 'NO (Tier-1 only)'}")

        self.history_ids = None
        self.full_log = []
        self.muted_heads = set()
        self._mute_handles = []
        self.snapshot = None

    # ── architecture accessors ──
    def _lm_head_out_features(self):
        if self._lm_head_mod is not None and hasattr(self._lm_head_mod, "weight"):
            return self._lm_head_mod.weight.shape[0]
        if self._tied_emb_weight is not None:
            return self._tied_emb_weight.shape[0]
        return 0

    def lm_head(self, x):
        """Project hidden state to vocab logits (handles tied embeddings)."""
        if self._lm_head_mod is not None:
            return self._lm_head_mod(x)
        return F.linear(x, self._tied_emb_weight)

    def norm(self, x):
        """Apply final norm if the model has one, else identity."""
        return self.final_norm(x) if self.final_norm is not None else x

    def o_proj(self, layer_idx):
        return _resolve_o_proj(self.layers[layer_idx])

    # ── single shared forward pass ──
    def forward_snapshot(self):
        """Run ONE forward over the current context, capturing logits, hidden
        states, and (Tier-2) per-layer o_proj inputs. All panes read from this."""
        if self.history_ids is None:
            self.snapshot = None
            return None
        captured = {}
        handles = []
        if self.head_scan_supported:
            def mk(li):
                def fn(module, args, output):
                    captured[li] = args[0].detach()
                return fn
            for li in range(self.num_layers):
                op = self.o_proj(li)
                if op is not None:
                    handles.append(op.register_forward_hook(mk(li)))
        try:
            with torch.no_grad():
                out = self.model(self.history_ids, output_hidden_states=True)
        finally:
            for h in handles: h.remove()
        last_logits = out.logits[0, -1, :].float()
        probs = F.softmax(last_logits, dim=-1)
        ent = float(-(probs * torch.log(probs + 1e-12)).sum().item())
        top1 = float(probs.max().item())
        self.snapshot = Snapshot(
            logits=last_logits, hidden_states=out.hidden_states,
            head_inputs=captured, entropy=ent, top1=top1,
        )
        return self.snapshot

    def _batch_top3(self, mat: torch.Tensor, k: int = 3):
        """mat: [N, hidden] → list of N rows, each [(token, prob) * k].
        One norm + one lm_head + one top-k over the whole batch."""
        with torch.no_grad():
            normed = self.norm(mat)
            logits = self.lm_head(normed)              # [N, vocab]
            probs = F.softmax(logits, dim=-1)
            tp, ti = torch.topk(probs, k, dim=-1)      # [N, k]
        tp = tp.tolist(); ti = ti.tolist()
        return [[(self.tokenizer.decode(ti[n][j]), tp[n][j]) for j in range(k)]
                for n in range(mat.shape[0])]

    def set_seed(self):
        set_seed(SEED)

    def load_start(self, prompt: str, system_prompt: str = None):
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})
        try:
            full_prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            full_prompt = prompt
        self.history_ids = self.tokenizer(full_prompt, return_tensors="pt").to(DEVICE).input_ids

    def perform_xray(self, snap: "Snapshot" = None):
        """Logit lens per layer, batched: stack all last-token hidden states,
        one norm + lm_head + argmax."""
        snap = snap or self.snapshot
        if snap is None: return []
        hs = snap.hidden_states
        stacked = torch.cat([h[:, -1, :] for h in hs], dim=0)   # [L+1, hidden]
        with torch.no_grad():
            tids = self.lm_head(self.norm(stacked)).argmax(dim=-1).tolist()
        out = []
        for i, tid in enumerate(tids):
            label = "Emb" if i == 0 else f"L{i}"
            out.append((label, self.tokenizer.decode(tid)))
        return out

    def generate_previews(self, width: int, length: int, snap: "Snapshot" = None):
        snap = snap or self.snapshot
        if snap is not None:
            next_logits = snap.logits
        else:
            with torch.no_grad():
                next_logits = self.model(self.history_ids).logits[0, -1, :]
        probs = F.softmax(next_logits, dim=-1)
        top_probs, top_ids = torch.topk(probs, width)

        batch_inputs = []
        for i in range(width):
            tid = top_ids[i].item()
            nxt = torch.cat([self.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1)
            batch_inputs.append(nxt)
        batch = torch.cat(batch_inputs, dim=0)

        with torch.no_grad():
            outs = self.model.generate(
                batch, max_new_tokens=length, do_sample=False, temperature=0.0,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        previews = []
        for i in range(width):
            tid = top_ids[i].item()
            p = top_probs[i].item()
            full_seq = outs[i]
            trunk_len = self.history_ids.shape[1]
            root = self.tokenizer.decode(full_seq[trunk_len])
            future = self.tokenizer.decode(full_seq[trunk_len + 1:], skip_special_tokens=False)
            previews.append({
                "index": i, "token_id": tid, "root_token": root, "prob": p,
                "future_text": future.replace("\n", "\\n"),
                "full_ids": full_seq.unsqueeze(0),
            })
        return previews

    def commit(self, preview_obj, tokens_to_keep=None):
        full_seq = preview_obj["full_ids"]
        trunk_len = self.history_ids.shape[1]
        if tokens_to_keep is None:
            cut = full_seq.shape[1]
        else:
            cut = min(trunk_len + 1 + tokens_to_keep, full_seq.shape[1])
        self.history_ids = full_seq[:, :cut]
        committed = self.tokenizer.decode(self.history_ids[0, trunk_len:], skip_special_tokens=True)
        self.full_log.append({"type": "commit", "token": preview_obj["root_token"],
                              "full_text": committed, "kept_len": tokens_to_keep})

    def inject(self, text: str):
        ids = self.tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE).input_ids
        self.history_ids = torch.cat([self.history_ids, ids], dim=1)
        self.full_log.append({"type": "inject", "text": text})

    def scan_heads(self, snap: "Snapshot" = None, target_layers=None, target_heads=None):
        """Decode the head map from a precomputed snapshot. No model call here.
        Note: the per-head 'pre-projection' and per-head 'Δ-residual' are the
        same quantity (a head's last-token write into the residual via o_proj);
        the original tool computed them twice. We compute once and fill both."""
        snap = snap or self.snapshot
        if snap is None: return None
        if target_layers is None: target_layers = list(range(self.num_layers))
        if target_heads is None: target_heads = list(range(self.num_heads))
        target_layers = [li for li in target_layers if 0 <= li < self.num_layers]
        target_heads = [hi for hi in target_heads if 0 <= hi < self.num_heads]
        hs = snap.hidden_states

        pre_stream = {li: {} for li in target_layers}
        delta_v = {li: {} for li in target_layers}

        # ── Tier-1: per-layer full Δ-residual (always available) — batched ──
        full_vecs = torch.cat([hs[li + 1][:, -1, :] - hs[li][:, -1, :] for li in target_layers], dim=0)
        full_top = self._batch_top3(full_vecs)
        for row, li in enumerate(target_layers):
            delta_v[li]["full"] = full_top[row]

        do_heads = self.head_scan_supported and snap.head_inputs
        if not do_heads:
            return {"pre_stream": pre_stream, "delta_v": delta_v}

        # ── Tier-2: per-head residual writes — one batched lm_head over all ──
        rows = []
        index = []   # (li, hi) aligned with rows
        for li in target_layers:
            inp = snap.head_inputs.get(li)
            if inp is None: continue
            wo = self.o_proj(li).weight
            wv = wo.view(wo.shape[0], self.num_heads, self.head_dim)
            x_last = inp[0, -1, :].view(self.num_heads, self.head_dim)   # [H, hd]
            # proj[h] = Wo[:, h, :] @ x_last[h]  →  [H, hidden]
            proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last)
            for hi in target_heads:
                rows.append(proj_all[hi])
                index.append((li, hi))
        if rows:
            head_top = self._batch_top3(torch.stack(rows, dim=0))
            for (li, hi), res in zip(index, head_top):
                pre_stream[li][hi] = res
                delta_v[li][hi] = res   # identical quantity (see docstring)
        return {"pre_stream": pre_stream, "delta_v": delta_v}

    def _install_mute_hooks(self):
        for h in self._mute_handles: h.remove()
        self._mute_handles = []
        if not self.muted_heads: return
        muted_by_layer = {}
        for (l, h) in self.muted_heads:
            muted_by_layer.setdefault(l, set()).add(h)
        hd, nh = self.head_dim, self.num_heads
        for li, heads in muted_by_layer.items():
            def make_hook(muted_set, head_dim, num_heads):
                def fn(module, args):
                    x = args[0]
                    xv = x.view(x.shape[0], x.shape[1], num_heads, head_dim)
                    for h in muted_set:
                        xv[:, :, h, :] = 0.0
                    return (xv.reshape(x.shape),)
                return fn
            op = self.o_proj(li)
            if op is not None:
                self._mute_handles.append(op.register_forward_pre_hook(make_hook(heads, hd, nh)))

    def mute_head(self, layer, head):
        if not self.head_scan_supported:
            return f"head muting unsupported for {self.arch_name} (no per-head o_proj decomposition)"
        if not (0 <= layer < self.num_layers): return f"Layer {layer} OOR"
        if not (0 <= head < self.num_heads): return f"Head {head} OOR"
        key = (layer, head)
        if key in self.muted_heads: return f"L{layer}H{head} already muted"
        self.muted_heads.add(key)
        self._install_mute_hooks()
        self.full_log.append({"type": "mute", "layer": layer, "head": head})
        return f"MUTED L{layer}H{head}"

    def unmute_head(self, layer, head):
        key = (layer, head)
        if key not in self.muted_heads: return f"L{layer}H{head} not muted"
        self.muted_heads.discard(key)
        self._install_mute_hooks()
        self.full_log.append({"type": "unmute", "layer": layer, "head": head})
        return f"UNMUTED L{layer}H{head}"

    def unmute_all(self):
        for h in self._mute_handles: h.remove()
        self._mute_handles = []
        n = len(self.muted_heads)
        self.muted_heads.clear()
        self.full_log.append({"type": "unmute_all"})
        return f"UNMUTED ALL ({n} heads)"

    def sparkline_data(self, token_str: str, rank_lo: int = 0, rank_hi: int = None,
                       snap: "Snapshot" = None):
        if self.history_ids is None: return None
        if not self.head_scan_supported:
            return "unsupported"
        snap = snap or self.snapshot
        if snap is None or not snap.head_inputs:
            return None
        if token_str.isdigit():
            tid = int(token_str); name = self.tokenizer.decode(tid)
        else:
            ids = self.tokenizer.encode(token_str, add_special_tokens=False)
            if not ids: return None
            tid = ids[0]; name = token_str
        if rank_hi is None: rank_hi = self.vocab_size

        # Build [N, hidden] of every (head, layer) last-token projection, then a
        # single lm_head, then rank by COMPARISON (count of logits above target)
        # rather than a full argsort — O(vocab) vs O(vocab·log vocab) per row.
        rows, index = [], []
        for li in sorted(snap.head_inputs.keys()):
            inp = snap.head_inputs[li]
            wo = self.o_proj(li).weight
            wv = wo.view(wo.shape[0], self.num_heads, self.head_dim)
            x_last = inp[0, -1, :].view(self.num_heads, self.head_dim)
            proj_all = torch.einsum("khd,hd->hk", wv.to(x_last.dtype), x_last)  # [H, hidden]
            for hi in range(self.num_heads):
                rows.append(proj_all[hi]); index.append((hi, li))
        ranks = {}
        if rows:
            with torch.no_grad():
                logits = self.lm_head(self.norm(torch.stack(rows, dim=0)))      # [N, vocab]
                target = logits[:, tid].unsqueeze(1)                            # [N, 1]
                rk = (logits > target).sum(dim=1).add(1).tolist()               # [N]
            for (hi, li), r in zip(index, rk):
                ranks[(hi, li)] = int(r)
        return {"name": name, "tid": tid, "rank_lo": rank_lo, "rank_hi": rank_hi, "ranks": ranks}


# ── HTML renderers ────────────────────────────────────────────────────────
CELL_W = 10  # display chars per token cell

def render_context(probe: InteractiveProbe) -> str:
    if probe is None or probe.history_ids is None:
        return "
No context loaded.
" full = probe.tokenizer.decode(probe.history_ids[0], skip_special_tokens=False) tail = full[-800:].replace("\n", " ") if len(full) > 800: tail = "..." + tail n = probe.history_ids.shape[1] return (f"
" f"
[{n} tokens]
" f"
{html_escape(tail)}
") def render_headmap(probe: InteractiveProbe, data: dict, xray: list) -> str: if not data: return "
No head-scan run yet. Try h * for all layers.
" pre, dv = data.get("pre_stream", {}), data.get("delta_v", {}) layers = sorted(pre.keys()) if not layers: return "
Empty scan.
" heads = sorted(h for h in pre.get(layers[0], {}).keys() if isinstance(h, int)) xray_map = {label: tok for label, tok in xray} if xray else {} def cell(tok, prob): clean = html_escape(fmt_tok(tok, CELL_W)) c = heat_color(prob) return f"{clean}" rows = [] # Header: [pre-stream heads] | Xray | [Δ-residual heads] | L_full hdr = ["Lay"] for h in heads: hdr.append(f"H{h}") hdr.append("Xray") for h in heads: hdr.append(f"H{h}dv") hdr.append("Lfull") hdr.append("") rows.append("".join(hdr)) sub = [""] if heads: sub.append(f"per-head pre-projection (raw head → norm → lm_head)") sub.append("logit lens") sub.append(f"Δ-residual (attn+MLP+skip), " f"{'per head + ' if heads else ''}full layer") sub.append("") rows.insert(0, "".join(sub)) for l in layers: r = [f"L{l}"] for h in heads: if h in pre.get(l, {}): tok, p = pre[l][h][0]; r.append(cell(tok, p)) else: r.append("--") # Xray (logit lens at this layer) label = "Emb" if l == 0 else f"L{l}" xt = xray_map.get(label, "--") r.append(f"{html_escape(fmt_tok(xt, CELL_W))}") # delta-v per head dvl = dv.get(l, {}) for h in heads: if h in dvl: tok, p = dvl[h][0]; r.append(cell(tok, p)) else: r.append("--") # full layer dv if "full" in dvl: tok, p = dvl["full"][0] r.append(f"{html_escape(fmt_tok(tok, CELL_W))}") else: r.append("--") r.append("") rows.append("".join(r)) note = "" if not heads: note = ("
per-head decomposition unavailable for this " "architecture — showing logit-lens + full-layer Δ-residual only (Tier-1)
") return f"
{note}{''.join(rows)}
" def render_futures(options: list, selected: int = None) -> str: if not options: return "
No branches yet — click Load.
" rows = ["cmdprobtokenpreviewhints"] for o in options: n = o["index"] + 1 prob = f"{o['prob']:.1%}" col = heat_color(o["prob"]) root_clean = o["root_token"].strip()[:12] or o["root_token"][:12] hints = (f"{n},5·{n},20·" f"spark {html_escape(root_clean)}" if root_clean else f"{n},5") cls = "selected" if selected == n else "" rows.append( f"" f"{n}" f"{prob}" f"{html_escape(fmt_tok(o['root_token'], 16))}" f"{html_escape(o['future_text'][:120])}…" f"{hints}" f"" ) return f"
{''.join(rows)}
" def render_sparkline(probe: InteractiveProbe, sd: dict) -> str: if not sd: return "
No sparkline. Try spark <token> [lo:hi].
" n_layers = probe.num_layers n_heads = probe.num_heads rank_lo, rank_hi = sd["rank_lo"], sd["rank_hi"] span = max(rank_hi - rank_lo, 1) ranks = sd["ranks"] band_h = 8 head_colors = ["#64D32A", "#FFCF67", "#21C5C5", "#E78BFF"] lines = [] name = sd["name"].replace("\n", "\\n")[:15] lines.append(f"Scatter: '{html_escape(name)}' (id={sd['tid']}) rank {rank_lo:,} (top) → {rank_hi:,} (bottom)") hdr = " " + "".join(f"{l:>3}" for l in range(n_layers)) lines.append(f"{html_escape(hdr)}") lines.append(f" {'─' * (n_layers * 3)}") for hi in range(n_heads): color = head_colors[hi % len(head_colors)] grid = [[" "] * n_layers for _ in range(band_h)] for li in range(n_layers): r = ranks.get((hi, li), rank_hi + 1) if r < rank_lo: grid[0][li] = "▲" elif r > rank_hi: grid[band_h - 1][li] = "▼" else: frac = (r - rank_lo) / span row = int(frac * (band_h - 1)) grid[max(0, min(band_h - 1, row))][li] = "●" for ri in range(band_h): if ri == 0: label = f"H{hi} {rank_lo:>6}│" elif ri == band_h - 1: label = f" {rank_hi:>6}│" elif ri == band_h // 2: mid = rank_lo + span // 2 label = f" {mid:>6}│" else: label = f" │" cells = [] for li in range(n_layers): ch = grid[ri][li] if ch in ("●", "▲", "▼"): cells.append(f" {ch} ") else: cells.append(f" {ch} ") lines.append(html_escape(label) + "".join(cells)) if hi < n_heads - 1: lines.append(f" {'┈' * (n_layers * 3)}") lines.append(f" {'─' * (n_layers * 3)}") return f"
{chr(10).join(lines)}
" def render_muted(probe: InteractiveProbe) -> str: if probe is None or not probe.muted_heads: return "" items = ", ".join(f"L{l}H{h}" for l, h in sorted(probe.muted_heads)) return f"MUTED: {items}" def render_summary(s) -> str: """Compact run-summary strip: arch · tokens · entropy · top1 · selected · mode · muted.""" if s.probe is None: return "
no model loaded
" p = s.probe n_tok = p.history_ids.shape[1] if p.history_ids is not None else 0 snap = p.snapshot ent = f"{snap.entropy:.2f}" if snap else "—" top1 = f"{snap.top1*100:.1f}%" if snap else "—" tier = "T2" if p.head_scan_supported else "T1" sel = getattr(s, "last_selected", None) sel_str = f"branch [{sel}]" if sel is not None else "—" muted = (f"muted {len(p.muted_heads)}" if p.muted_heads else "muted 0") arch = html_escape(p.arch_name) mid = html_escape(p.model_id) items = [ f"{arch} {mid}", f"tier {tier}", f"L{p.num_layers} ×H{p.num_heads}", f"tok {n_tok}", f"H={ent}", f"top1={top1}", f"last {sel_str}", f"w={s.width} len={s.preview_len}", muted, ] return "
" + " · ".join(items) + "
" # ── State container ────────────────────────────────────────────────────── class Session: def __init__(self): self.probe: InteractiveProbe = None self.options: list = [] self.scan_data: dict = None self.xray: list = [] self.spark: dict = None self.width = DEFAULT_WIDTH self.preview_len = DEFAULT_PREVIEW self.transcript: list = [] # [(cmd, status), ...] self.loaded_model = None self.loaded_system = None self.loaded_user = None self.last_selected = None # which branch index was last committed def _reset_panes(self): self.options = [] self.scan_data = None self.xray = [] self.spark = None self.transcript = [] self.last_selected = None def load(self, model_id, system_prompt, user_prompt, force=False): """(Re)build the probe when model/prompt changed. Returns True if rebuilt. Fixes the stale-session bug: clicking Load with a new model/prompt now actually resets state instead of silently keeping the old probe.""" changed = ( force or self.probe is None or self.loaded_model != model_id or self.loaded_system != system_prompt or self.loaded_user != user_prompt ) if not changed: return False self.probe = InteractiveProbe(model_id=model_id) self.probe.set_seed() self.probe.load_start(user_prompt or "", system_prompt or None) self._reset_panes() self.loaded_model, self.loaded_system, self.loaded_user = model_id, system_prompt, user_prompt return True def ensure(self, model_id: str, system_prompt: str = None, user_prompt: str = ""): """Lazy load for the first command if the user typed before clicking Load.""" if self.probe is None: self.load(model_id, system_prompt, user_prompt or "") def panes_html(s: Session): return ( render_summary(s), render_context(s.probe), render_headmap(s.probe, s.scan_data, s.xray), render_futures(s.options, s.last_selected), render_sparkline(s.probe, s.spark), render_muted(s.probe), ) def status_html(s: Session) -> str: if not s.transcript: return "
Ready.
" rows = [] for cmd, status in s.transcript[-20:]: rows.append(f"
{html_escape(cmd)} " f"— {html_escape(status)}
") return f"
{''.join(rows)}
" def handle_command(cmd: str, s: Session, model_id: str, system_prompt: str): cmd = (cmd or "").strip() if not cmd: return panes_html(s) + (status_html(s),) # Lazy load model on first command (so UI renders fast) s.ensure(model_id=model_id, system_prompt=system_prompt) status = "ok" try: lc = cmd.lower() if cmd[0].isdigit(): parts = cmd.split(",") idx = int(parts[0]) - 1 if 0 <= idx < len(s.options): keep_len = int(parts[1]) if len(parts) > 1 else None s.probe.commit(s.options[idx], keep_len) s.last_selected = idx + 1 _refresh_after_step(s) status = f"selected branch [{idx+1}]" else: status = f"invalid index {idx+1}" elif lc.startswith("i "): s.probe.inject(cmd[2:]) _refresh_after_step(s) status = f"injected '{cmd[2:][:30]}'" elif lc.startswith("h "): parts = cmd.split() target_layers = None if parts[1] == "*" else [int(parts[1])] target_heads = None if len(parts) > 2: target_heads = None if parts[2] == "*" else [int(parts[2])] if s.probe.snapshot is None: s.probe.forward_snapshot() s.scan_data = s.probe.scan_heads(target_layers=target_layers, target_heads=target_heads) s.xray = s.probe.perform_xray() status = f"head-scan: layers={parts[1]} heads={parts[2] if len(parts)>2 else '*'}" elif lc.startswith("top "): ranks = [int(r.strip()) for r in cmd[4:].split(",")] for rank in ranks: with torch.no_grad(): out = s.probe.model(s.probe.history_ids) logits = out.logits[0, -1, :] probs = F.softmax(logits, dim=-1) sp, si = torch.sort(probs, descending=True) if rank < 1 or rank > len(si): continue tid = si[rank - 1].item() s.probe.history_ids = torch.cat( [s.probe.history_ids, torch.tensor([[tid]], device=DEVICE)], dim=1 ) s.probe.full_log.append({"type": "top_select", "rank": rank, "token": s.probe.tokenizer.decode([tid]), "token_id": tid, "prob": sp[rank-1].item()}) _refresh_after_step(s) status = f"top-selected {len(ranks)} tokens" elif lc.startswith("rew "): n = int(cmd[4:].strip()) cur = s.probe.history_ids.shape[1] new = max(1, cur - n) s.probe.history_ids = s.probe.history_ids[:, :new] s.probe.full_log.append({"type": "rewind", "tokens_removed": cur - new}) _refresh_after_step(s) status = f"rewound {cur - new} tokens" elif lc.startswith("w "): s.width = int(cmd[2:]); status = f"width={s.width}" _refresh_futures(s) elif lc.startswith("l "): s.preview_len = int(cmd[2:]); status = f"preview_len={s.preview_len}" _refresh_futures(s) elif lc.startswith("spark "): parts = cmd[6:].strip().split() tok = parts[0] rl, rh = 0, None if len(parts) > 1 and ":" in parts[1]: a, b = parts[1].split(":") rl, rh = int(a), int(b) sd = s.probe.sparkline_data(tok, rl, rh) if sd == "unsupported": status = f"token rank trace unsupported for {s.probe.arch_name} (Tier-1 only)" elif sd is None: status = f"token '{tok}' not found in vocab" else: s.spark = sd status = f"spark '{tok}' [{rl}:{rh if rh else 'vocab'}]" elif lc.startswith("mute ") and not lc.startswith("muted"): parts = cmd.split() status = s.probe.mute_head(int(parts[1]), int(parts[2])) _refresh_after_step(s) elif lc == "unmute all": status = s.probe.unmute_all() _refresh_after_step(s) elif lc.startswith("unmute "): parts = cmd.split() status = s.probe.unmute_head(int(parts[1]), int(parts[2])) _refresh_after_step(s) elif lc == "muted": status = f"muted: {sorted(s.probe.muted_heads)}" elif lc == "s": ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") fname = f"/tmp/session_{ts}.json" with open(fname, "w") as f: json.dump(s.probe.full_log, f, indent=2) status = f"saved {fname}" elif lc == "refresh" or lc == "r": _refresh_after_step(s) status = "refreshed" elif lc == "q": status = "(q is a no-op in the web UI — close the tab)" else: status = f"unknown: {cmd}" except Exception as e: status = f"error: {type(e).__name__}: {e}" s.transcript.append((cmd, status)) return panes_html(s) + (status_html(s),) def _refresh_futures(s: Session): """Re-derive branches from the existing snapshot (context unchanged).""" if s.probe is None or s.probe.history_ids is None: return if s.probe.snapshot is None: s.probe.forward_snapshot() s.options = s.probe.generate_previews(s.width, s.preview_len) def _refresh_after_step(s: Session): """Context changed: ONE forward, then every pane reads that snapshot.""" if s.probe is None or s.probe.history_ids is None: return s.probe.forward_snapshot() s.options = s.probe.generate_previews(s.width, s.preview_len) s.scan_data = s.probe.scan_heads() s.xray = s.probe.perform_xray() def initial_load(model_id: str, system_prompt: str, user_prompt: str, s: Session): rebuilt = s.load(model_id, system_prompt or None, user_prompt or "") _refresh_after_step(s) tier = "Tier-2 (per-head)" if s.probe.head_scan_supported else "Tier-1 only" verb = "loaded" if rebuilt else "ready" s.transcript.append(("(load)", f"{verb} {s.probe.arch_name} {s.probe.model_id} " f"— {s.probe.num_layers}L×{s.probe.num_heads}H, {tier}")) return panes_html(s) + (status_html(s),) # ── Gradio UI ───────────────────────────────────────────────────────────── CSS = """ /* ── Global tightening ── */ .gradio-container { max-width: 100% !important; padding: 6px 10px !important; } .gradio-container .main { padding: 0 !important; gap: 4px !important; } .gradio-container .gap, .gradio-container .form { gap: 4px !important; } .gradio-container .block { padding: 3px !important; border-radius: 4px !important; } .gradio-container .gr-block { padding: 0 !important; } .gradio-container .prose { margin: 0 !important; } /* ── Header: kill the giant title margins ── */ #hdr h1 { font-size: 18px !important; margin: 0 !important; padding: 0 !important; line-height: 1.2 !important; } #hdr p, #hdr em { font-size: 11px !important; margin: 0 !important; padding: 0 !important; color: #8b949e !important; line-height: 1.2 !important; } #hdr { margin: 0 !important; padding: 4px 0 !important; } /* ── Top bar: compact textboxes ── */ #topbar { gap: 6px !important; } #topbar > div { padding: 0 !important; } #topbar textarea, #topbar input[type="text"] { font-size: 12px !important; padding: 4px 6px !important; min-height: 28px !important; line-height: 1.3 !important; } #topbar label, #topbar .gr-text-input label, #topbar span[data-testid="block-label"] { font-size: 10px !important; padding: 0 0 2px 0 !important; margin: 0 !important; text-transform: uppercase; letter-spacing: 0.5px; color: #8b949e !important; } #topbar button { min-height: 56px !important; align-self: stretch !important; font-size: 13px !important; padding: 4px 8px !important; } /* ── CMD bar tightening ── */ #cmdbar textarea, #cmdbar input[type="text"] { font-size: 12px !important; padding: 4px 8px !important; min-height: 28px !important; font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace !important; } #cmdbar label, #cmdbar span[data-testid="block-label"] { font-size: 10px !important; color: #8b949e !important; margin: 0 !important; padding: 0 0 2px 0 !important; } #cmdbar button { min-height: 36px !important; font-size: 12px !important; } /* ── Help line (small, dim) ── */ #help { font-size: 10.5px !important; color: #8b949e !important; margin: 2px 0 !important; padding: 0 !important; } #help code { font-size: 10.5px !important; padding: 0 2px !important; background: #161b22 !important; } #help p { margin: 0 !important; line-height: 1.5 !important; } /* ── Kill empty muted-row gap before Context ── */ #muted-row { padding: 0 !important; margin: 0 !important; min-height: 0 !important; } #muted-row:empty { display: none !important; } #muted-row > * { padding: 0 !important; margin: 0 !important; } /* ── Panes ── */ .pane { border: 1px solid #30363d; border-radius: 4px; background: #0d1117; color: #c9d1d9; font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace; font-size: 12px; } .pane h3 { margin: 0; padding: 3px 8px; background: #161b22; border-bottom: 1px solid #30363d; font-size: 10px; letter-spacing: 0.5px; text-transform: uppercase; color: #8b949e; } .pane-body { padding: 5px 8px; max-height: 440px; overflow: auto; } .pane-body.scroll { overflow: auto; } .dim { color: #6e7681; } .context-text { white-space: pre-wrap; word-break: break-word; margin-top: 4px; } table.headmap { border-collapse: collapse; font-size: 11px; } table.headmap th, table.headmap td { padding: 1px 4px; border: 1px solid #21262d; text-align: center; white-space: pre; } table.headmap th.ps { background: #112; } table.headmap th.xr { background: #221; color: #ECA60F; } table.headmap th.dv { background: #122; } table.headmap th.full { background: #133; color: #64D32A; } table.headmap td.lay { color: #8b949e; } table.headmap td.xr { color: #ECA60F; background: #15110a; } table.headmap td.full { background: #0a1410; } table.headmap tr.subhdr th { background: #161b22; color: #8b949e; font-weight: normal; font-size: 10px; } table.futures { width: 100%; border-collapse: collapse; } table.futures th { background: #161b22; padding: 3px 6px; text-align: left; color: #8b949e; font-size: 10px; text-transform: uppercase; letter-spacing: 0.4px; } table.futures td { padding: 5px 6px; border-bottom: 1px solid #21262d; vertical-align: top; } table.futures tr:hover { background: #11161c; } table.futures tr.selected { background: #0f2418; outline: 1px solid #64D32A; } table.futures td.idx { color: #6e7681; } .cmd-pill { background: #21262d; color: #c9d1d9; padding: 1px 7px; border-radius: 3px; font-weight: bold; font-size: 11px; border: 1px solid #30363d; } table.futures tr.selected .cmd-pill { background: #1a4022; border-color: #64D32A; color: #b8f0c4; } table.futures td.tok { color: #c9d1d9; font-weight: bold; white-space: pre; } table.futures td.preview { color: #8b949e; max-width: 320px; } table.futures td.cmd-hints code { background: transparent; padding: 0 1px; color: #8b949e; font-size: 10.5px; } /* ── Run-summary strip ── */ #summary-strip { padding: 0 !important; margin: 0 !important; } .summary { font-family: 'JetBrains Mono', Consolas, monospace; font-size: 11px; padding: 4px 8px; background: #0d1117; border: 1px solid #30363d; border-radius: 4px; color: #c9d1d9; line-height: 1.5; } .summary .sep { color: #30363d; padding: 0 4px; } .summary .dim { color: #6e7681; } /* ── Locked-height instrument row: widening branches does not reflow head map ── */ #dual-row { align-items: stretch !important; } #dual-row > div { display: flex; flex-direction: column; } #dual-row #col-branch, #dual-row #col-head { height: 55vh; min-height: 360px; max-height: 640px; } #dual-row .pane { height: 100%; display: flex; flex-direction: column; overflow: hidden; } #dual-row .pane h3 { flex: 0 0 auto; } #dual-row .pane .pane-body { flex: 1 1 auto; max-height: none !important; overflow: auto; } /* ── Placeholder legibility (was being swallowed by the dark theme) ── */ .gradio-container textarea::placeholder, .gradio-container input[type="text"]::placeholder { color: #6e7681 !important; opacity: 1 !important; } #cmdbar textarea::placeholder, #cmdbar input[type="text"]::placeholder { color: #8b949e !important; opacity: 1 !important; } pre.spark { margin: 0; font-size: 11px; line-height: 1.1; } .muted-tag { color: #ff6b6b; font-weight: bold; padding: 2px 6px; } .prompt { color: #64D32A; } """ HELP = ( "`1-N` select branch · `1,5` select + keep N tokens · `i ` inject text · " "`h * | h L | h L H` head×layer scan · `top R1,R2,…` pick by rank · `rew N` rewind · " "`spark [lo:hi]` token rank trace · `w N` width · `l N` length · " "`mute L H` · `unmute L H` · `unmute all` · `muted` · `r` refresh · `s` save" ) def build_ui(): with gr.Blocks(title="Cartogemma", css=CSS, theme=gr.themes.Base()) as demo: gr.Markdown("# Cartogemma \n*Mechanistic probe — logit lens · head map · branches · token rank trace — " "on the Gemma family (3 / 3n / 4) and most HF decoder LMs (Qwen, Llama, …). " "Architecture is auto-discovered; per-head views degrade gracefully when unavailable.*", elem_id="hdr") session = gr.State(Session()) # ── Compact top bar: model | system | user | load ── MODEL_PRESETS = [ "google/gemma-3-270m-it", "google/gemma-3-1b-it", "google/gemma-3-4b-it", "google/gemma-3n-E2B-it", "google/gemma-4-E2B-it", "google/gemma-4-E4B-it", "Qwen/Qwen3-0.6B", "meta-llama/Llama-3.2-1B-Instruct", ] with gr.Row(elem_id="topbar"): model_id = gr.Dropdown( choices=MODEL_PRESETS, value=DEFAULT_MODEL_ID, label="Model", allow_custom_value=True, filterable=True, scale=3, container=True, ) system_prompt = gr.Textbox(label="System", value="", lines=1, max_lines=3, scale=4, container=True) user_prompt = gr.Textbox( label="User", value=("Hi, I need something like a cake recipe but it can't include eggs, sugar, " "gluten-- so no flour, actually make that no grains of any sort, as well as " "liquid and solid ingredients, none of them either. Thanks!"), lines=1, max_lines=3, scale=5, container=True, ) load_btn = gr.Button("Load", variant="primary", scale=1, min_width=80) # ── Run-summary strip: arch · tier · L×H · tokens · entropy · top1 · selected · mode · muted ── summary_pane = gr.HTML(value="
no model loaded
", elem_id="summary-strip") muted_html = gr.HTML(elem_id="muted-row") # ── Context: full width ── ctx_pane = gr.HTML(value="

Context

" "
Click Load to begin.
") # ── Command bar: directly under Context (eye scans context → types command) ── with gr.Row(elem_id="cmdbar"): cmd = gr.Textbox( label="CMD", placeholder=( "1-N select branch · 1,5 select+keep · i inject · " "h *|h L|h L H scan · top R1,R2,… rank-pick · rew N · " "spark [lo:hi] rank trace · w N · l N · mute L H · unmute L H|all · muted · r · s" ), scale=10, autofocus=True, lines=1, max_lines=1, container=True, ) run_btn = gr.Button("Run", scale=1, variant="primary", min_width=70) # ── Instrument grid: Branches | Head Map on one row, Spark + Transcript below. # Row heights locked so widening branches (w 30) doesn't reflow neighbors. ── with gr.Row(elem_id="dual-row"): with gr.Column(scale=3, min_width=280, elem_id="col-branch"): fut_pane = gr.HTML(value="

Branches (next-token top-k)

" "
Top-k continuations appear after Load.
") with gr.Column(scale=7, elem_id="col-head"): hm_pane = gr.HTML(value="

Head Map " "— per-head pre-projection · logit-lens · per-head Δ-residual · Lfull

" "
Auto-runs after every step.
") spark_pane = gr.HTML(value="

Token Rank Trace " "(per-(head,layer) rank of a chosen token)

" "
Run spark <token>.
") status = gr.HTML(value="

Transcript

" "
Ready.
") # Wrap renderers so each output is its own pane (with title) def _wrap(summary, ctx, hm, fut, sp, mute, st): return ( summary, f"

Context

{ctx}
", f"

Head Map " f"— per-head pre-projection · logit-lens · per-head Δ-residual · Lfull

{hm}
", f"

Branches (next-token top-k)

{fut}
", f"

Token Rank Trace " f"(per-(head,layer) rank of a chosen token)

{sp}
", mute, f"

Transcript

{st}
", ) def on_load(mid, sp, up, s): summary, ctx, hm, fut, spk, mute, st = initial_load(mid, sp, up, s) return _wrap(summary, ctx, hm, fut, spk, mute, st) + (s,) def on_cmd(c, s, mid, sp): summary, ctx, hm, fut, spk, mute, st = handle_command(c, s, mid, sp) return _wrap(summary, ctx, hm, fut, spk, mute, st) + ("", s) OUT_BASE = [summary_pane, ctx_pane, hm_pane, fut_pane, spark_pane, muted_html, status] load_btn.click( on_load, inputs=[model_id, system_prompt, user_prompt, session], outputs=OUT_BASE + [session], ) run_btn.click( on_cmd, inputs=[cmd, session, model_id, system_prompt], outputs=OUT_BASE + [cmd, session], ) cmd.submit( on_cmd, inputs=[cmd, session, model_id, system_prompt], outputs=OUT_BASE + [cmd, session], ) return demo if __name__ == "__main__": build_ui().launch()