VINDEX / app.py
Chris4K's picture
Update app.py
bdbe22a verified
#!/usr/bin/env python3
"""
VINDEX β€” FastAPI + D3 single-file LLM knowledge editor
Phase 1: Engine extensions | Phase 2: FastAPI | Phase 3: D3 frontend
pip install transformers torch fastapi uvicorn pydantic
python temp.py
"""
# ══════════════════════════════════════════════════════════════
# ENGINE
# ══════════════════════════════════════════════════════════════
import re, json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
class ArchAdapter:
def __init__(self, model):
self.model = model
self.style = self._detect_style()
self.n_layers = self._count_layers()
def _detect_style(self):
t = self.model.config.model_type
if t in ("gpt2","gpt_neo","gpt_neox"): return "gpt2"
if t in ("llama","mistral","qwen2","gemma","gemma2","phi3",
"falcon","codellama","deepseek","internlm2"): return "gated"
try:
if hasattr(self._layer(0).mlp, "gate_proj"): return "gated"
except: pass
return "gpt2"
def _layer(self, i):
m = self.model
if hasattr(m,"transformer") and hasattr(m.transformer,"h"):
return m.transformer.h[i]
if hasattr(m,"model") and hasattr(m.model,"layers"):
return m.model.layers[i]
raise ValueError("Unknown model structure")
def _count_layers(self):
m = self.model
if hasattr(m,"transformer") and hasattr(m.transformer,"h"):
return len(m.transformer.h)
if hasattr(m,"model") and hasattr(m.model,"layers"):
return len(m.model.layers)
raise ValueError("Cannot count layers")
def get_ffn_weights(self, li):
layer = self._layer(li)
if self.style == "gpt2":
return layer.mlp.c_fc.weight.detach().T, layer.mlp.c_proj.weight.detach().T
return layer.mlp.gate_proj.weight.detach(), layer.mlp.down_proj.weight.detach()
def set_ffn_weights(self, li, Wg, Wd):
layer = self._layer(li)
with torch.no_grad():
if self.style == "gpt2":
layer.mlp.c_fc.weight.copy_(Wg.T)
layer.mlp.c_proj.weight.copy_(Wd.T)
else:
layer.mlp.gate_proj.weight.copy_(Wg)
layer.mlp.down_proj.weight.copy_(Wd)
def get_embedding(self):
return self.model.get_input_embeddings().weight.detach()
def get_unembedding(self):
if hasattr(self.model,"lm_head"):
return self.model.lm_head.weight.detach()
return self.get_embedding()
class VIndex:
def __init__(self, model_name: str, device: Optional[str] = None):
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.patches: List[Dict] = []
self._base_weights: Optional[Dict] = None
self.tok = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float32).to(self.device)
self.model.eval()
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
self.arch = ArchAdapter(self.model)
n = self.arch.n_layers
self.kb_start = n // 3
self.kb_end = n
@property
def info(self):
return (f"{self.model_name} | {self.arch.n_layers} layers | "
f"style={self.arch.style} | kb=L{self.kb_start}-L{self.kb_end-1} | {self.device}")
# ── utils ──────────────────────────────────────────────────
def embed(self, text: str):
ids = self.tok.encode(text, add_special_tokens=False)
if not ids: raise ValueError(f"Cannot tokenize: {text!r}")
E = self.arch.get_embedding().to(self.device)
return E[ids].mean(0)
def decode_down_col(self, col, top_k=5):
scores = self.arch.get_unembedding().to(self.device) @ col.to(self.device)
top = scores.topk(top_k)
return [(self.tok.decode([i.item()]).strip(), v.item())
for i,v in zip(top.indices,top.values) if v.item()>0]
def token_id(self, word: str) -> int:
ids = self.tok.encode(word, add_special_tokens=False)
return ids[0] if ids else 0
def _forward(self, prompt: str):
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.model(**inputs)
return out.logits[0,-1]
# ── Phase 1: new engine methods ────────────────────────────
def _get_subject_activations(self, prompt: str, subject: str
) -> Tuple[Dict[int, torch.Tensor], int]:
"""Capture h_L[last_subject_token_pos] at every layer via forward hooks."""
enc = self.tok(prompt, return_tensors="pt").to(self.device)
ids = enc["input_ids"][0].tolist()
# Subsequence match to find subject position
subj_ids = self.tok.encode(subject, add_special_tokens=False)
subject_pos = 0
for start in range(len(ids) - len(subj_ids) + 1):
if ids[start:start+len(subj_ids)] == subj_ids:
subject_pos = start + len(subj_ids) - 1
break
else:
# Fallback: find any single token from subject_ids
for si in subj_ids:
if si in ids:
subject_pos = ids.index(si)
break
activations: Dict[int, torch.Tensor] = {}
handles = []
def make_hook(li):
def hook(m, inp, out):
h = out[0] if isinstance(out, tuple) else out
activations[li] = h[0, subject_pos].detach().clone()
return hook
for li in range(self.arch.n_layers):
handles.append(self.arch._layer(li).register_forward_hook(make_hook(li)))
with torch.no_grad():
self.model(**enc)
for h in handles:
h.remove()
return activations, subject_pos
def locate(self, prompt: str, subject: str, target: str) -> Dict:
"""Diagnostic: combines trace + activation-guided similarity scan."""
trace_stats = self.trace(prompt, target)
# Find phase_layer: biggest relative rank drop
phase_layer = 0
best_drop = 0.0
prev_rank = None
for s in trace_stats:
if prev_rank is not None and prev_rank > 5:
drop = (prev_rank - s["rank"]) / prev_rank
if drop > best_drop:
best_drop = drop
phase_layer = s["l"]
prev_rank = s["rank"]
activations, subject_pos = self._get_subject_activations(prompt, subject)
layer_scores = []
for li in range(self.kb_start, self.kb_end):
h_L = activations.get(li)
if h_L is None:
layer_scores.append({"layer": li, "max_sim": 0.0, "best_slot": -1, "top_tokens": []})
continue
Wg, Wd = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
h_n = F.normalize(h_L, dim=0)
sims = F.normalize(Wg, dim=1) @ h_n
best_slot = int(sims.argmax().item())
max_sim = float(sims[best_slot].item())
Wd = Wd.to(self.device)
top_tokens = self.decode_down_col(Wd[:, best_slot], top_k=3)
layer_scores.append({
"layer": li, "max_sim": round(max_sim, 4),
"best_slot": best_slot,
"top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_tokens]
})
return {
"phase_layer": phase_layer,
"subject_pos": subject_pos,
"layer_scores": layer_scores,
"trace": trace_stats,
}
def precise_update(self, prompt: str, subject: str, relation: str,
new_target: str, top_k: int = 3, scale: float = 1.0,
log: Optional[List[str]] = None):
"""Activation-guided update: uses h_L[subject_pos] instead of embed(subject)."""
if log is None: log = []
if not prompt:
return self.update(subject, relation, new_target, top_k=top_k, scale=scale, log=log)
self._snapshot()
activations, subject_pos = self._get_subject_activations(prompt, subject)
tv = self.embed(new_target)
log.append(f"PRECISE_UPDATE: '{subject}' -[{relation}]-> '{new_target}'")
log.append(f" subject_pos={subject_pos} scale={scale} top_k={top_k}")
candidates = []
for li in range(self.kb_start, self.kb_end):
h_L = activations.get(li)
if h_L is None: continue
Wg, _ = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
h_n = F.normalize(h_L, dim=0)
sims = F.normalize(Wg, dim=1) @ h_n
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
candidates.append((v.item(), li, idx.item()))
log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs)))
candidates.sort(key=lambda x: -x[0])
best_sim = candidates[0][0] if candidates else 0.0
log.append(f"\n Best activation_sim = {best_sim:.4f} (embed-based would be ~{best_sim/3.5:.4f})")
if best_sim < 0.05:
log.append(" ⚠ sim < 0.05 β†’ INSERT fallback")
return self.insert(subject, relation, new_target, log=log)
chosen = [c for c in candidates if c[0] >= 0.05][:top_k]
ops = []
for sim, li, slot in chosen:
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
col_norm = Wd[:, slot].norm().item()
new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,
"down_col":new_col,"activation_sim":round(sim,4)})
log.append(f" βœ“ L{li} slot {slot}: act_sim={sim:.4f} col_norm={col_norm:.4f}")
self.patches.append({"type":"PRECISE_UPDATE","entity":subject,"relation":relation,
"new_target":new_target,"ops":ops})
self._apply_all_patches()
log.append(f"\n βœ“ Applied {len(ops)} op(s), patch #{len(self.patches)}")
return ops
def suppress(self, entity: str, top_k: int = 3, factor: float = 0.0,
log: Optional[List[str]] = None):
"""Scale down (or zero) matching down columns."""
if log is None: log = []
self._snapshot()
ev_n = F.normalize(self.embed(entity), dim=0)
log.append(f"SUPPRESS: '{entity}' factor={factor} top_k={top_k}")
ops = []
candidates = []
for li in range(self.kb_start, self.kb_end):
Wg, _ = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
sims = F.normalize(Wg, dim=1) @ ev_n
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
candidates.append((v.item(), li, idx.item()))
candidates.sort(key=lambda x: -x[0])
chosen = [c for c in candidates if c[0] >= 0.05][:top_k]
for sim, li, slot in chosen:
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
new_col = (Wd[:, slot] * factor).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}")
self.patches.append({"type":"SUPPRESS","entity":entity,"factor":factor,"ops":ops})
self._apply_all_patches()
log.append(f" βœ“ Suppressed {len(ops)} slot(s)")
return {"ops": len(ops), "log": log}
def amplify(self, entity: str, top_k: int = 3, factor: float = 2.0,
log: Optional[List[str]] = None):
"""Scale up matching down columns."""
if log is None: log = []
self._snapshot()
ev_n = F.normalize(self.embed(entity), dim=0)
log.append(f"AMPLIFY: '{entity}' factor={factor} top_k={top_k}")
ops = []
candidates = []
for li in range(self.kb_start, self.kb_end):
Wg, _ = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
sims = F.normalize(Wg, dim=1) @ ev_n
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
candidates.append((v.item(), li, idx.item()))
candidates.sort(key=lambda x: -x[0])
chosen = [c for c in candidates if c[0] >= 0.05][:top_k]
for sim, li, slot in chosen:
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
new_col = (Wd[:, slot] * factor).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}")
self.patches.append({"type":"AMPLIFY","entity":entity,"factor":factor,"ops":ops})
self._apply_all_patches()
log.append(f" βœ“ Amplified {len(ops)} slot(s)")
return {"ops": len(ops), "log": log}
def style_shift(self, anchor_entity: str, from_concept: str, to_concept: str,
top_k: int = 3, strength: float = 0.5,
log: Optional[List[str]] = None):
"""Add a direction vector to matching down columns."""
if log is None: log = []
self._snapshot()
ev_n = F.normalize(self.embed(anchor_entity), dim=0)
from_v = self.embed(from_concept)
to_v = self.embed(to_concept)
dir_v = F.normalize(to_v - from_v, dim=0)
log.append(f"STYLE_SHIFT: anchor='{anchor_entity}' {from_concept!r}β†’{to_concept!r} strength={strength}")
ops = []
candidates = []
for li in range(self.kb_start, self.kb_end):
Wg, _ = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
sims = F.normalize(Wg, dim=1) @ ev_n
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
candidates.append((v.item(), li, idx.item()))
candidates.sort(key=lambda x: -x[0])
chosen = [c for c in candidates if c[0] >= 0.05][:top_k]
for sim, li, slot in chosen:
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
col = Wd[:, slot]
new_col = (col + dir_v * col.norm() * strength).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} col_norm={col.norm():.4f}")
self.patches.append({"type":"STYLE_SHIFT","entity":anchor_entity,
"from":from_concept,"to":to_concept,"ops":ops})
self._apply_all_patches()
log.append(f" βœ“ Style-shifted {len(ops)} slot(s)")
return {"ops": len(ops), "log": log}
def multi_edit(self, facts: List[Dict], mode: str = "UPDATE",
alpha: float = 0.25, top_k: int = 3, scale: float = 1.0):
"""Apply a batch of edits sequentially."""
results = []
for f in facts:
log: List[str] = []
try:
entity = f["entity"]
relation = f.get("relation", "")
new_target = f["new_target"]
prompt = f.get("prompt", "")
if mode == "PRECISE" and prompt:
ops = self.precise_update(prompt, entity, relation, new_target,
top_k=top_k, scale=scale, log=log)
elif mode == "INSERT":
ops = self.insert(entity, relation, new_target,
alpha=alpha, spread=top_k, log=log)
else:
ops = self.update(entity, relation, new_target,
top_k=top_k, scale=scale, log=log)
results.append({"entity":entity,"status":"ok",
"ops": len(ops) if isinstance(ops,(list,)) else 1,
"log":log})
except Exception as e:
results.append({"entity":f.get("entity","?"),"status":"error",
"error":str(e),"log":log})
return results
def gate_heatmap(self, entity: str, use_activation: bool = False,
prompt: Optional[str] = None) -> Dict:
"""Full layerΓ—slot similarity matrix with top token decoding."""
if use_activation and prompt:
activations, _ = self._get_subject_activations(prompt, entity)
else:
activations = {}
ev_n = F.normalize(self.embed(entity), dim=0)
layers_out = []
for li in range(self.kb_start, self.kb_end):
Wg, Wd = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device); Wd = Wd.to(self.device)
if use_activation and li in activations:
q = F.normalize(activations[li], dim=0)
else:
q = ev_n
sims = F.normalize(Wg, dim=1) @ q
top_slots_count = min(20, sims.shape[0])
vals, idxs = sims.topk(top_slots_count)
slots = []
for v, idx in zip(vals, idxs):
top_toks = self.decode_down_col(Wd[:, idx], top_k=3)
slots.append({
"slot": int(idx.item()),
"sim": round(float(v.item()), 4),
"top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_toks]
})
layers_out.append({"layer": li, "slots": slots})
return {"layers": layers_out}
def dry_run(self, entity: str, new_target: str, top_k: int = 3,
scale: float = 1.0, prompt: Optional[str] = None) -> Dict:
"""Same logic as precise_update/update but does NOT mutate weights."""
if prompt:
activations, subject_pos = self._get_subject_activations(prompt, entity)
use_act = True
else:
use_act = False
ev_n = F.normalize(self.embed(entity), dim=0)
tv = self.embed(new_target)
candidates = []
for li in range(self.kb_start, self.kb_end):
Wg, Wd = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device); Wd = Wd.to(self.device)
if use_act and li in activations:
q = F.normalize(activations[li], dim=0)
else:
q = ev_n
sims = F.normalize(Wg, dim=1) @ q
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
col_norm = Wd[:, idx].norm().item()
top_toks = self.decode_down_col(Wd[:, idx], top_k=3)
candidates.append({
"layer": li, "slot": int(idx.item()),
"sim": round(float(v.item()), 4),
"col_norm": round(col_norm, 4),
"inject_norm": round(col_norm * scale, 4),
"current_top": [{"tok":t,"score":round(s,2)} for t,s in top_toks]
})
candidates.sort(key=lambda x: -x["sim"])
best_sim = candidates[0]["sim"] if candidates else 0.0
chosen = [c for c in candidates if c["sim"] >= 0.05][:top_k]
return {
"candidates": chosen,
"best_sim": best_sim,
"would_patch": len(chosen),
"new_target": new_target,
"mode": "activation-guided" if use_act else "embed-based"
}
# ── Phase 1+: mechanistic attribution ─────────────────────
def gradient_slot_scores(self, prompt: str, target: str) -> Dict:
"""One backward pass: grad norm of βˆ‚(-log p(target))/βˆ‚W_down[:,slot] per KB layer.
Identifies which slots causally contributed to this prediction via gradient signal."""
target_id = self.token_id(target)
# Temporarily enable grad on down-proj weights
down_params: List[Tuple[int, torch.nn.Parameter]] = []
for li in range(self.arch.n_layers):
layer = self.arch._layer(li)
p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \
else layer.mlp.down_proj.weight
p.requires_grad_(True)
down_params.append((li, p))
self.model.zero_grad()
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
out = self.model(**inputs)
loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id]
loss.backward()
layer_scores = []
for li, p in down_params:
grad = p.grad
p.requires_grad_(False)
if grad is None:
layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []})
continue
# gpt2: c_proj.weight [ffn_dim, hidden] β†’ rows = slots
# gated: down_proj.weight [hidden, ffn_dim] β†’ cols = slots
slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \
else grad.norm(dim=0) # [ffn_dim]
k = min(20, slot_norms.shape[0])
vals, idxs = slot_norms.topk(k)
layer_scores.append({
"layer": li,
"max_grad": round(float(vals[0].item()), 6),
"top_slots": [{"slot": int(idx.item()),
"grad_norm": round(float(v.item()), 6)}
for idx, v in zip(idxs, vals)]
})
self.model.zero_grad()
return {"layer_scores": layer_scores}
def causal_patch_trace(self, prompt: str, subject: str, target: str,
noise_std: float = 0.1) -> Dict:
"""ROME-style causal tracing.
Corrupts subject embeddings, then for each KB layer measures how much
patching that layer's hidden state (at subject position) restores p(target).
Expensive: O(n_layers) forward passes."""
target_id = self.token_id(target)
W_u = self.arch.get_unembedding().to(self.device)
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
ids = inputs["input_ids"][0].tolist()
# Find subject token positions via subsequence match
subj_ids = self.tok.encode(subject, add_special_tokens=False)
subj_pos: List[int] = []
for start in range(len(ids) - len(subj_ids) + 1):
if ids[start:start+len(subj_ids)] == subj_ids:
subj_pos = list(range(start, start+len(subj_ids)))
break
if not subj_pos:
for si in subj_ids:
if si in ids:
subj_pos = [ids.index(si)]
break
if not subj_pos:
subj_pos = [0]
# ── Clean forward β€” capture every layer's hidden states ──
clean_hs: Dict[int, torch.Tensor] = {}
clean_handles = []
def _mk_clean(li):
def _h(m, inp, out):
h = out[0] if isinstance(out, tuple) else out
clean_hs[li] = h[0].detach().clone() # [seq, hidden]
return _h
for li in range(self.arch.n_layers):
clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li)))
with torch.no_grad():
clean_out = self.model(**inputs)
for h in clean_handles: h.remove()
clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item())
# ── Corrupted embeddings ──
E = self.arch.get_embedding().to(self.device)
emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() # [1, seq, hidden]
noise_scale = emb.std().item() * noise_std
for pos in subj_pos:
emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale
with torch.no_grad():
corr_out = self.model(inputs_embeds=emb)
corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item())
# ── Causal patch sweep ──
results = []
for li in range(self.kb_start, self.kb_end):
def _mk_patch(target_li):
def _h(m, inp, out):
if target_li not in clean_hs:
return out
is_tuple = isinstance(out, tuple)
h = list(out) if is_tuple else [out]
clean = clean_hs[target_li]
for pos in subj_pos:
if pos < clean.shape[0]:
h[0][0, pos] = clean[pos].to(h[0].device)
return tuple(h) if is_tuple else h[0]
return _h
ph = self.arch._layer(li).register_forward_hook(_mk_patch(li))
with torch.no_grad():
patch_out = self.model(inputs_embeds=emb.clone())
ph.remove()
patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item())
ie = patch_prob - corr_prob
results.append({
"layer": li,
"patch_prob": round(patch_prob, 6),
"indirect_effect": round(ie, 6),
})
return {
"clean_prob": round(clean_prob, 6),
"corrupt_prob": round(corr_prob, 6),
"subject_pos": subj_pos,
"results": results,
}
def smart_locate(self, prompt: str, subject: str, target: str,
alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3,
noise_std: float = 0.1) -> Dict:
"""Combined gate_sim + grad_norm + causal_effect β†’ precise layer/slot ranking.
alpha = weight for gate cosine sim
beta = weight for gradient norm
gamma = weight for causal indirect effect"""
gate_data = self.locate(prompt, subject, target)
grad_data = self.gradient_slot_scores(prompt, target)
causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std)
gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]}
grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]}
causal_map = {r["layer"]: max(0.0, r["indirect_effect"])
for r in causal_data["results"]}
grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]}
layers = sorted(set(gate_map) | set(grad_map) | set(causal_map))
def _norm(vals: List[float]) -> List[float]:
m = max(vals) if vals else 1.0
return [v/m if m > 0 else 0.0 for v in vals]
gv = [gate_map.get(l, 0.0) for l in layers]
dv = [grad_map.get(l, 0.0) for l in layers]
cv = [causal_map.get(l, 0.0) for l in layers]
gn, dn, cn = _norm(gv), _norm(dv), _norm(cv)
ranked = []
for i, l in enumerate(layers):
score = alpha*gn[i] + beta*dn[i] + gamma*cn[i]
ranked.append({
"layer": l,
"gate_sim": round(gv[i], 4),
"grad_norm": round(dv[i], 6),
"causal_effect": round(cv[i], 6),
"gate_sim_n": round(gn[i], 4),
"grad_norm_n": round(dn[i], 4),
"causal_n": round(cn[i], 4),
"combined": round(score, 4),
"best_slots": (grad_slots.get(l) or [])[:5],
})
ranked.sort(key=lambda x: -x["combined"])
return {
"ranked_layers": ranked,
"phase_layer": gate_data["phase_layer"],
"subject_pos": gate_data["subject_pos"],
"clean_prob": causal_data["clean_prob"],
"corrupt_prob": causal_data["corrupt_prob"],
"recommendation": ranked[0] if ranked else None,
"weights": {"alpha": alpha, "beta": beta, "gamma": gamma},
}
def smart_edit(self, prompt: str, subject: str, relation: str,
old_target: str, new_target: str,
top_layers: int = 3, slots_per_layer: int = 2,
scale: float = 1.5, noise_std: float = 0.1,
alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2,
log: Optional[List[str]] = None) -> Dict:
"""Auto edit: runs smart_locate on (prompt, subject, old_target) to find
the exact layer+slot targets via gradient+causal+gate consensus, then
patches those W_down columns toward embed(new_target).
old_target = what the model currently predicts (used to locate)
new_target = what you want to inject
top_layers = how many top-ranked layers to patch
slots_per_layer = gradient-identified slots to patch per layer
scale = col_norm multiplier (1.5-3.0 recommended)
beta > alpha because grad_norm is more reliable than gate_sim for small models."""
if log is None: log = []
self._snapshot()
log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} β†’ {new_target!r}")
log.append(f" Running smart_locate on prompt: {prompt!r}")
log.append(f" Weights: Ξ±={alpha} Ξ²={beta} Ξ³={gamma} noise_std={noise_std}")
sl = self.smart_locate(prompt, subject, old_target,
alpha=alpha, beta=beta, gamma=gamma,
noise_std=noise_std)
log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}")
log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}")
if sl["clean_prob"] < 1e-5:
log.append(" ⚠ clean_prob near zero β€” model barely knows this fact.")
log.append(" Grad-norm signal still valid. Causal IE=0 is expected.")
log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.")
tv = self.embed(new_target)
tv_n = F.normalize(tv, dim=0)
ops = []
used = []
top_ranked = sl["ranked_layers"][:top_layers]
for lr in top_ranked:
li = lr["layer"]
# Use gradient-identified slots β€” far more precise than gate cosine
grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]]
if not grad_slots:
log.append(f" L{li}: no grad slots, skipping")
continue
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
for slot in grad_slots:
col_norm = Wd[:, slot].norm().item()
new_col = (tv_n * col_norm * scale).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
log.append(f" βœ“ L{li} slot {slot}: combined={lr['combined']} "
f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} "
f"inject={col_norm*scale:.4f}")
used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]})
self.patches.append({
"type": "SMART_UPDATE",
"entity": subject,
"relation": relation,
"new_target": new_target,
"old_target": old_target,
"smart_top": top_ranked,
"ops": ops,
})
self._apply_all_patches()
log.append(f"\n βœ“ {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}")
return {
"ops": ops,
"used_layers": used,
"smart_locate": sl,
"log": log,
}
def infer(self, prompt: str, top_k: int = 5):
probs = torch.softmax(self._forward(prompt), dim=-1)
top = probs.topk(top_k)
return [{"token": self.tok.decode([idx.item()]).strip(),
"prob": round(val.item(), 6)}
for idx, val in zip(top.indices, top.values)]
def describe(self, entity: str, top_k: int = 10):
ev_n = F.normalize(self.embed(entity), dim=0)
all_edges = []
for li in range(self.kb_start, self.kb_end):
Wg, Wd = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device); Wd = Wd.to(self.device)
sims = F.normalize(Wg, dim=1) @ ev_n
for fid in sims.topk(min(5,sims.shape[0])).indices:
gsim = sims[fid].item()
if gsim < 0.08: continue
for tok, score in self.decode_down_col(Wd[:,fid], 4):
if tok: all_edges.append({"tok":tok,"score":score,"layer":li,"gate_sim":gsim})
best: Dict[str, Any] = {}
for e in all_edges:
t = e["tok"]
if t not in best or e["score"] > best[t]["score"]:
best[t] = e
ranked = sorted(best.values(), key=lambda x: -x["score"])[:top_k]
return ranked
def trace(self, prompt: str, target: str):
target_id = self.token_id(target)
W_u = self.arch.get_unembedding().to(self.device)
stats: List[Dict] = []
handles = []
def make_hook(li):
def hook(m, inp, out):
h = out[0] if isinstance(out,tuple) else out
last = h[0,-1].detach()
p = torch.softmax(W_u @ last, dim=-1)
rank = int((p > p[target_id]).sum().item()) + 1
stats.append({"l":li,"rank":rank,"prob":round(p[target_id].item(),8)})
return hook
for li in range(self.arch.n_layers):
handles.append(self.arch._layer(li).register_forward_hook(make_hook(li)))
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
with torch.no_grad(): self.model(**inputs)
for h in handles: h.remove()
return stats
# ── patch management ───────────────────────────────────────
def _snapshot(self):
if self._base_weights is not None: return
self._base_weights = {}
for li in range(self.arch.n_layers):
Wg, Wd = self.arch.get_ffn_weights(li)
self._base_weights[li] = (Wg.clone().cpu(), Wd.clone().cpu())
def _restore_base(self):
if self._base_weights is None: return
for li,(Wg,Wd) in self._base_weights.items():
self.arch.set_ffn_weights(li, Wg.to(self.device), Wd.to(self.device))
def _apply_all_patches(self):
self._restore_base()
for patch in self.patches:
for op in patch.get("ops",[]):
li=op["layer"]; slot=op["slot"]
Wg, Wd = self.arch.get_ffn_weights(li)
Wg=Wg.clone(); Wd=Wd.clone()
if op["op"] in ("insert","update_gate"):
Wg[slot] = torch.tensor(op["gate_row"],dtype=Wg.dtype,device=self.device)
if op["op"] in ("insert","update_down"):
Wd[:,slot] = torch.tensor(op["down_col"],dtype=Wd.dtype,device=self.device)
self.arch.set_ffn_weights(li, Wg, Wd)
def insert(self, entity, relation, target, alpha=0.25, spread=4, log=None):
if log is None: log = []
self._snapshot()
gate_dir = F.normalize(self.embed(entity), dim=0)
down_dir = F.normalize(self.embed(target), dim=0)
ls=self.kb_start; le=min(ls+spread, self.arch.n_layers)
log.append(f"INSERT: '{entity}' -[{relation}]-> '{target}' alpha={alpha}")
ops=[]
for li in range(ls,le):
Wg, Wd = self.arch.get_ffn_weights(li)
Wg=Wg.to(self.device); Wd=Wd.to(self.device)
norms_g=Wg.norm(dim=1); norms_d=Wd.norm(dim=0)
slot=norms_g.argmin().item()
wg_mean=norms_g.mean().item(); wd_mean=norms_d.mean().item()
ops.append({"op":"insert","layer":li,"slot":slot,
"gate_row":(gate_dir*wg_mean*alpha).cpu().tolist(),
"down_col":(down_dir*wd_mean*alpha).cpu().tolist()})
log.append(f" L{li}: slot={slot} inject={wg_mean*alpha:.4f}")
self.patches.append({"type":"INSERT","entity":entity,"relation":relation,
"target":target,"ops":ops})
self._apply_all_patches()
return ops
def update(self, entity, relation, new_target, top_k=3, scale=1.0, log=None):
if log is None: log = []
self._snapshot()
ev = self.embed(entity)
tv = self.embed(new_target)
ev_n = F.normalize(ev, dim=0)
log.append(f"UPDATE: '{entity}' -[{relation}]-> '{new_target}' top_k={top_k} scale={scale}")
candidates = []
for li in range(self.kb_start, self.kb_end):
Wg, _ = self.arch.get_ffn_weights(li)
Wg = Wg.to(self.device)
sims = F.normalize(Wg, dim=1) @ ev_n
k = min(top_k, sims.shape[0])
vals, idxs = sims.topk(k)
for v, idx in zip(vals, idxs):
candidates.append((v.item(), li, idx.item()))
log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs)))
candidates.sort(key=lambda x: -x[0])
best_sim = candidates[0][0] if candidates else 0.0
if best_sim < 0.08:
log.append(" ⚠ sim<0.08 β†’ INSERT fallback")
return self.insert(entity, relation, new_target, log=log), "INSERT_FALLBACK"
chosen = [c for c in candidates if c[0] >= 0.08][:top_k]
ops = []
for sim, li, slot in chosen:
_, Wd = self.arch.get_ffn_weights(li)
Wd = Wd.to(self.device)
col_norm = Wd[:, slot].norm().item()
new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist()
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
log.append(f" βœ“ L{li} slot {slot}: sim={sim:.4f} norm={col_norm:.4f}")
self.patches.append({"type":"UPDATE","entity":entity,"relation":relation,
"new_target":new_target,"ops":ops})
self._apply_all_patches()
li0,slot0,sim0 = chosen[0]
return li0, slot0, sim0
def reset(self):
self.patches.clear()
self._restore_base()
def save_patch(self, path: str):
with open(path,"w") as f:
json.dump(self.patches, f, indent=2)
def compile_to(self, output_dir: str):
Path(output_dir).mkdir(parents=True, exist_ok=True)
self.model.save_pretrained(output_dir)
self.tok.save_pretrained(output_dir)
# ══════════════════════════════════════════════════════════════
# HTML FRONTEND (Phase 3)
# ══════════════════════════════════════════════════════════════
HTML_PAGE = r"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>VINDEX β€” LLM Knowledge Editor</title>
<script src="https://unpkg.com/d3@7"></script>
<style>
:root {
--bg: #0d1117; --bg2: #161b22; --bg3: #21262d;
--border: #30363d; --text: #e6edf3; --muted: #8b949e;
--blue: #58a6ff; --green: #3fb950; --red: #f85149;
--yellow: #d29922; --purple: #bc8cff; --cyan: #39d353;
--font: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', monospace;
}
* { box-sizing: border-box; margin: 0; padding: 0; }
body { background: var(--bg); color: var(--text); font-family: var(--font);
font-size: 13px; min-height: 100vh; }
#app { max-width: 1400px; margin: 0 auto; padding: 16px; }
/* HEADER */
.header { background: var(--bg2); border: 1px solid var(--border);
border-radius: 10px; padding: 20px 28px; margin-bottom: 14px;
display: flex; align-items: center; justify-content: space-between; }
.header-title { font-size: 20px; font-weight: 700; letter-spacing: 1px; }
.header-title span { color: var(--blue); }
.status-pill { background: var(--bg3); border: 1px solid var(--border);
border-radius: 20px; padding: 6px 14px; font-size: 11px;
color: var(--muted); display: flex; align-items: center; gap: 8px; }
.status-dot { width: 8px; height: 8px; border-radius: 50%; background: var(--red); }
.status-dot.ok { background: var(--green); }
/* TABS */
.tabs { display: flex; gap: 2px; background: var(--bg2); border: 1px solid var(--border);
border-radius: 8px; padding: 4px; margin-bottom: 14px; flex-wrap: wrap; }
.tab-btn { padding: 7px 16px; border: none; background: transparent; color: var(--muted);
font-family: var(--font); font-size: 12px; cursor: pointer; border-radius: 6px;
letter-spacing: .5px; transition: all .15s; }
.tab-btn:hover { color: var(--text); background: var(--bg3); }
.tab-btn.active { background: var(--blue); color: #fff; }
/* PANELS */
.panel { display: none; }
.panel.active { display: block; }
.card { background: var(--bg2); border: 1px solid var(--border);
border-radius: 8px; padding: 18px; margin-bottom: 12px; }
.card h3 { font-size: 11px; letter-spacing: 2px; text-transform: uppercase;
color: var(--muted); margin-bottom: 14px; }
/* FORM ELEMENTS */
label { font-size: 11px; color: var(--muted); display: block; margin-bottom: 4px; }
input[type=text], textarea, select {
width: 100%; background: var(--bg3); border: 1px solid var(--border);
color: var(--text); font-family: var(--font); font-size: 12px;
padding: 8px 10px; border-radius: 6px; outline: none; }
input[type=text]:focus, textarea:focus { border-color: var(--blue); }
textarea { resize: vertical; min-height: 80px; }
input[type=range] { width: 100%; accent-color: var(--blue); }
.range-row { display: flex; align-items: center; gap: 10px; }
.range-row span { min-width: 40px; text-align: right; color: var(--blue); }
button {
background: var(--blue); color: #fff; border: none; font-family: var(--font);
font-size: 12px; padding: 8px 18px; border-radius: 6px; cursor: pointer;
font-weight: 600; letter-spacing: .5px; transition: opacity .15s; }
button:hover { opacity: .85; }
button.secondary { background: var(--bg3); color: var(--text); border: 1px solid var(--border); }
button.danger { background: var(--red); }
button.warn { background: var(--yellow); color: #000; }
.row { display: flex; gap: 12px; flex-wrap: wrap; }
.col { flex: 1; min-width: 200px; }
.col2 { flex: 2; min-width: 300px; }
/* RADIO GROUP */
.radio-group { display: flex; flex-wrap: wrap; gap: 6px; }
.radio-group label { display: flex; align-items: center; gap: 5px; cursor: pointer;
background: var(--bg3); border: 1px solid var(--border); border-radius: 5px;
padding: 5px 10px; color: var(--text); font-size: 11px; margin: 0; }
.radio-group input { accent-color: var(--blue); }
/* LOG BOX */
.log { background: var(--bg); border: 1px solid var(--border); border-radius: 6px;
padding: 12px; font-size: 11px; color: var(--muted); max-height: 200px;
overflow-y: auto; white-space: pre-wrap; margin-top: 8px; }
/* SVG CHARTS */
.chart-wrap { background: var(--bg); border: 1px solid var(--border);
border-radius: 6px; overflow: hidden; margin-top: 10px; }
svg text { font-family: var(--font); fill: var(--text); }
.tooltip {
position: absolute; background: var(--bg2); border: 1px solid var(--border);
border-radius: 6px; padding: 8px 12px; font-size: 11px; pointer-events: none;
opacity: 0; transition: opacity .1s; z-index: 100; }
/* PATCH CARDS */
.patch-card { background: var(--bg3); border: 1px solid var(--border);
border-radius: 6px; padding: 10px 14px; margin-bottom: 8px;
display: flex; align-items: center; justify-content: space-between; }
.patch-badge { padding: 2px 8px; border-radius: 4px; font-size: 10px;
font-weight: 700; letter-spacing: 1px; margin-right: 10px; }
.badge-UPDATE { background: var(--blue); color: #fff; }
.badge-INSERT { background: var(--green); color: #000; }
.badge-PRECISE { background: var(--purple); color: #fff; }
.badge-SUPPRESS{ background: var(--red); color: #fff; }
.badge-AMPLIFY { background: var(--yellow); color: #000; }
.badge-STYLE { background: var(--cyan); color: #000; }
.patch-del { background: transparent; border: none; color: var(--red);
font-size: 16px; cursor: pointer; padding: 0 6px; }
/* HEATMAP */
.hm-cell { stroke: var(--bg); stroke-width: 1; cursor: pointer; }
.hm-cell:hover { stroke: var(--blue); stroke-width: 2; }
/* LOCATE SECTION */
.locate-bar { height: 22px; background: var(--bg3); border-radius: 3px;
margin-bottom: 3px; display: flex; align-items: center;
padding: 0 8px; cursor: pointer; transition: background .1s; }
.locate-bar:hover { background: var(--bg2); }
.locate-bar .layer-lbl { width: 40px; color: var(--muted); font-size: 10px; }
.locate-bar .bar-fill { height: 10px; border-radius: 2px; background: var(--purple); }
.locate-bar .sim-val { margin-left: 8px; font-size: 10px; color: var(--blue); }
/* Scrollbar */
::-webkit-scrollbar { width: 6px; } ::-webkit-scrollbar-track { background: var(--bg2); }
::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
</style>
</head>
<body>
<div id="app">
<!-- HEADER -->
<div class="header">
<div>
<div class="header-title">VINDEX <span>β—ˆ</span> LLM Knowledge Editor</div>
<div style="color:var(--muted);font-size:11px;margin-top:4px">
The model IS the database. Inspect Β· Edit Β· Locate Β· Compile.
</div>
</div>
<div class="status-pill">
<div class="status-dot" id="status-dot"></div>
<span id="status-text">No model loaded</span>
<span id="patch-count" style="color:var(--blue)"></span>
</div>
</div>
<!-- TABS -->
<div class="tabs">
<button class="tab-btn active" onclick="showTab('infer')">β‘  Infer</button>
<button class="tab-btn" onclick="showTab('describe')">β‘‘ Describe</button>
<button class="tab-btn" onclick="showTab('trace')">β‘’ Trace</button>
<button class="tab-btn" onclick="showTab('locate')">β‘£ Locate</button>
<button class="tab-btn" onclick="showTab('smartlocate')">β‘€ Smart Locate</button>
<button class="tab-btn" onclick="showTab('heatmap')">β‘₯ Heatmap</button>
<button class="tab-btn" onclick="showTab('edit')">⑦ Edit</button>
<button class="tab-btn" onclick="showTab('patches')">β‘§ Patches</button>
<button class="tab-btn" onclick="showTab('guide')" style="margin-left:auto;color:var(--green)">πŸ“– Guide</button>
<button class="tab-btn" onclick="showTab('load')">βš™ Load</button>
</div>
<!-- TOOLTIP -->
<div class="tooltip" id="tooltip"></div>
<!-- ══════════ LOAD PANEL ══════════ -->
<div id="panel-load" class="panel">
<div class="card">
<h3>Load Model</h3>
<div class="row">
<div class="col2">
<label>Model name or local path</label>
<input type="text" id="load-model" value="distilgpt2"
placeholder="distilgpt2 | Qwen/Qwen2.5-1.5B-Instruct | /local/path">
</div>
<div class="col">
<label>Device</label>
<select id="load-device">
<option value="auto">Auto</option>
<option value="cpu">CPU</option>
<option value="cuda">CUDA</option>
<option value="mps">MPS</option>
</select>
</div>
</div>
<div style="margin-top:12px">
<button onclick="loadModel()">⚑ Load Model</button>
</div>
<div id="load-log" class="log" style="margin-top:12px;display:none"></div>
</div>
<div class="card">
<h3>Quick Models</h3>
<div style="color:var(--muted);line-height:1.8">
<span style="color:var(--blue)">distilgpt2</span> β€” 350 MB, instant<br>
<span style="color:var(--blue)">gpt2</span> β€” 550 MB<br>
<span style="color:var(--blue)">gpt2-medium</span> β€” 1.5 GB<br>
<span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> β€” 3 GB, strong facts
</div>
</div>
</div>
<!-- ══════════ INFER PANEL ══════════ -->
<div id="panel-infer" class="panel active">
<div class="card">
<h3>Next-Token Prediction</h3>
<div class="row">
<div class="col2">
<label>Prompt</label>
<input type="text" id="infer-prompt" value="The capital of France is">
</div>
<div class="col">
<label>Top-K: <span id="infer-k-val">10</span></label>
<div class="range-row">
<input type="range" id="infer-k" min="1" max="20" value="10"
oninput="document.getElementById('infer-k-val').textContent=this.value">
</div>
</div>
</div>
<button onclick="runInfer()" style="margin-top:12px">β–Ά Run Infer</button>
</div>
<div class="card">
<h3>Results</h3>
<div class="chart-wrap" id="infer-chart" style="min-height:200px"></div>
</div>
</div>
<!-- ══════════ DESCRIBE PANEL ══════════ -->
<div id="panel-describe" class="panel">
<div class="card">
<h3>Entity Knowledge Graph (W_gate KNN β†’ W_down decode)</h3>
<div class="row">
<div class="col">
<label>Entity</label>
<input type="text" id="desc-entity" value="France">
</div>
<div class="col">
<label>Top edges: <span id="desc-k-val">10</span></label>
<div class="range-row">
<input type="range" id="desc-k" min="1" max="20" value="10"
oninput="document.getElementById('desc-k-val').textContent=this.value">
</div>
</div>
</div>
<button onclick="runDescribe()" style="margin-top:12px">β–Ά Run Describe</button>
</div>
<div class="card">
<h3>Force-Directed Graph</h3>
<div class="chart-wrap" id="desc-chart" style="height:420px"></div>
</div>
</div>
<!-- ══════════ TRACE PANEL ══════════ -->
<div id="panel-trace" class="panel">
<div class="card">
<h3>Layer-by-Layer Rank Trace</h3>
<div class="row">
<div class="col2">
<label>Prompt</label>
<input type="text" id="trace-prompt" value="The capital of France is">
</div>
<div class="col">
<label>Target token</label>
<input type="text" id="trace-target" value="Paris">
</div>
</div>
<button onclick="runTrace()" style="margin-top:12px">β–Ά Run Trace</button>
</div>
<div class="card">
<h3>Rank + Probability over Layers</h3>
<div class="chart-wrap" id="trace-chart" style="min-height:300px"></div>
</div>
</div>
<!-- ══════════ LOCATE PANEL ══════════ -->
<div id="panel-locate" class="panel">
<div class="card">
<h3>Locate β€” Diagnostic (Trace + Activation Similarity)</h3>
<div class="row">
<div class="col2">
<label>Prompt</label>
<input type="text" id="loc-prompt" value="The capital of France is">
</div>
<div class="col">
<label>Subject</label>
<input type="text" id="loc-subject" value="France">
</div>
<div class="col">
<label>Target</label>
<input type="text" id="loc-target" value="Paris">
</div>
</div>
<button onclick="runLocate()" style="margin-top:12px">β–Ά Locate</button>
</div>
<div class="row">
<div class="col2 card">
<h3>Trace (rank over layers)</h3>
<div class="chart-wrap" id="loc-trace-chart" style="min-height:220px"></div>
</div>
<div class="col card">
<h3>Activation Sim per KB Layer</h3>
<div id="loc-bars" style="margin-top:8px"></div>
<div id="loc-slots" class="log" style="margin-top:8px;display:none"></div>
</div>
</div>
<div class="card" id="loc-summary" style="display:none">
<h3>Summary</h3>
<div id="loc-summary-text" style="color:var(--blue)"></div>
</div>
</div>
<!-- ══════════ SMART LOCATE PANEL ══════════ -->
<div id="panel-smartlocate" class="panel">
<div class="card">
<h3>Smart Locate β€” gradient + causal + gate_sim combined</h3>
<div style="color:var(--muted);font-size:11px;margin-bottom:12px;line-height:1.7">
Three independent signals combined into one ranked layer list.<br>
<span style="color:var(--blue)">β–  gate_sim</span> β€” static embedding cosine (fast, weak proxy) &nbsp;
<span style="color:var(--green)">β–  grad_norm</span> β€” βˆ‚loss/βˆ‚W_down per slot (one backward pass) &nbsp;
<span style="color:var(--yellow)">β–  causal IE</span> β€” indirect effect via subject-corruption patching (N_layers passes, slow)
</div>
<div class="row">
<div class="col2">
<label>Prompt</label>
<input type="text" id="sl-prompt" value="The capital of France is">
</div>
<div class="col">
<label>Subject</label>
<input type="text" id="sl-subject" value="France">
</div>
<div class="col">
<label>Target</label>
<input type="text" id="sl-target" value="Paris">
</div>
</div>
<div class="row" style="margin-top:10px">
<div class="col">
<label>Ξ± gate_sim: <span id="sl-a-val">0.4</span></label>
<input type="range" id="sl-alpha" min="0" max="1" step="0.05" value="0.4"
oninput="document.getElementById('sl-a-val').textContent=this.value">
</div>
<div class="col">
<label>Ξ² grad_norm: <span id="sl-b-val">0.3</span></label>
<input type="range" id="sl-beta" min="0" max="1" step="0.05" value="0.3"
oninput="document.getElementById('sl-b-val').textContent=this.value">
</div>
<div class="col">
<label>Ξ³ causal: <span id="sl-g-val">0.3</span></label>
<input type="range" id="sl-gamma" min="0" max="1" step="0.05" value="0.3"
oninput="document.getElementById('sl-g-val').textContent=this.value">
</div>
<div class="col">
<label>Noise Οƒ: <span id="sl-noise-val">0.1</span></label>
<input type="range" id="sl-noise" min="0.02" max="0.5" step="0.02" value="0.1"
oninput="document.getElementById('sl-noise-val').textContent=this.value">
</div>
</div>
<div style="display:flex;gap:8px;margin-top:12px;flex-wrap:wrap">
<button onclick="runSmartLocate()">⚑ Smart Locate (full)</button>
<button class="secondary" onclick="runGradientOnly()">β–Ά Gradient only (fast)</button>
<button class="secondary" onclick="runCausalOnly()">β–Ά Causal trace only</button>
</div>
<div id="sl-status" style="color:var(--muted);font-size:11px;margin-top:8px"></div>
</div>
<div class="row">
<div class="col2 card">
<h3>Layer Rankings β€” 3-signal stacked bars</h3>
<div class="chart-wrap" id="sl-chart" style="min-height:320px"></div>
</div>
<div class="col card">
<h3>Recommendation</h3>
<div id="sl-rec" class="log">Run Smart Locate to see the best edit target.</div>
<h3 style="margin-top:14px">Collateral Probe</h3>
<div class="row" style="margin-top:8px">
<input type="text" id="sl-coll-prompt" value="Biggest cities in France"
style="flex:2" placeholder="Collateral prompt…">
<button class="secondary" onclick="runCollateral()" style="flex:0">β–Ά</button>
</div>
<div id="sl-coll-out" class="log" style="margin-top:8px">Probe a prompt to check collateral damage.</div>
</div>
</div>
<div class="card">
<h3>Per-Layer Detail</h3>
<div id="sl-table" style="overflow-x:auto">
<div style="color:var(--muted);font-size:11px">Run Smart Locate first.</div>
</div>
</div>
</div>
<!-- ══════════ HEATMAP PANEL ══════════ -->
<div id="panel-heatmap" class="panel">
<div class="card">
<h3>Gate Heatmap β€” Layer Γ— Slot Cosine Similarity</h3>
<div class="row">
<div class="col">
<label>Entity</label>
<input type="text" id="hm-entity" value="France">
</div>
<div class="col">
<label>Prompt (optional β€” for activation mode)</label>
<input type="text" id="hm-prompt" value="The capital of France is">
</div>
<div class="col">
<label>Mode</label>
<div class="radio-group">
<label><input type="radio" name="hm-mode" value="embed" checked> Embed</label>
<label><input type="radio" name="hm-mode" value="activation"> Activation</label>
</div>
</div>
</div>
<button onclick="runHeatmap()" style="margin-top:12px">β–Ά Run Heatmap</button>
</div>
<div class="row">
<div class="col2 card">
<h3>Heatmap</h3>
<div class="chart-wrap" id="hm-chart" style="min-height:300px"></div>
</div>
<div class="col card">
<h3>Selected Slot</h3>
<div id="hm-slot-detail" class="log">Click a cell to see decoded tokens.</div>
</div>
</div>
</div>
<!-- ══════════ EDIT PANEL ══════════ -->
<div id="panel-edit" class="panel">
<div class="card">
<h3>Edit</h3>
<div class="row">
<div class="col">
<label>Entity</label>
<input type="text" id="edit-entity" value="France">
<label style="margin-top:8px">Relation</label>
<input type="text" id="edit-relation" value="capital">
<label style="margin-top:8px">Old value (context only)</label>
<input type="text" id="edit-old" value="Paris">
<label style="margin-top:8px">New value (inject)</label>
<input type="text" id="edit-new" value="Lyon">
</div>
<div class="col">
<label>Mode</label>
<div class="radio-group" id="edit-mode-group">
<label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label>
<label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label>
<label><input type="radio" name="edit-mode" value="SMART"> β˜… SMART</label>
<label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label>
<label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label>
<label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label>
<label><input type="radio" name="edit-mode" value="STYLE-SHIFT"> STYLE-SHIFT</label>
<label><input type="radio" name="edit-mode" value="MULTI-EDIT"> MULTI-EDIT</label>
</div>
<div id="precise-prompt-row" style="display:none;margin-top:8px">
<label>Prompt (PRECISE mode)</label>
<input type="text" id="edit-prompt" value="The capital of France is">
</div>
<div id="smart-row" style="display:none;margin-top:8px;background:var(--bg);border:1px solid var(--border);border-radius:6px;padding:10px">
<div style="color:var(--blue);font-size:11px;font-weight:700;margin-bottom:6px">β˜… SMART AUTO MODE</div>
<label>Prompt (used for locate + after-check)</label>
<input type="text" id="smart-prompt" value="The capital of France is">
<label style="margin-top:6px">Old value (what model currently says)</label>
<input type="text" id="smart-old" value="Paris">
<div class="row" style="margin-top:6px">
<div class="col">
<label>Top layers: <span id="smart-layers-val">3</span></label>
<input type="range" id="smart-layers" min="1" max="8" value="3"
oninput="document.getElementById('smart-layers-val').textContent=this.value">
</div>
<div class="col">
<label>Slots/layer: <span id="smart-slots-val">2</span></label>
<input type="range" id="smart-slots" min="1" max="5" value="2"
oninput="document.getElementById('smart-slots-val').textContent=this.value">
</div>
</div>
<div style="color:var(--muted);font-size:10px;margin-top:6px">
Runs smart_locate internally β†’ patches gradient-identified slots. No manual tuning needed.
</div>
</div>
<div id="style-shift-row" style="display:none;margin-top:8px">
<label>From concept</label>
<input type="text" id="ss-from" value="formal">
<label style="margin-top:6px">To concept</label>
<input type="text" id="ss-to" value="casual">
<label style="margin-top:6px">Strength: <span id="ss-str-val">0.5</span></label>
<div class="range-row">
<input type="range" id="ss-strength" min="0.1" max="2.0" step="0.1" value="0.5"
oninput="document.getElementById('ss-str-val').textContent=this.value">
</div>
</div>
<div id="multiedit-row" style="display:none;margin-top:8px">
<label>JSON array [{entity,relation,new_target,prompt?},...]</label>
<textarea id="multi-json" rows="5">[{"entity":"France","relation":"capital","new_target":"Lyon"}]</textarea>
</div>
<label style="margin-top:8px">Top-K slots: <span id="edit-k-val">3</span></label>
<div class="range-row">
<input type="range" id="edit-k" min="1" max="10" value="3"
oninput="document.getElementById('edit-k-val').textContent=this.value">
</div>
<label style="margin-top:6px">Scale: <span id="edit-scale-val">1.5</span></label>
<div class="range-row">
<input type="range" id="edit-scale" min="0.5" max="5.0" step="0.25" value="1.5"
oninput="document.getElementById('edit-scale-val').textContent=this.value">
</div>
<label style="margin-top:6px">Alpha (INSERT): <span id="edit-alpha-val">0.25</span></label>
<div class="range-row">
<input type="range" id="edit-alpha" min="0.05" max="1.0" step="0.05" value="0.25"
oninput="document.getElementById('edit-alpha-val').textContent=this.value">
</div>
<div style="display:flex;gap:8px;margin-top:14px;flex-wrap:wrap">
<button onclick="runDryRun()">πŸ” Dry Run</button>
<button onclick="runEdit()">⚑ Apply Edit</button>
</div>
</div>
</div>
</div>
<div class="row">
<div class="col card">
<h3>Before / After</h3>
<div class="chart-wrap" id="edit-chart" style="min-height:220px"></div>
</div>
<div class="col card">
<h3>Dry Run Preview</h3>
<div id="dryrun-log" class="log">Run dry-run first.</div>
</div>
</div>
<div class="card">
<h3>Edit Log + Delta</h3>
<div id="edit-log" class="log">No edit yet.</div>
</div>
</div>
<!-- ══════════ PATCHES PANEL ══════════ -->
<div id="panel-patches" class="panel">
<div class="card">
<h3>Active Patches</h3>
<div style="display:flex;gap:8px;margin-bottom:12px;flex-wrap:wrap">
<button onclick="loadPatches()">↻ Refresh</button>
<button class="secondary" onclick="savePatches()">πŸ’Ύ Save JSON</button>
<button class="secondary" onclick="compileModel()">πŸ“¦ Compile</button>
<button class="danger" onclick="resetModel()">⊘ Reset to Base</button>
</div>
<div id="patches-list"><div style="color:var(--muted)">No patches.</div></div>
</div>
<div class="card">
<h3>Concept Reference</h3>
<div style="color:var(--muted);line-height:1.9;font-size:11px">
<span style="color:var(--blue)">UPDATE</span> β€” rewrites W_down column β†’ different fact<br>
<span style="color:var(--purple)">PRECISE</span> β€” activation-guided UPDATE (3–5Γ— better sim)<br>
<span style="color:var(--green)">INSERT</span> β€” new (gate,down) pair in weakest slot<br>
<span style="color:var(--red)">SUPPRESS</span> β€” scale down β†’ model forgets entity<br>
<span style="color:var(--yellow)">AMPLIFY</span> β€” scale up β†’ stronger recall<br>
<span style="color:var(--cyan)">STYLE-SHIFT</span> β€” adds direction vector β†’ tone/bias shift
</div>
</div>
</div>
<!-- ══════════ GUIDE PANEL ══════════ -->
<div id="panel-guide" class="panel">
<div class="card">
<h3>What is VINDEX doing?</h3>
<div style="line-height:1.9;color:var(--muted)">
In a transformer, factual associations like <span style="color:var(--text)">"France β†’ capital β†’ Paris"</span>
are stored as direction vectors in the <span style="color:var(--blue)">W_down columns</span> of FFN layers.
The <span style="color:var(--blue)">W_gate rows</span> act as keys: when the residual stream resembles "France",
the matching gate fires, the down column adds "Paris" direction to the stream, and the unembedding reads out "Paris".
VINDEX surgically replaces those down columns without retraining.
</div>
</div>
<div class="card">
<h3>Quickstart β€” 5-step experiment</h3>
<div style="line-height:2;font-size:12px">
<div style="color:var(--yellow);margin-bottom:4px">Step 1 β€” Load a model that actually knows facts</div>
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
βš™ Load tab β†’ <span style="color:var(--blue)">gpt2-medium</span> (1.5 GB, knows capitals) or
<span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> (3 GB, strong).<br>
distilgpt2 has clean_probβ‰ˆ0 for most facts β†’ causal IE=0 everywhere β†’ misleading results.
</div>
<div style="color:var(--yellow);margin-bottom:4px">Step 2 β€” Verify the model knows the fact</div>
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
β‘  Infer: prompt = <code>"The capital of France is"</code><br>
βœ“ Good: "Paris" appears in top-3 with prob &gt; 0.05<br>
βœ— Bad: top tokens are "a", "the", "known" β†’ model doesn't know it β†’ skip to INSERT mode
</div>
<div style="color:var(--yellow);margin-bottom:4px">Step 3 β€” Find where the fact lives</div>
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
β‘’ Trace: prompt = <code>"The capital of France is"</code>, target = <code>"Paris"</code><br>
β†’ Look for phase layer: where rank drops from ~30000 to &lt;100. That's where the fact materializes.<br>
β‘€ Smart Locate β†’ Gradient only (fast, 1 backward pass):<br>
<span style="margin-left:16px">subject = <code>France</code>, target = <code>Paris</code></span><br>
β†’ The layer with highest grad_norm bar = best edit target. Note the slot numbers.
</div>
<div style="color:var(--yellow);margin-bottom:4px">Step 4 β€” Edit with SMART mode</div>
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
⑦ Edit tab β†’ mode = <span style="color:var(--blue)">β˜… SMART</span><br>
Entity = <code>France</code> | Relation = <code>capital</code><br>
Old value = <code>Paris</code> (what model says now β€” used for locate)<br>
New value = <code>Lyon</code> (what you want)<br>
Prompt = <code>"The capital of France is"</code><br>
Scale = <code>2.0</code> (start here; increase to 3.0 if effect is weak)<br>
β†’ Click <b>Apply Edit</b>. Smart locate runs internally, patches grad-identified slots.
</div>
<div style="color:var(--yellow);margin-bottom:4px">Step 5 β€” Check collateral damage</div>
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
β‘  Infer: <code>"The capital of France is"</code> β†’ should now say Lyon<br>
β‘  Infer: <code>"Biggest cities in France"</code> β†’ should be unchanged (different slots)<br>
β‘  Infer: <code>"Paris is a city in"</code> β†’ should still say France<br>
β‘  Infer: <code>"Lyon is a city in"</code> β†’ might now also say France (collateral)<br>
β‘€ Smart Locate collateral probe β†’ run these prompts, compare slot lists in β‘§ Patches
</div>
</div>
</div>
<div class="card">
<h3>Interpreting Smart Locate results</h3>
<div style="font-size:11px;line-height:1.9">
<div class="row" style="gap:20px">
<div class="col">
<div style="color:var(--blue);font-weight:700;margin-bottom:6px">β–  gate_sim (blue)</div>
<div style="color:var(--muted)">
Cosine between W_gate[slot] and embed(subject).<br>
Fast, cheap, but <b>weak proxy</b> β€” measures embedding-space similarity,<br>
not causal contribution. Useful for finding <i>related</i> slots.<br>
<b>High gate_sim + low grad_norm</b> = slot activates for this entity<br>
but doesn't contribute much to this specific prediction.
</div>
</div>
<div class="col">
<div style="color:var(--green);font-weight:700;margin-bottom:6px">β–  grad_norm (green)</div>
<div style="color:var(--muted)">
β€–βˆ‚(-log p(target))/βˆ‚W_down[:,slot]β€– β€” how much changing this slot<br>
would affect the loss for this (prompt, target) pair.<br>
<b>Most reliable signal</b>, works even when clean_prob is tiny.<br>
One backward pass. Use Ξ² &gt; Ξ± to weight this higher.<br>
<b>High grad_norm</b> = this slot is causally upstream of the prediction.
</div>
</div>
<div class="col">
<div style="color:var(--yellow);font-weight:700;margin-bottom:6px">β–  causal IE (yellow)</div>
<div style="color:var(--muted)">
Indirect effect via noise-corruption patching (ROME-style).<br>
Measures: if I corrupt subject embeddings, how much does patching<br>
layer L's hidden state at subject pos <i>restore</i> the prediction?<br>
<b>Most interpretable</b> β€” true causal measurement. But:<br>
If clean_prob β‰ˆ 0, IE = 0 everywhere (nothing to restore).<br>
Needs a model that actually knows the fact.
</div>
</div>
</div>
<div style="margin-top:12px;padding:10px;background:var(--bg);border-radius:6px;border:1px solid var(--border)">
<span style="color:var(--yellow)">⚠ Your distilgpt2 result:</span>
<span style="color:var(--muted)"> clean_prob=0.000001 β†’ causal IE=0 everywhere (expected, not a bug).
grad_norm on L9/slot515 IS real signal β€” that slot responds to France+capital context in the gradient sense.
But the probability mass is too diffuse to show causal separation.
Switch to gpt2-medium for textbook causal results.</span>
</div>
</div>
</div>
<div class="card">
<h3>Edit modes β€” when to use which</h3>
<div style="font-size:11px">
<table style="width:100%;border-collapse:collapse">
<thead><tr style="border-bottom:1px solid var(--border);color:var(--muted)">
<th style="padding:6px 8px;text-align:left">Mode</th>
<th style="padding:6px 8px;text-align:left">Slot selection</th>
<th style="padding:6px 8px;text-align:left">Best for</th>
<th style="padding:6px 8px;text-align:left">Knobs</th>
</tr></thead>
<tbody style="color:var(--muted)">
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--blue)">UPDATE</td>
<td style="padding:6px 8px">gate cosine sim to embed(entity)</td>
<td style="padding:6px 8px">Quick experiment, model knows the fact well</td>
<td style="padding:6px 8px">Top-K=3-5, Scale=1.5-3</td>
</tr>
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--purple)">PRECISE</td>
<td style="padding:6px 8px">gate cosine sim to h_L[subject_pos]</td>
<td style="padding:6px 8px">In-context subject representation (3-5Γ— better than UPDATE)</td>
<td style="padding:6px 8px">+ Prompt field</td>
</tr>
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--yellow)">β˜… SMART</td>
<td style="padding:6px 8px">gradient norm β†’ exact slots, then patch</td>
<td style="padding:6px 8px"><b>Best overall.</b> Auto-locates, no manual tuning</td>
<td style="padding:6px 8px">Top layers=3, Slots/layer=2, Scale=1.5-2.5</td>
</tr>
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--green)">INSERT</td>
<td style="padding:6px 8px">weakest slot (norm-based)</td>
<td style="padding:6px 8px">Model has no knowledge of fact, build from scratch</td>
<td style="padding:6px 8px">Alpha=0.4-0.7, Spread=4-6</td>
</tr>
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--red)">SUPPRESS</td>
<td style="padding:6px 8px">gate cosine β†’ scale W_down to 0</td>
<td style="padding:6px 8px">Make model forget an entity (factor=0) or weaken (0.5)</td>
<td style="padding:6px 8px">Factor: 0=forget, 0.5=weaken</td>
</tr>
<tr style="border-bottom:1px solid var(--border)">
<td style="padding:6px 8px;color:var(--cyan)">STYLE-SHIFT</td>
<td style="padding:6px 8px">gate cosine β†’ add direction vector</td>
<td style="padding:6px 8px">Bias/tone shifts: CEO→less male-coded, Paris→darker</td>
<td style="padding:6px 8px">from/to concepts, strength=0.3-0.8</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="card">
<h3>Experiments to run</h3>
<div style="font-size:11px;line-height:1.9;color:var(--muted)">
<div style="color:var(--text);margin-bottom:4px">Experiment A β€” Capital swap (classic ROME benchmark)</div>
Model: gpt2-medium | Prompt: "The capital of France is" | Old: Paris | New: Lyon<br>
Check: "France's capital city" | "Lyon is now" | "Paris is in" | "Eiffel Tower is in"<br>
Insight: does it generalize (paraphrase) or is it prompt-specific?<br><br>
<div style="color:var(--text);margin-bottom:4px">Experiment B β€” Slot overlap analysis (your collateral question)</div>
1. SMART locate "The capital of France is" β†’ note slot numbers in recommendation<br>
2. SMART locate "The biggest city in France is" β†’ compare slot lists<br>
3. Overlap = slots that will be collaterally damaged<br>
4. No overlap = clean surgery βœ“<br><br>
<div style="color:var(--text);margin-bottom:4px">Experiment C β€” Suppression then INSERT</div>
SUPPRESS France β†’ then INSERT France capital Lyon β†’ Infer<br>
vs just UPDATE. Which gives cleaner, more confident result?<br><br>
<div style="color:var(--text);margin-bottom:4px">Experiment D β€” Style shift (no factual change)</div>
STYLE-SHIFT: anchor=CEO, from="male", to="female", strength=0.3<br>
Then Infer: "The CEO of the company is a" β€” does pronoun distribution shift?<br>
Insight: this is mechanical debiasing without retraining.<br><br>
<div style="color:var(--text);margin-bottom:4px">Experiment E β€” Compile and compare</div>
Edit 5 facts. Compile β†’ save as new model directory.<br>
Load compiled model fresh β†’ Infer same prompts β†’ edits should persist in weights.<br>
Then Trace on compiled model β†’ phase layers should shift or sharpen.
</div>
</div>
<div class="card">
<h3>Ξ± Ξ² Ξ³ tuning guide</h3>
<div style="font-size:11px;line-height:1.9;color:var(--muted)">
<b style="color:var(--text)">Default (0.4 / 0.3 / 0.3)</b> β€” balanced, works for unknown model quality<br>
<b style="color:var(--text)">Grad-heavy (0.1 / 0.7 / 0.2)</b> β€” clean_prob &gt; 0.01. Grad signal is sharp, trust it.<br>
<b style="color:var(--text)">Gate+Grad (0.4 / 0.4 / 0.2)</b> β€” recommended for smart_edit when causal IE is weak<br>
<b style="color:var(--text)">Causal-heavy (0.2 / 0.2 / 0.6)</b> β€” only when clean_prob &gt; 0.1. IE is the gold signal then.<br>
<b style="color:var(--text)">Gate-only (1.0 / 0.0 / 0.0)</b> β€” equivalent to basic locate(), sanity check<br>
<br>
<b style="color:var(--yellow)">Your distilgpt2 setting:</b> use (0.3 / 0.7 / 0.0) β€” gate+grad, skip causal (it's 0 anyway).
</div>
</div>
</div>
</div><!-- /app -->
<script>
// ═══════════════════════════════════════════════
// STATE
// ═══════════════════════════════════════════════
let modelLoaded = false;
const BASE = ''; // same origin
// ═══════════════════════════════════════════════
// UTILS
// ═══════════════════════════════════════════════
function showTab(name) {
document.querySelectorAll('.panel').forEach(p => p.classList.remove('active'));
document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active'));
document.getElementById('panel-'+name).classList.add('active');
event.target.classList.add('active');
}
function showTooltip(html, x, y) {
const t = document.getElementById('tooltip');
t.innerHTML = html; t.style.opacity = 1;
t.style.left = (x+14)+'px'; t.style.top = (y-10)+'px';
}
function hideTooltip() { document.getElementById('tooltip').style.opacity = 0; }
async function api(path, body) {
const opts = body !== undefined
? { method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(body) }
: { method:'GET' };
const r = await fetch(BASE+path, opts);
if (!r.ok) { const e = await r.json(); throw new Error(e.detail || r.statusText); }
return r.json();
}
function setStatus(ok, text, patches) {
const dot = document.getElementById('status-dot');
dot.classList.toggle('ok', ok);
document.getElementById('status-text').textContent = text;
if (patches !== undefined)
document.getElementById('patch-count').textContent = patches ? ` Β· ${patches} patch(es)` : '';
}
// ═══════════════════════════════════════════════
// LOAD
// ═══════════════════════════════════════════════
async function loadModel() {
const logEl = document.getElementById('load-log');
logEl.style.display = 'block'; logEl.textContent = 'Loading…';
try {
const r = await api('/api/load', {
model_name: document.getElementById('load-model').value.trim(),
device: document.getElementById('load-device').value
});
logEl.textContent = 'βœ“ ' + r.info;
modelLoaded = true;
setStatus(true, r.info.split(' | ')[0], 0);
} catch(e) { logEl.textContent = 'βœ— '+e.message; }
}
async function checkStatus() {
try {
const r = await api('/api/status');
if (r.loaded) {
modelLoaded = true;
setStatus(true, r.info.split(' | ')[0], r.patches_count);
}
} catch(e) {}
}
// ═══════════════════════════════════════════════
// D3 HELPERS
// ═══════════════════════════════════════════════
const C = { bg:'#0d1117', bg2:'#161b22', bg3:'#21262d',
blue:'#58a6ff', green:'#3fb950', red:'#f85149',
yellow:'#d29922', purple:'#bc8cff', cyan:'#39d353',
text:'#e6edf3', muted:'#8b949e' };
function clearChart(id) {
const el = document.getElementById(id);
el.innerHTML = ''; return el;
}
// ═══════════════════════════════════════════════
// INFER
// ═══════════════════════════════════════════════
async function runInfer() {
try {
const data = await api('/api/infer', {
prompt: document.getElementById('infer-prompt').value,
top_k: +document.getElementById('infer-k').value
});
drawInferChart(data.results);
} catch(e) { alert(e.message); }
}
function drawInferChart(results) {
const el = clearChart('infer-chart');
const W = el.clientWidth || 700, H = 36 + results.length * 38;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const margin = {left:120, right:70, top:16, bottom:16};
const w = W - margin.left - margin.right;
const maxP = d3.max(results, d=>d.prob);
const x = d3.scaleLinear().domain([0, maxP]).range([0, w]);
const g = svg.append('g').attr('transform',`translate(${margin.left},${margin.top})`);
const color = d3.scaleSequential(d3.interpolateCool)
.domain([results.length-1, 0]);
results.forEach((d,i) => {
const y = i * 38;
// Label
g.append('text').attr('x',-8).attr('y',y+19).attr('text-anchor','end')
.attr('fill',C.text).attr('font-size',13).text(d.token || '(empty)');
// Bar
g.append('rect').attr('x',0).attr('y',y+8).attr('width', x(d.prob))
.attr('height',20).attr('rx',3).attr('fill',color(i)).attr('opacity',.85)
.on('mousemove',(ev)=>showTooltip(`${d.token}: ${d.prob.toFixed(6)}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
// Value
g.append('text').attr('x',x(d.prob)+6).attr('y',y+19)
.attr('fill',C.muted).attr('font-size',11).text(d.prob.toFixed(4));
});
}
// ═══════════════════════════════════════════════
// DESCRIBE β€” force graph
// ═══════════════════════════════════════════════
async function runDescribe() {
try {
const data = await api('/api/describe', {
entity: document.getElementById('desc-entity').value,
top_k: +document.getElementById('desc-k').value
});
drawDescribeGraph(document.getElementById('desc-entity').value, data.edges);
} catch(e) { alert(e.message); }
}
function drawDescribeGraph(center, edges) {
const el = clearChart('desc-chart');
const W = el.clientWidth || 700, H = 420;
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
if (!edges.length) {
svg.append('text').attr('x',W/2).attr('y',H/2).attr('text-anchor','middle')
.attr('fill',C.muted).text('No edges found (sim < 0.08)'); return;
}
const maxLayer = d3.max(edges, d=>d.layer);
const layerColor = d3.scaleSequential(d3.interpolateRainbow).domain([0, maxLayer]);
const maxScore = d3.max(edges, d=>d.score);
const nodes = [{id:'__center__', label:center, r:22, fixed:true}];
const links = [];
edges.forEach((e,i) => {
const id = 'n'+i;
nodes.push({id, label:e.tok, score:e.score, layer:e.layer, gate_sim:e.gate_sim,
r: 8 + Math.sqrt(e.score/maxScore)*10});
links.push({source:'__center__', target:id, layer:e.layer});
});
const sim = d3.forceSimulation(nodes)
.force('link', d3.forceLink(links).id(d=>d.id).distance(100))
.force('charge', d3.forceManyBody().strength(-120))
.force('center', d3.forceCenter(W/2, H/2))
.force('collision', d3.forceCollide(d=>d.r+4));
const link = svg.append('g').selectAll('line').data(links).join('line')
.attr('stroke', d=>layerColor(d.layer)).attr('stroke-opacity',.5).attr('stroke-width',1.5);
const node = svg.append('g').selectAll('g').data(nodes).join('g')
.call(d3.drag()
.on('start',(ev,d)=>{if(!ev.active)sim.alphaTarget(.3).restart();d.fx=d.x;d.fy=d.y})
.on('drag',(ev,d)=>{d.fx=ev.x;d.fy=ev.y})
.on('end',(ev,d)=>{if(!ev.active)sim.alphaTarget(0);d.fx=null;d.fy=null}));
node.append('circle').attr('r',d=>d.r)
.attr('fill', d=>d.id==='__center__' ? C.blue : layerColor(d.layer||0))
.attr('opacity',.85)
.on('mousemove',(ev,d)=>{
if(d.id==='__center__') return;
showTooltip(`${d.label}<br>score: ${d.score?.toFixed(1)}<br>L${d.layer} sim:${d.gate_sim?.toFixed(3)}`,ev.pageX,ev.pageY);
}).on('mouseleave',hideTooltip);
node.append('text').attr('dy','0.35em').attr('text-anchor','middle')
.attr('font-size', d=>d.id==='__center__'?13:10)
.attr('fill',C.text).attr('pointer-events','none')
.text(d=>d.label.substring(0,12));
sim.on('tick',()=>{
link.attr('x1',d=>d.source.x).attr('y1',d=>d.source.y)
.attr('x2',d=>d.target.x).attr('y2',d=>d.target.y);
node.attr('transform',d=>`translate(${d.x},${d.y})`);
});
// Fix center node
nodes[0].fx = W/2; nodes[0].fy = H/2;
}
// ═══════════════════════════════════════════════
// TRACE β€” dual-axis line chart
// ═══════════════════════════════════════════════
async function runTrace() {
try {
const data = await api('/api/trace', {
prompt: document.getElementById('trace-prompt').value,
target: document.getElementById('trace-target').value
});
drawTraceChart(data.stats, 'trace-chart');
} catch(e) { alert(e.message); }
}
function drawTraceChart(stats, chartId) {
const el = clearChart(chartId);
const W = el.clientWidth || 700, H = 300;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const m = {left:55, right:55, top:20, bottom:30};
const w = W-m.left-m.right, h = H-m.top-m.bottom;
// Find phase transition
let phaseL = -1, maxDrop = 0, prevRank = null;
stats.forEach(s=>{
if(prevRank!==null && prevRank>5){
const drop=(prevRank-s.rank)/prevRank;
if(drop>maxDrop){maxDrop=drop;phaseL=s.l;}
}
prevRank=s.rank;
});
const x = d3.scaleLinear().domain([0,stats.length-1]).range([0,w]);
const yRank = d3.scaleLog().domain([1, d3.max(stats,d=>d.rank)]).range([h,0]).clamp(true);
const yProb = d3.scaleLinear().domain([0, d3.max(stats,d=>d.prob)]).range([h,0]);
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
// Axes
g.append('g').attr('transform',`translate(0,${h})`)
.call(d3.axisBottom(x).ticks(stats.length<=12?stats.length:8).tickFormat(d=>'L'+d))
.selectAll('text').attr('fill',C.muted).attr('font-size',10);
g.append('g').call(d3.axisLeft(yRank).ticks(5).tickFormat(d3.format('d')))
.selectAll('text').attr('fill',C.muted).attr('font-size',10);
g.append('g').attr('transform',`translate(${w},0)`)
.call(d3.axisRight(yProb).ticks(5).tickFormat(d3.format('.2e')))
.selectAll('text').attr('fill',C.green).attr('font-size',10);
// Phase marker
if(phaseL>=0){
const px = x(phaseL);
g.append('line').attr('x1',px).attr('x2',px).attr('y1',0).attr('y2',h)
.attr('stroke',C.yellow).attr('stroke-dasharray','4,2').attr('stroke-width',1.5);
g.append('text').attr('x',px+4).attr('y',12).attr('fill',C.yellow).attr('font-size',10)
.text('⚑ PHASE L'+phaseL);
}
// Rank line (log)
const lineRank = d3.line().x((_,i)=>x(i)).y(d=>yRank(Math.max(1,d.rank)));
g.append('path').datum(stats).attr('fill','none').attr('stroke',C.blue)
.attr('stroke-width',2).attr('d',lineRank);
// Prob line
const lineProb = d3.line().x((_,i)=>x(i)).y(d=>yProb(d.prob));
g.append('path').datum(stats).attr('fill','none').attr('stroke',C.green)
.attr('stroke-width',1.5).attr('stroke-dasharray','5,2').attr('d',lineProb);
// Dots
g.selectAll('.dot').data(stats).join('circle').attr('class','dot')
.attr('cx',(_,i)=>x(i)).attr('cy',d=>yRank(Math.max(1,d.rank)))
.attr('r',3).attr('fill',C.blue)
.on('mousemove',(ev,d)=>showTooltip(`L${d.l} rank:${d.rank}<br>prob:${d.prob.toExponential(3)}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
// Legend
const leg = g.append('g').attr('transform',`translate(${w-120},6)`);
leg.append('line').attr('x2',16).attr('stroke',C.blue).attr('stroke-width',2);
leg.append('text').attr('x',20).attr('y',4).attr('fill',C.blue).attr('font-size',10).text('rank (log, left)');
leg.append('line').attr('y1',14).attr('x2',16).attr('y2',14).attr('stroke',C.green).attr('stroke-width',1.5).attr('stroke-dasharray','5,2');
leg.append('text').attr('x',20).attr('y',18).attr('fill',C.green).attr('font-size',10).text('prob (right)');
}
// ═══════════════════════════════════════════════
// LOCATE
// ═══════════════════════════════════════════════
async function runLocate() {
try {
const data = await api('/api/locate', {
prompt: document.getElementById('loc-prompt').value,
subject: document.getElementById('loc-subject').value,
target: document.getElementById('loc-target').value
});
drawTraceChart(data.trace, 'loc-trace-chart');
drawLocateBars(data.layer_scores, data.phase_layer);
const sumEl = document.getElementById('loc-summary');
sumEl.style.display='block';
document.getElementById('loc-summary-text').textContent =
`Phase transition at L${data.phase_layer}. Subject token pos: ${data.subject_pos}. ` +
`Peak activation sim: L${data.layer_scores.reduce((a,b)=>b.max_sim>a.max_sim?b:a,data.layer_scores[0])?.layer} = ` +
`${data.layer_scores.reduce((a,b)=>b.max_sim>a.max_sim?b:a,data.layer_scores[0])?.max_sim?.toFixed(4)}`;
} catch(e) { alert(e.message); }
}
function drawLocateBars(layerScores, phaseLayer) {
const el = document.getElementById('loc-bars');
el.innerHTML = '';
const maxSim = d3.max(layerScores, d=>d.max_sim) || 1;
layerScores.forEach(ls=>{
const div = document.createElement('div');
div.className = 'locate-bar';
div.style.border = ls.layer===phaseLayer ? '1px solid '+C.yellow : '1px solid transparent';
const pct = (ls.max_sim / maxSim * 100).toFixed(1);
div.innerHTML = `<span class="layer-lbl">L${ls.layer}</span>`+
`<div class="bar-fill" style="width:${pct}%;background:${ls.layer===phaseLayer?C.yellow:C.purple}"></div>`+
`<span class="sim-val">${ls.max_sim.toFixed(4)}</span>`;
div.onclick = ()=>{ showSlotDetail(ls); };
el.appendChild(div);
});
}
function showSlotDetail(ls) {
const el = document.getElementById('loc-slots');
el.style.display = 'block';
let txt = `L${ls.layer} slot=${ls.best_slot} max_sim=${ls.max_sim}\n\nTop decoded tokens:\n`;
ls.top_tokens.forEach(t=>{ txt+=` ${t.tok} (${t.score})\n`; });
el.textContent = txt;
}
// ═══════════════════════════════════════════════
// HEATMAP
// ═══════════════════════════════════════════════
async function runHeatmap() {
try {
const mode = document.querySelector('input[name="hm-mode"]:checked').value;
const data = await api('/api/gate_heatmap', {
entity: document.getElementById('hm-entity').value,
use_activation: mode==='activation',
prompt: document.getElementById('hm-prompt').value
});
drawHeatmap(data.layers);
} catch(e) { alert(e.message); }
}
function drawHeatmap(layers) {
const el = clearChart('hm-chart');
if(!layers.length){el.textContent='No data';return;}
const nLayers = layers.length;
const nSlots = d3.max(layers, l=>l.slots.length);
const W = el.clientWidth || 700;
const cellW = Math.max(18, Math.floor((W-60)/nSlots));
const cellH = 24;
const H = nLayers*cellH + 40;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const g = svg.append('g').attr('transform','translate(50,20)');
const allSims = layers.flatMap(l=>l.slots.map(s=>s.sim));
const color = d3.scaleSequential(d3.interpolateYlOrRd).domain([0, d3.max(allSims)]);
layers.forEach((layer,li)=>{
layer.slots.forEach((slot,si)=>{
const rect = g.append('rect')
.attr('class','hm-cell')
.attr('x', si*cellW).attr('y', li*cellH)
.attr('width', cellW-2).attr('height', cellH-2)
.attr('rx', 2)
.attr('fill', color(slot.sim))
.on('mousemove',(ev)=>showTooltip(
`L${layer.layer} slot ${slot.slot}<br>sim: ${slot.sim}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip)
.on('click',()=>showHmSlotDetail(layer.layer, slot));
});
// Row label
g.append('text').attr('x',-6).attr('y',li*cellH+cellH/2+4)
.attr('text-anchor','end').attr('fill',C.muted).attr('font-size',9).text('L'+layer.layer);
});
}
function showHmSlotDetail(layer, slot) {
const el = document.getElementById('hm-slot-detail');
let txt = `Layer ${layer} Slot ${slot.slot} sim=${slot.sim}\n\nDecoded top tokens:\n`;
slot.top_tokens.forEach(t=>{ txt+=` "${t.tok}" score=${t.score}\n`; });
el.textContent = txt;
}
// ═══════════════════════════════════════════════
// EDIT
// ═══════════════════════════════════════════════
document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{
r.addEventListener('change', ()=>{
document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none';
document.getElementById('smart-row').style.display = r.value==='SMART'?'block':'none';
document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none';
document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none';
});
});
async function runDryRun() {
try {
const data = await api('/api/dry_run', {
entity: document.getElementById('edit-entity').value,
new_target: document.getElementById('edit-new').value,
top_k: +document.getElementById('edit-k').value,
scale: +document.getElementById('edit-scale').value,
prompt: document.getElementById('edit-prompt').value || null
});
const el = document.getElementById('dryrun-log');
let txt = `Mode: ${data.mode} best_sim=${data.best_sim} would_patch=${data.would_patch}\n\n`;
data.candidates.forEach(c=>{
txt += `L${c.layer} slot ${c.slot} sim=${c.sim} col_norm=${c.col_norm}\n`;
txt += ` current top: ${c.current_top.map(t=>t.tok).join(', ')}\n`;
});
el.textContent = txt;
} catch(e) { alert(e.message); }
}
let _before = null;
async function runEdit() {
const mode = document.querySelector('input[name="edit-mode"]:checked').value;
const body = {
entity: document.getElementById('edit-entity').value,
relation: document.getElementById('edit-relation').value,
old_target: document.getElementById('edit-old').value,
new_target: document.getElementById('edit-new').value,
mode,
alpha: +document.getElementById('edit-alpha').value,
top_k: +document.getElementById('edit-k').value,
scale: +document.getElementById('edit-scale').value,
prompt: document.getElementById('edit-prompt').value || null,
from_concept:document.getElementById('ss-from').value,
to_concept: document.getElementById('ss-to').value,
strength: +document.getElementById('ss-strength').value,
};
if(mode==='SMART'){
try {
const r = await api('/api/smart_edit', {
prompt: document.getElementById('smart-prompt').value,
subject: document.getElementById('edit-entity').value,
relation: document.getElementById('edit-relation').value,
old_target: document.getElementById('smart-old').value,
new_target: document.getElementById('edit-new').value,
top_layers: +document.getElementById('smart-layers').value,
slots_per_layer: +document.getElementById('smart-slots').value,
scale: +document.getElementById('edit-scale').value,
noise_std: 0.1, alpha: 0.4, beta: 0.4, gamma: 0.2,
});
drawBeforeAfterChart(r.before, r.after);
let log = r.debug_log.join('\n');
log += '\n\nUsed layers:\n';
r.used_layers.forEach(l=>{ log+=` L${l.layer} slots=[${l.slots.join(',')}] combined=${l.combined}\n`; });
log += '\nDelta:\n';
r.delta.slice(0,8).forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β†’ ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; });
document.getElementById('edit-log').textContent = log;
updatePatchCount();
} catch(e) { alert(e.message); }
return;
}
if(mode==='MULTI-EDIT'){
try {
body.facts = JSON.parse(document.getElementById('multi-json').value);
const r = await api('/api/multi_edit', body.facts);
document.getElementById('edit-log').textContent = JSON.stringify(r, null, 2);
updatePatchCount();
} catch(e) { alert(e.message); }
return;
}
try {
const r = await api('/api/edit', body);
_before = r.before;
drawBeforeAfterChart(r.before, r.after);
let log = r.debug_log.join('\n');
log += '\n\nDelta:\n';
r.delta.forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β†’ ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; });
document.getElementById('edit-log').textContent = log;
updatePatchCount();
} catch(e) { alert(e.message); }
}
function drawBeforeAfterChart(before, after) {
const el = clearChart('edit-chart');
const tokens = [...new Set([...before.map(d=>d.token), ...after.map(d=>d.token)])];
const bMap = Object.fromEntries(before.map(d=>[d.token,d.prob]));
const aMap = Object.fromEntries(after.map(d=>[d.token,d.prob]));
const data = tokens.map(t=>({token:t,before:bMap[t]||0,after:aMap[t]||0}))
.sort((a,b)=>b.after-a.after);
const W = el.clientWidth || 700, H = 40+data.length*40;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const m = {left:110,right:10,top:20,bottom:10};
const w = W-m.left-m.right;
const maxP = d3.max([...before,...after],d=>d.prob);
const x = d3.scaleLinear().domain([0,maxP]).range([0,w]);
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
data.forEach((d,i)=>{
const y = i*40;
g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
.attr('fill',C.text).attr('font-size',12).text(d.token||'(empty)');
// before
g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.before))
.attr('height',14).attr('rx',2).attr('fill',C.blue).attr('opacity',.5);
// after
const delta = d.after - d.before;
g.append('rect').attr('x',0).attr('y',y+18).attr('width',x(d.after))
.attr('height',14).attr('rx',2).attr('fill',delta>=0?C.green:C.red).attr('opacity',.8);
g.append('text').attr('x',x(d.after)+4).attr('y',y+30)
.attr('fill',delta>=0?C.green:C.red).attr('font-size',10)
.text((delta>0?'+':'')+delta.toFixed(4));
});
}
// ═══════════════════════════════════════════════
// PATCHES
// ═══════════════════════════════════════════════
async function loadPatches() {
try {
const r = await api('/api/patches');
const el = document.getElementById('patches-list');
if(!r.patches.length){ el.innerHTML='<div style="color:var(--muted)">No patches.</div>'; return; }
el.innerHTML = r.patches.map((p,i)=>{
const type = p.type.replace('_',' ');
const cls = 'badge-'+p.type.replace('_UPDATE','').replace('PRECISE_','PRECISE');
return `<div class="patch-card">
<div style="display:flex;align-items:center">
<span class="patch-badge ${cls}">${type}</span>
<span>${p.entity}${p.relation?' Β· '+p.relation:''}${p.new_target?' β†’ '+p.new_target:''}</span>
<span style="color:var(--muted);margin-left:10px;font-size:10px">${p.ops_count} op(s)</span>
</div>
<button class="patch-del" onclick="deletePatch(${i})">βœ•</button>
</div>`;
}).join('');
} catch(e) { alert(e.message); }
}
async function deletePatch(i) {
try {
await fetch('/api/patches/'+i, {method:'DELETE'});
loadPatches(); updatePatchCount();
} catch(e) { alert(e.message); }
}
async function savePatches() {
const path = prompt('Save to path:', 'patches.json');
if(!path) return;
try { await api('/api/save', {path}); alert('Saved to '+path); }
catch(e) { alert(e.message); }
}
async function compileModel() {
const dir = prompt('Output directory:', './vindex_compiled');
if(!dir) return;
try { await api('/api/compile', {output_dir:dir}); alert('Compiled to '+dir); }
catch(e) { alert(e.message); }
}
async function resetModel() {
if(!confirm('Reset to base weights? All patches discarded.')) return;
try {
await api('/api/reset', {});
loadPatches(); updatePatchCount();
alert('Reset done.');
} catch(e) { alert(e.message); }
}
async function updatePatchCount() {
try {
const r = await api('/api/status');
setStatus(r.loaded, r.info.split(' | ')[0], r.patches_count);
} catch(e){}
}
// ═══════════════════════════════════════════════
// SMART LOCATE
// ═══════════════════════════════════════════════
let _slData = null;
async function runSmartLocate() {
const st = document.getElementById('sl-status');
st.textContent = '⏳ Running gradient pass + causal sweep (may take ~20s for large models)…';
try {
const data = await api('/api/smart_locate', {
prompt: document.getElementById('sl-prompt').value,
subject: document.getElementById('sl-subject').value,
target: document.getElementById('sl-target').value,
alpha: +document.getElementById('sl-alpha').value,
beta: +document.getElementById('sl-beta').value,
gamma: +document.getElementById('sl-gamma').value,
noise_std: +document.getElementById('sl-noise').value,
});
_slData = data;
st.textContent = `βœ“ Done. clean_prob=${data.clean_prob.toFixed(4)} corrupt_prob=${data.corrupt_prob.toFixed(4)}`;
drawSmartLocateChart(data.ranked_layers);
showSmartRec(data);
buildSlTable(data.ranked_layers);
} catch(e) { st.textContent = 'βœ— '+e.message; }
}
async function runGradientOnly() {
const st = document.getElementById('sl-status');
st.textContent = '⏳ Running gradient pass…';
try {
const data = await api('/api/gradient_scores', {
prompt: document.getElementById('sl-prompt').value,
target: document.getElementById('sl-target').value,
});
st.textContent = `βœ“ Gradient done. ${data.layer_scores.length} KB layers.`;
// Draw gradient-only bars
drawGradOnlyChart(data.layer_scores);
} catch(e) { st.textContent = 'βœ— '+e.message; }
}
async function runCausalOnly() {
const st = document.getElementById('sl-status');
st.textContent = '⏳ Running causal patch trace…';
try {
const data = await api('/api/causal_trace', {
prompt: document.getElementById('sl-prompt').value,
subject: document.getElementById('sl-subject').value,
target: document.getElementById('sl-target').value,
noise_std: +document.getElementById('sl-noise').value,
});
st.textContent = `βœ“ Causal done. clean=${data.clean_prob.toFixed(4)} corrupt=${data.corrupt_prob.toFixed(4)}`;
drawCausalOnlyChart(data.results);
} catch(e) { st.textContent = 'βœ— '+e.message; }
}
async function runCollateral() {
const prompt = document.getElementById('sl-coll-prompt').value;
try {
const data = await api('/api/infer', { prompt, top_k: 5 });
const el = document.getElementById('sl-coll-out');
el.textContent = `"${prompt}"\n` +
data.results.map(r=>` ${r.token.padEnd(18)} ${r.prob.toFixed(4)}`).join('\n');
} catch(e) { document.getElementById('sl-coll-out').textContent = 'βœ— '+e.message; }
}
function drawSmartLocateChart(ranked) {
// Sort by layer for chart display
const byLayer = [...ranked].sort((a,b)=>a.layer-b.layer);
const el = clearChart('sl-chart');
const W = el.clientWidth || 700, H = 40 + byLayer.length * 34;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const m = {left:50,right:110,top:20,bottom:20};
const w = W-m.left-m.right;
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
// Each bar = 3 stacked segments (normalized: gate_sim_n, grad_norm_n, causal_n)
// Each segment width = signal_n * (w/3) so max of each is w/3
const segW = w / 3;
byLayer.forEach((d,i)=>{
const y = i*34;
// Label
g.append('text').attr('x',-6).attr('y',y+17).attr('text-anchor','end')
.attr('fill', d.layer===(_slData?.recommendation?.layer) ? C.yellow : C.muted)
.attr('font-size',10).text('L'+d.layer);
// gate_sim segment
g.append('rect').attr('x',0).attr('y',y+4).attr('width',d.gate_sim_n*segW)
.attr('height',12).attr('rx',2).attr('fill',C.blue).attr('opacity',.8)
.on('mousemove',(ev)=>showTooltip(`L${d.layer} gate_sim: ${d.gate_sim}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
// grad_norm segment
g.append('rect').attr('x',segW).attr('y',y+4).attr('width',d.grad_norm_n*segW)
.attr('height',12).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
.on('mousemove',(ev)=>showTooltip(`L${d.layer} grad_norm: ${d.grad_norm}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
// causal segment
g.append('rect').attr('x',segW*2).attr('y',y+4).attr('width',d.causal_n*segW)
.attr('height',12).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
.on('mousemove',(ev)=>showTooltip(`L${d.layer} causal_IE: ${d.causal_effect}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
// combined score label
g.append('text').attr('x',w+6).attr('y',y+14)
.attr('fill', d.combined===Math.max(...ranked.map(r=>r.combined)) ? C.yellow : C.muted)
.attr('font-size',10).text(d.combined.toFixed(3));
});
// Axis labels
const ax = g.append('g').attr('transform',`translate(0,${byLayer.length*34})`);
ax.append('text').attr('x',segW/2).attr('y',14).attr('text-anchor','middle')
.attr('fill',C.blue).attr('font-size',9).text('gate_sim');
ax.append('text').attr('x',segW*1.5).attr('y',14).attr('text-anchor','middle')
.attr('fill',C.green).attr('font-size',9).text('grad_norm');
ax.append('text').attr('x',segW*2.5).attr('y',14).attr('text-anchor','middle')
.attr('fill',C.yellow).attr('font-size',9).text('causal IE');
// Section dividers
[segW,segW*2].forEach(x=>{
g.append('line').attr('x1',x).attr('x2',x).attr('y1',0).attr('y2',byLayer.length*34)
.attr('stroke',C.border).attr('stroke-width',1).attr('stroke-dasharray','3,2');
});
}
function drawGradOnlyChart(layerScores) {
const el = clearChart('sl-chart');
const W = el.clientWidth || 700, H = 40 + layerScores.length * 28;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const m = {left:50,right:80,top:20,bottom:10};
const w = W-m.left-m.right;
const maxG = d3.max(layerScores, d=>d.max_grad) || 1;
const x = d3.scaleLinear().domain([0,maxG]).range([0,w]);
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
layerScores.forEach((d,i)=>{
const y=i*28;
g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
.attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.max_grad))
.attr('height',16).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
.on('mousemove',(ev)=>showTooltip(`L${d.layer} max_grad: ${d.max_grad}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
g.append('text').attr('x',x(d.max_grad)+4).attr('y',y+14)
.attr('fill',C.green).attr('font-size',9).text(d.max_grad.toExponential(2));
});
}
function drawCausalOnlyChart(results) {
const el = clearChart('sl-chart');
const W = el.clientWidth || 700, H = 40 + results.length * 28;
el.style.height = H+'px';
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
const m = {left:50,right:80,top:20,bottom:10};
const w = W-m.left-m.right;
const maxIE = Math.max(d3.max(results, d=>d.indirect_effect), 0.001);
const x = d3.scaleLinear().domain([0,maxIE]).range([0,w]);
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
results.forEach((d,i)=>{
const y=i*28; const ie=Math.max(0,d.indirect_effect);
g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
.attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(ie))
.attr('height',16).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
.on('mousemove',(ev)=>showTooltip(`L${d.layer} IE: ${d.indirect_effect} patch_p: ${d.patch_prob}`,ev.pageX,ev.pageY))
.on('mouseleave',hideTooltip);
g.append('text').attr('x',x(ie)+4).attr('y',y+14)
.attr('fill',C.yellow).attr('font-size',9).text(d.indirect_effect.toFixed(5));
});
}
function showSmartRec(data) {
const rec = data.recommendation;
if(!rec){ document.getElementById('sl-rec').textContent='No recommendation.'; return; }
let txt = `β˜… Best layer: L${rec.layer} combined=${rec.combined}\n\n`;
txt += ` gate_sim: ${rec.gate_sim} (norm ${rec.gate_sim_n})\n`;
txt += ` grad_norm: ${rec.grad_norm} (norm ${rec.grad_norm_n})\n`;
txt += ` causal_effect: ${rec.causal_effect} (norm ${rec.causal_n})\n`;
if(rec.best_slots.length){
txt += `\nTop gradient slots in L${rec.layer}:\n`;
rec.best_slots.forEach(s=>{ txt+=` slot ${s.slot} grad_norm=${s.grad_norm}\n`; });
}
txt += `\nPhase layer (trace): L${data.phase_layer}\n`;
txt += `Subject pos: ${data.subject_pos}\n`;
txt += `clean_prob: ${data.clean_prob} corrupt_prob: ${data.corrupt_prob}`;
document.getElementById('sl-rec').textContent = txt;
}
function buildSlTable(ranked) {
const el = document.getElementById('sl-table');
const maxC = Math.max(...ranked.map(r=>r.combined));
let html = `<table style="width:100%;border-collapse:collapse;font-size:11px">
<thead><tr style="color:var(--muted);border-bottom:1px solid var(--border)">
<th style="padding:4px 8px;text-align:left">Layer</th>
<th style="padding:4px 8px;text-align:right;color:${C.blue}">gate_sim</th>
<th style="padding:4px 8px;text-align:right;color:${C.green}">grad_norm</th>
<th style="padding:4px 8px;text-align:right;color:${C.yellow}">causal IE</th>
<th style="padding:4px 8px;text-align:right">combined β˜…</th>
<th style="padding:4px 8px;text-align:left;color:var(--muted)">top grad slots</th>
</tr></thead><tbody>`;
ranked.forEach(r=>{
const hi = r.combined===maxC ? `background:rgba(210,153,34,0.08)` : '';
const slots = r.best_slots.slice(0,3).map(s=>s.slot).join(', ');
html+=`<tr style="${hi};border-bottom:1px solid var(--border)">
<td style="padding:4px 8px;color:${r.combined===maxC?C.yellow:C.text}">L${r.layer}</td>
<td style="padding:4px 8px;text-align:right;color:${C.blue}">${r.gate_sim}</td>
<td style="padding:4px 8px;text-align:right;color:${C.green}">${r.grad_norm.toExponential(2)}</td>
<td style="padding:4px 8px;text-align:right;color:${C.yellow}">${r.causal_effect.toFixed(5)}</td>
<td style="padding:4px 8px;text-align:right;font-weight:700">${r.combined}</td>
<td style="padding:4px 8px;color:var(--muted)">${slots}</td>
</tr>`;
});
html += '</tbody></table>';
el.innerHTML = html;
}
// ═══════════════════════════════════════════════
// INIT
// ═══════════════════════════════════════════════
checkStatus();
</script>
</body>
</html>
"""
# ══════════════════════════════════════════════════════════════
# FASTAPI (Phase 2)
# ══════════════════════════════════════════════════════════════
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import uvicorn
app = FastAPI(title="VINDEX")
_vi: VIndex | None = None
def _require():
if _vi is None:
raise HTTPException(status_code=400, detail="No model loaded. POST /api/load first.")
return _vi
# ── Request models ─────────────────────────────────────────────
class LoadReq(BaseModel):
model_name: str = "distilgpt2"
device: str = "auto"
class InferReq(BaseModel):
prompt: str
top_k: int = 10
class DescribeReq(BaseModel):
entity: str
top_k: int = 10
class TraceReq(BaseModel):
prompt: str
target: str
class LocateReq(BaseModel):
prompt: str
subject: str
target: str
class HeatmapReq(BaseModel):
entity: str
top_slots: int = 20
use_activation: bool = False
prompt: Optional[str] = None
class GradientReq(BaseModel):
prompt: str
target: str
class CausalTraceReq(BaseModel):
prompt: str
subject: str
target: str
noise_std: float = 0.1
class SmartLocateReq(BaseModel):
prompt: str
subject: str
target: str
alpha: float = 0.4
beta: float = 0.3
gamma: float = 0.3
noise_std: float = 0.1
class SmartEditReq(BaseModel):
prompt: str
subject: str
relation: str = ""
old_target: str
new_target: str
top_layers: int = 3
slots_per_layer: int = 2
scale: float = 1.5
noise_std: float = 0.1
alpha: float = 0.4
beta: float = 0.4
gamma: float = 0.2
class DryRunReq(BaseModel):
entity: str
new_target: str
top_k: int = 3
scale: float = 1.0
prompt: Optional[str] = None
class EditReq(BaseModel):
entity: str
relation: str = ""
old_target: str = ""
new_target: str
mode: str = "UPDATE"
alpha: float = 0.25
top_k: int = 3
scale: float = 1.0
prompt: Optional[str] = None
from_concept: str = ""
to_concept: str = ""
strength: float = 0.5
class SuppressReq(BaseModel):
entity: str
top_k: int = 3
factor: float = 0.0
class AmplifyReq(BaseModel):
entity: str
top_k: int = 3
factor: float = 2.0
class StyleShiftReq(BaseModel):
anchor: str
from_concept: str
to_concept: str
top_k: int = 3
strength: float = 0.5
class SaveReq(BaseModel):
path: str = "patches.json"
class CompileReq(BaseModel):
output_dir: str = "./vindex_compiled"
class MultiEditFact(BaseModel):
entity: str
relation: str = ""
new_target: str
prompt: Optional[str] = None
# ── Endpoints ──────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
async def root():
return HTML_PAGE
@app.post("/api/load")
async def api_load(req: LoadReq):
global _vi
device = None if req.device == "auto" else req.device
try:
_vi = VIndex(req.model_name, device=device)
return {"ok": True, "info": _vi.info,
"n_layers": _vi.arch.n_layers,
"kb_start": _vi.kb_start, "kb_end": _vi.kb_end}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/status")
async def api_status():
if _vi is None:
return {"loaded": False, "info": "", "patches_count": 0, "kb_start": 0, "kb_end": 0, "n_layers": 0}
return {"loaded": True, "info": _vi.info, "patches_count": len(_vi.patches),
"kb_start": _vi.kb_start, "kb_end": _vi.kb_end, "n_layers": _vi.arch.n_layers}
@app.post("/api/infer")
async def api_infer(req: InferReq):
vi = _require()
return {"results": vi.infer(req.prompt, top_k=req.top_k)}
@app.post("/api/describe")
async def api_describe(req: DescribeReq):
vi = _require()
edges = vi.describe(req.entity, top_k=req.top_k)
out = [{"tok":e["tok"],"score":round(e["score"],2),"layer":e["layer"],"gate_sim":round(e["gate_sim"],4)}
for e in edges]
return {"edges": out}
@app.post("/api/trace")
async def api_trace(req: TraceReq):
vi = _require()
return {"stats": vi.trace(req.prompt, req.target)}
@app.post("/api/locate")
async def api_locate(req: LocateReq):
vi = _require()
return vi.locate(req.prompt, req.subject, req.target)
@app.post("/api/gradient_scores")
async def api_gradient_scores(req: GradientReq):
vi = _require()
return vi.gradient_slot_scores(req.prompt, req.target)
@app.post("/api/causal_trace")
async def api_causal_trace(req: CausalTraceReq):
vi = _require()
return vi.causal_patch_trace(req.prompt, req.subject, req.target,
noise_std=req.noise_std)
@app.post("/api/smart_locate")
async def api_smart_locate(req: SmartLocateReq):
vi = _require()
return vi.smart_locate(req.prompt, req.subject, req.target,
alpha=req.alpha, beta=req.beta, gamma=req.gamma,
noise_std=req.noise_std)
@app.post("/api/smart_edit")
async def api_smart_edit(req: SmartEditReq):
vi = _require()
prompt_str = req.prompt or f"The {req.relation} of {req.subject} is"
before = vi.infer(prompt_str, top_k=5)
log: List[str] = []
try:
result = vi.smart_edit(
prompt_str, req.subject, req.relation, req.old_target, req.new_target,
top_layers=req.top_layers, slots_per_layer=req.slots_per_layer,
scale=req.scale, noise_std=req.noise_std,
alpha=req.alpha, beta=req.beta, gamma=req.gamma, log=log
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
after = vi.infer(prompt_str, top_k=5)
b_map = {d["token"]: d["prob"] for d in before}
a_map = {d["token"]: d["prob"] for d in after}
all_toks = set(b_map) | set(a_map)
delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0),
"delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks],
key=lambda x: -abs(x["delta"]))
return {"before": before, "after": after, "delta": delta,
"debug_log": log, "used_layers": result["used_layers"],
"smart_locate": result["smart_locate"]}
@app.post("/api/gate_heatmap")
async def api_gate_heatmap(req: HeatmapReq):
vi = _require()
return vi.gate_heatmap(req.entity, use_activation=req.use_activation, prompt=req.prompt)
@app.post("/api/dry_run")
async def api_dry_run(req: DryRunReq):
vi = _require()
return vi.dry_run(req.entity, req.new_target, top_k=req.top_k, scale=req.scale, prompt=req.prompt)
@app.post("/api/edit")
async def api_edit(req: EditReq):
vi = _require()
prompt_str = f"The {req.relation} of {req.entity} is"
before = vi.infer(prompt_str, top_k=5)
log: List[str] = []
try:
mode = req.mode.upper()
if mode == "PRECISE":
vi.precise_update(req.prompt or prompt_str, req.entity, req.relation,
req.new_target, top_k=req.top_k, scale=req.scale, log=log)
elif mode == "INSERT":
vi.insert(req.entity, req.relation, req.new_target,
alpha=req.alpha, spread=req.top_k, log=log)
elif mode == "SUPPRESS":
vi.suppress(req.entity, top_k=req.top_k, log=log)
elif mode == "AMPLIFY":
vi.amplify(req.entity, top_k=req.top_k, log=log)
elif mode == "STYLE-SHIFT":
vi.style_shift(req.entity, req.from_concept, req.to_concept,
top_k=req.top_k, strength=req.strength, log=log)
else: # UPDATE
vi.update(req.entity, req.relation, req.new_target,
top_k=req.top_k, scale=req.scale, log=log)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
after = vi.infer(prompt_str, top_k=5)
b_map = {d["token"]: d["prob"] for d in before}
a_map = {d["token"]: d["prob"] for d in after}
all_toks = set(b_map) | set(a_map)
delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0),
"delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks],
key=lambda x: -abs(x["delta"]))
return {"before": before, "after": after, "delta": delta, "debug_log": log,
"ops": len(vi.patches[-1]["ops"]) if vi.patches else 0}
@app.post("/api/multi_edit")
async def api_multi_edit(facts: List[MultiEditFact]):
vi = _require()
return vi.multi_edit([f.model_dump() for f in facts])
@app.post("/api/suppress")
async def api_suppress(req: SuppressReq):
vi = _require()
return vi.suppress(req.entity, top_k=req.top_k, factor=req.factor)
@app.post("/api/amplify")
async def api_amplify(req: AmplifyReq):
vi = _require()
return vi.amplify(req.entity, top_k=req.top_k, factor=req.factor)
@app.post("/api/style_shift")
async def api_style_shift(req: StyleShiftReq):
vi = _require()
return vi.style_shift(req.anchor, req.from_concept, req.to_concept,
top_k=req.top_k, strength=req.strength)
@app.get("/api/patches")
async def api_patches():
vi = _require()
out = []
for i, p in enumerate(vi.patches):
out.append({
"i": i, "type": p["type"],
"entity": p.get("entity",""),
"relation": p.get("relation",""),
"new_target": p.get("new_target","") or p.get("target",""),
"ops_count": len(p.get("ops",[]))
})
return {"patches": out}
@app.delete("/api/patches/{idx}")
async def api_delete_patch(idx: int):
vi = _require()
if idx < 0 or idx >= len(vi.patches):
raise HTTPException(status_code=404, detail="Patch index out of range")
vi.patches.pop(idx)
vi._apply_all_patches()
return {"ok": True}
@app.post("/api/reset")
async def api_reset():
vi = _require()
vi.reset()
return {"ok": True}
@app.post("/api/save")
async def api_save(req: SaveReq):
vi = _require()
vi.save_patch(req.path)
return {"ok": True, "path": req.path}
@app.post("/api/compile")
async def api_compile(req: CompileReq):
vi = _require()
vi.compile_to(req.output_dir)
return {"ok": True, "output_dir": req.output_dir}
# ══════════════════════════════════════════════════════════════
# ENTRY
# ══════════════════════════════════════════════════════════════
if __name__ == "__main__":
import argparse, webbrowser, threading, time
ap = argparse.ArgumentParser()
ap.add_argument("--port", type=int, default=8787)
ap.add_argument("--model", default=None)
ap.add_argument("--device", default=None)
ap.add_argument("--no-browser", action="store_true")
args, _ = ap.parse_known_args()
if args.model:
print(f"Pre-loading {args.model}…")
_vi = VIndex(args.model, device=args.device)
print("Done.", _vi.info)
if not args.no_browser:
def _open():
time.sleep(1.2)
webbrowser.open(f"http://localhost:{args.port}")
threading.Thread(target=_open, daemon=True).start()
uvicorn.run(app, host="0.0.0.0", port=args.port)