| |
| """ |
| VINDEX β FastAPI + D3 single-file LLM knowledge editor |
| Phase 1: Engine extensions | Phase 2: FastAPI | Phase 3: D3 frontend |
| |
| pip install transformers torch fastapi uvicorn pydantic |
| python temp.py |
| """ |
|
|
| |
| |
| |
|
|
| import re, json |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
| class ArchAdapter: |
| def __init__(self, model): |
| self.model = model |
| self.style = self._detect_style() |
| self.n_layers = self._count_layers() |
|
|
| def _detect_style(self): |
| t = self.model.config.model_type |
| if t in ("gpt2","gpt_neo","gpt_neox"): return "gpt2" |
| if t in ("llama","mistral","qwen2","gemma","gemma2","phi3", |
| "falcon","codellama","deepseek","internlm2"): return "gated" |
| try: |
| if hasattr(self._layer(0).mlp, "gate_proj"): return "gated" |
| except: pass |
| return "gpt2" |
|
|
| def _layer(self, i): |
| m = self.model |
| if hasattr(m,"transformer") and hasattr(m.transformer,"h"): |
| return m.transformer.h[i] |
| if hasattr(m,"model") and hasattr(m.model,"layers"): |
| return m.model.layers[i] |
| raise ValueError("Unknown model structure") |
|
|
| def _count_layers(self): |
| m = self.model |
| if hasattr(m,"transformer") and hasattr(m.transformer,"h"): |
| return len(m.transformer.h) |
| if hasattr(m,"model") and hasattr(m.model,"layers"): |
| return len(m.model.layers) |
| raise ValueError("Cannot count layers") |
|
|
| def get_ffn_weights(self, li): |
| layer = self._layer(li) |
| if self.style == "gpt2": |
| return layer.mlp.c_fc.weight.detach().T, layer.mlp.c_proj.weight.detach().T |
| return layer.mlp.gate_proj.weight.detach(), layer.mlp.down_proj.weight.detach() |
|
|
| def set_ffn_weights(self, li, Wg, Wd): |
| layer = self._layer(li) |
| with torch.no_grad(): |
| if self.style == "gpt2": |
| layer.mlp.c_fc.weight.copy_(Wg.T) |
| layer.mlp.c_proj.weight.copy_(Wd.T) |
| else: |
| layer.mlp.gate_proj.weight.copy_(Wg) |
| layer.mlp.down_proj.weight.copy_(Wd) |
|
|
| def get_embedding(self): |
| return self.model.get_input_embeddings().weight.detach() |
|
|
| def get_unembedding(self): |
| if hasattr(self.model,"lm_head"): |
| return self.model.lm_head.weight.detach() |
| return self.get_embedding() |
|
|
|
|
| class VIndex: |
| def __init__(self, model_name: str, device: Optional[str] = None): |
| self.model_name = model_name |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.patches: List[Dict] = [] |
| self._base_weights: Optional[Dict] = None |
|
|
| self.tok = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, torch_dtype=torch.float32).to(self.device) |
| self.model.eval() |
| if self.tok.pad_token is None: |
| self.tok.pad_token = self.tok.eos_token |
|
|
| self.arch = ArchAdapter(self.model) |
| n = self.arch.n_layers |
| self.kb_start = n // 3 |
| self.kb_end = n |
|
|
| @property |
| def info(self): |
| return (f"{self.model_name} | {self.arch.n_layers} layers | " |
| f"style={self.arch.style} | kb=L{self.kb_start}-L{self.kb_end-1} | {self.device}") |
|
|
| |
|
|
| def embed(self, text: str): |
| ids = self.tok.encode(text, add_special_tokens=False) |
| if not ids: raise ValueError(f"Cannot tokenize: {text!r}") |
| E = self.arch.get_embedding().to(self.device) |
| return E[ids].mean(0) |
|
|
| def decode_down_col(self, col, top_k=5): |
| scores = self.arch.get_unembedding().to(self.device) @ col.to(self.device) |
| top = scores.topk(top_k) |
| return [(self.tok.decode([i.item()]).strip(), v.item()) |
| for i,v in zip(top.indices,top.values) if v.item()>0] |
|
|
| def token_id(self, word: str) -> int: |
| ids = self.tok.encode(word, add_special_tokens=False) |
| return ids[0] if ids else 0 |
|
|
| def _forward(self, prompt: str): |
| inputs = self.tok(prompt, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| out = self.model(**inputs) |
| return out.logits[0,-1] |
|
|
| |
|
|
| def _get_subject_activations(self, prompt: str, subject: str |
| ) -> Tuple[Dict[int, torch.Tensor], int]: |
| """Capture h_L[last_subject_token_pos] at every layer via forward hooks.""" |
| enc = self.tok(prompt, return_tensors="pt").to(self.device) |
| ids = enc["input_ids"][0].tolist() |
|
|
| |
| subj_ids = self.tok.encode(subject, add_special_tokens=False) |
| subject_pos = 0 |
| for start in range(len(ids) - len(subj_ids) + 1): |
| if ids[start:start+len(subj_ids)] == subj_ids: |
| subject_pos = start + len(subj_ids) - 1 |
| break |
| else: |
| |
| for si in subj_ids: |
| if si in ids: |
| subject_pos = ids.index(si) |
| break |
|
|
| activations: Dict[int, torch.Tensor] = {} |
| handles = [] |
|
|
| def make_hook(li): |
| def hook(m, inp, out): |
| h = out[0] if isinstance(out, tuple) else out |
| activations[li] = h[0, subject_pos].detach().clone() |
| return hook |
|
|
| for li in range(self.arch.n_layers): |
| handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) |
|
|
| with torch.no_grad(): |
| self.model(**enc) |
|
|
| for h in handles: |
| h.remove() |
|
|
| return activations, subject_pos |
|
|
| def locate(self, prompt: str, subject: str, target: str) -> Dict: |
| """Diagnostic: combines trace + activation-guided similarity scan.""" |
| trace_stats = self.trace(prompt, target) |
|
|
| |
| phase_layer = 0 |
| best_drop = 0.0 |
| prev_rank = None |
| for s in trace_stats: |
| if prev_rank is not None and prev_rank > 5: |
| drop = (prev_rank - s["rank"]) / prev_rank |
| if drop > best_drop: |
| best_drop = drop |
| phase_layer = s["l"] |
| prev_rank = s["rank"] |
|
|
| activations, subject_pos = self._get_subject_activations(prompt, subject) |
|
|
| layer_scores = [] |
| for li in range(self.kb_start, self.kb_end): |
| h_L = activations.get(li) |
| if h_L is None: |
| layer_scores.append({"layer": li, "max_sim": 0.0, "best_slot": -1, "top_tokens": []}) |
| continue |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| h_n = F.normalize(h_L, dim=0) |
| sims = F.normalize(Wg, dim=1) @ h_n |
| best_slot = int(sims.argmax().item()) |
| max_sim = float(sims[best_slot].item()) |
| Wd = Wd.to(self.device) |
| top_tokens = self.decode_down_col(Wd[:, best_slot], top_k=3) |
| layer_scores.append({ |
| "layer": li, "max_sim": round(max_sim, 4), |
| "best_slot": best_slot, |
| "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_tokens] |
| }) |
|
|
| return { |
| "phase_layer": phase_layer, |
| "subject_pos": subject_pos, |
| "layer_scores": layer_scores, |
| "trace": trace_stats, |
| } |
|
|
| def precise_update(self, prompt: str, subject: str, relation: str, |
| new_target: str, top_k: int = 3, scale: float = 1.0, |
| log: Optional[List[str]] = None): |
| """Activation-guided update: uses h_L[subject_pos] instead of embed(subject).""" |
| if log is None: log = [] |
| if not prompt: |
| return self.update(subject, relation, new_target, top_k=top_k, scale=scale, log=log) |
|
|
| self._snapshot() |
| activations, subject_pos = self._get_subject_activations(prompt, subject) |
| tv = self.embed(new_target) |
|
|
| log.append(f"PRECISE_UPDATE: '{subject}' -[{relation}]-> '{new_target}'") |
| log.append(f" subject_pos={subject_pos} scale={scale} top_k={top_k}") |
|
|
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| h_L = activations.get(li) |
| if h_L is None: continue |
| Wg, _ = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| h_n = F.normalize(h_L, dim=0) |
| sims = F.normalize(Wg, dim=1) @ h_n |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| candidates.append((v.item(), li, idx.item())) |
| log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) |
|
|
| candidates.sort(key=lambda x: -x[0]) |
| best_sim = candidates[0][0] if candidates else 0.0 |
| log.append(f"\n Best activation_sim = {best_sim:.4f} (embed-based would be ~{best_sim/3.5:.4f})") |
|
|
| if best_sim < 0.05: |
| log.append(" β sim < 0.05 β INSERT fallback") |
| return self.insert(subject, relation, new_target, log=log) |
|
|
| chosen = [c for c in candidates if c[0] >= 0.05][:top_k] |
| ops = [] |
| for sim, li, slot in chosen: |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| col_norm = Wd[:, slot].norm().item() |
| new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot, |
| "down_col":new_col,"activation_sim":round(sim,4)}) |
| log.append(f" β L{li} slot {slot}: act_sim={sim:.4f} col_norm={col_norm:.4f}") |
|
|
| self.patches.append({"type":"PRECISE_UPDATE","entity":subject,"relation":relation, |
| "new_target":new_target,"ops":ops}) |
| self._apply_all_patches() |
| log.append(f"\n β Applied {len(ops)} op(s), patch #{len(self.patches)}") |
| return ops |
|
|
| def suppress(self, entity: str, top_k: int = 3, factor: float = 0.0, |
| log: Optional[List[str]] = None): |
| """Scale down (or zero) matching down columns.""" |
| if log is None: log = [] |
| self._snapshot() |
| ev_n = F.normalize(self.embed(entity), dim=0) |
| log.append(f"SUPPRESS: '{entity}' factor={factor} top_k={top_k}") |
| ops = [] |
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, _ = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| sims = F.normalize(Wg, dim=1) @ ev_n |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| candidates.append((v.item(), li, idx.item())) |
| candidates.sort(key=lambda x: -x[0]) |
| chosen = [c for c in candidates if c[0] >= 0.05][:top_k] |
| for sim, li, slot in chosen: |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| new_col = (Wd[:, slot] * factor).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) |
| log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") |
| self.patches.append({"type":"SUPPRESS","entity":entity,"factor":factor,"ops":ops}) |
| self._apply_all_patches() |
| log.append(f" β Suppressed {len(ops)} slot(s)") |
| return {"ops": len(ops), "log": log} |
|
|
| def amplify(self, entity: str, top_k: int = 3, factor: float = 2.0, |
| log: Optional[List[str]] = None): |
| """Scale up matching down columns.""" |
| if log is None: log = [] |
| self._snapshot() |
| ev_n = F.normalize(self.embed(entity), dim=0) |
| log.append(f"AMPLIFY: '{entity}' factor={factor} top_k={top_k}") |
| ops = [] |
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, _ = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| sims = F.normalize(Wg, dim=1) @ ev_n |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| candidates.append((v.item(), li, idx.item())) |
| candidates.sort(key=lambda x: -x[0]) |
| chosen = [c for c in candidates if c[0] >= 0.05][:top_k] |
| for sim, li, slot in chosen: |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| new_col = (Wd[:, slot] * factor).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) |
| log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") |
| self.patches.append({"type":"AMPLIFY","entity":entity,"factor":factor,"ops":ops}) |
| self._apply_all_patches() |
| log.append(f" β Amplified {len(ops)} slot(s)") |
| return {"ops": len(ops), "log": log} |
|
|
| def style_shift(self, anchor_entity: str, from_concept: str, to_concept: str, |
| top_k: int = 3, strength: float = 0.5, |
| log: Optional[List[str]] = None): |
| """Add a direction vector to matching down columns.""" |
| if log is None: log = [] |
| self._snapshot() |
| ev_n = F.normalize(self.embed(anchor_entity), dim=0) |
| from_v = self.embed(from_concept) |
| to_v = self.embed(to_concept) |
| dir_v = F.normalize(to_v - from_v, dim=0) |
| log.append(f"STYLE_SHIFT: anchor='{anchor_entity}' {from_concept!r}β{to_concept!r} strength={strength}") |
| ops = [] |
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, _ = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| sims = F.normalize(Wg, dim=1) @ ev_n |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| candidates.append((v.item(), li, idx.item())) |
| candidates.sort(key=lambda x: -x[0]) |
| chosen = [c for c in candidates if c[0] >= 0.05][:top_k] |
| for sim, li, slot in chosen: |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| col = Wd[:, slot] |
| new_col = (col + dir_v * col.norm() * strength).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) |
| log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} col_norm={col.norm():.4f}") |
| self.patches.append({"type":"STYLE_SHIFT","entity":anchor_entity, |
| "from":from_concept,"to":to_concept,"ops":ops}) |
| self._apply_all_patches() |
| log.append(f" β Style-shifted {len(ops)} slot(s)") |
| return {"ops": len(ops), "log": log} |
|
|
| def multi_edit(self, facts: List[Dict], mode: str = "UPDATE", |
| alpha: float = 0.25, top_k: int = 3, scale: float = 1.0): |
| """Apply a batch of edits sequentially.""" |
| results = [] |
| for f in facts: |
| log: List[str] = [] |
| try: |
| entity = f["entity"] |
| relation = f.get("relation", "") |
| new_target = f["new_target"] |
| prompt = f.get("prompt", "") |
| if mode == "PRECISE" and prompt: |
| ops = self.precise_update(prompt, entity, relation, new_target, |
| top_k=top_k, scale=scale, log=log) |
| elif mode == "INSERT": |
| ops = self.insert(entity, relation, new_target, |
| alpha=alpha, spread=top_k, log=log) |
| else: |
| ops = self.update(entity, relation, new_target, |
| top_k=top_k, scale=scale, log=log) |
| results.append({"entity":entity,"status":"ok", |
| "ops": len(ops) if isinstance(ops,(list,)) else 1, |
| "log":log}) |
| except Exception as e: |
| results.append({"entity":f.get("entity","?"),"status":"error", |
| "error":str(e),"log":log}) |
| return results |
|
|
| def gate_heatmap(self, entity: str, use_activation: bool = False, |
| prompt: Optional[str] = None) -> Dict: |
| """Full layerΓslot similarity matrix with top token decoding.""" |
| if use_activation and prompt: |
| activations, _ = self._get_subject_activations(prompt, entity) |
| else: |
| activations = {} |
|
|
| ev_n = F.normalize(self.embed(entity), dim=0) |
| layers_out = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device); Wd = Wd.to(self.device) |
| if use_activation and li in activations: |
| q = F.normalize(activations[li], dim=0) |
| else: |
| q = ev_n |
| sims = F.normalize(Wg, dim=1) @ q |
| top_slots_count = min(20, sims.shape[0]) |
| vals, idxs = sims.topk(top_slots_count) |
| slots = [] |
| for v, idx in zip(vals, idxs): |
| top_toks = self.decode_down_col(Wd[:, idx], top_k=3) |
| slots.append({ |
| "slot": int(idx.item()), |
| "sim": round(float(v.item()), 4), |
| "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_toks] |
| }) |
| layers_out.append({"layer": li, "slots": slots}) |
| return {"layers": layers_out} |
|
|
| def dry_run(self, entity: str, new_target: str, top_k: int = 3, |
| scale: float = 1.0, prompt: Optional[str] = None) -> Dict: |
| """Same logic as precise_update/update but does NOT mutate weights.""" |
| if prompt: |
| activations, subject_pos = self._get_subject_activations(prompt, entity) |
| use_act = True |
| else: |
| use_act = False |
|
|
| ev_n = F.normalize(self.embed(entity), dim=0) |
| tv = self.embed(new_target) |
|
|
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device); Wd = Wd.to(self.device) |
| if use_act and li in activations: |
| q = F.normalize(activations[li], dim=0) |
| else: |
| q = ev_n |
| sims = F.normalize(Wg, dim=1) @ q |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| col_norm = Wd[:, idx].norm().item() |
| top_toks = self.decode_down_col(Wd[:, idx], top_k=3) |
| candidates.append({ |
| "layer": li, "slot": int(idx.item()), |
| "sim": round(float(v.item()), 4), |
| "col_norm": round(col_norm, 4), |
| "inject_norm": round(col_norm * scale, 4), |
| "current_top": [{"tok":t,"score":round(s,2)} for t,s in top_toks] |
| }) |
|
|
| candidates.sort(key=lambda x: -x["sim"]) |
| best_sim = candidates[0]["sim"] if candidates else 0.0 |
| chosen = [c for c in candidates if c["sim"] >= 0.05][:top_k] |
|
|
| return { |
| "candidates": chosen, |
| "best_sim": best_sim, |
| "would_patch": len(chosen), |
| "new_target": new_target, |
| "mode": "activation-guided" if use_act else "embed-based" |
| } |
|
|
| |
|
|
| def gradient_slot_scores(self, prompt: str, target: str) -> Dict: |
| """One backward pass: grad norm of β(-log p(target))/βW_down[:,slot] per KB layer. |
| Identifies which slots causally contributed to this prediction via gradient signal.""" |
| target_id = self.token_id(target) |
|
|
| |
| down_params: List[Tuple[int, torch.nn.Parameter]] = [] |
| for li in range(self.arch.n_layers): |
| layer = self.arch._layer(li) |
| p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \ |
| else layer.mlp.down_proj.weight |
| p.requires_grad_(True) |
| down_params.append((li, p)) |
|
|
| self.model.zero_grad() |
| inputs = self.tok(prompt, return_tensors="pt").to(self.device) |
| out = self.model(**inputs) |
| loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id] |
| loss.backward() |
|
|
| layer_scores = [] |
| for li, p in down_params: |
| grad = p.grad |
| p.requires_grad_(False) |
| if grad is None: |
| layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []}) |
| continue |
| |
| |
| slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \ |
| else grad.norm(dim=0) |
| k = min(20, slot_norms.shape[0]) |
| vals, idxs = slot_norms.topk(k) |
| layer_scores.append({ |
| "layer": li, |
| "max_grad": round(float(vals[0].item()), 6), |
| "top_slots": [{"slot": int(idx.item()), |
| "grad_norm": round(float(v.item()), 6)} |
| for idx, v in zip(idxs, vals)] |
| }) |
|
|
| self.model.zero_grad() |
| return {"layer_scores": layer_scores} |
|
|
| def causal_patch_trace(self, prompt: str, subject: str, target: str, |
| noise_std: float = 0.1) -> Dict: |
| """ROME-style causal tracing. |
| Corrupts subject embeddings, then for each KB layer measures how much |
| patching that layer's hidden state (at subject position) restores p(target). |
| Expensive: O(n_layers) forward passes.""" |
| target_id = self.token_id(target) |
| W_u = self.arch.get_unembedding().to(self.device) |
| inputs = self.tok(prompt, return_tensors="pt").to(self.device) |
| ids = inputs["input_ids"][0].tolist() |
|
|
| |
| subj_ids = self.tok.encode(subject, add_special_tokens=False) |
| subj_pos: List[int] = [] |
| for start in range(len(ids) - len(subj_ids) + 1): |
| if ids[start:start+len(subj_ids)] == subj_ids: |
| subj_pos = list(range(start, start+len(subj_ids))) |
| break |
| if not subj_pos: |
| for si in subj_ids: |
| if si in ids: |
| subj_pos = [ids.index(si)] |
| break |
| if not subj_pos: |
| subj_pos = [0] |
|
|
| |
| clean_hs: Dict[int, torch.Tensor] = {} |
| clean_handles = [] |
| def _mk_clean(li): |
| def _h(m, inp, out): |
| h = out[0] if isinstance(out, tuple) else out |
| clean_hs[li] = h[0].detach().clone() |
| return _h |
| for li in range(self.arch.n_layers): |
| clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li))) |
| with torch.no_grad(): |
| clean_out = self.model(**inputs) |
| for h in clean_handles: h.remove() |
| clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item()) |
|
|
| |
| E = self.arch.get_embedding().to(self.device) |
| emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() |
| noise_scale = emb.std().item() * noise_std |
| for pos in subj_pos: |
| emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale |
|
|
| with torch.no_grad(): |
| corr_out = self.model(inputs_embeds=emb) |
| corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item()) |
|
|
| |
| results = [] |
| for li in range(self.kb_start, self.kb_end): |
| def _mk_patch(target_li): |
| def _h(m, inp, out): |
| if target_li not in clean_hs: |
| return out |
| is_tuple = isinstance(out, tuple) |
| h = list(out) if is_tuple else [out] |
| clean = clean_hs[target_li] |
| for pos in subj_pos: |
| if pos < clean.shape[0]: |
| h[0][0, pos] = clean[pos].to(h[0].device) |
| return tuple(h) if is_tuple else h[0] |
| return _h |
| ph = self.arch._layer(li).register_forward_hook(_mk_patch(li)) |
| with torch.no_grad(): |
| patch_out = self.model(inputs_embeds=emb.clone()) |
| ph.remove() |
| patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item()) |
| ie = patch_prob - corr_prob |
| results.append({ |
| "layer": li, |
| "patch_prob": round(patch_prob, 6), |
| "indirect_effect": round(ie, 6), |
| }) |
|
|
| return { |
| "clean_prob": round(clean_prob, 6), |
| "corrupt_prob": round(corr_prob, 6), |
| "subject_pos": subj_pos, |
| "results": results, |
| } |
|
|
| def smart_locate(self, prompt: str, subject: str, target: str, |
| alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3, |
| noise_std: float = 0.1) -> Dict: |
| """Combined gate_sim + grad_norm + causal_effect β precise layer/slot ranking. |
| alpha = weight for gate cosine sim |
| beta = weight for gradient norm |
| gamma = weight for causal indirect effect""" |
| gate_data = self.locate(prompt, subject, target) |
| grad_data = self.gradient_slot_scores(prompt, target) |
| causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std) |
|
|
| gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]} |
| grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]} |
| causal_map = {r["layer"]: max(0.0, r["indirect_effect"]) |
| for r in causal_data["results"]} |
| grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]} |
|
|
| layers = sorted(set(gate_map) | set(grad_map) | set(causal_map)) |
|
|
| def _norm(vals: List[float]) -> List[float]: |
| m = max(vals) if vals else 1.0 |
| return [v/m if m > 0 else 0.0 for v in vals] |
|
|
| gv = [gate_map.get(l, 0.0) for l in layers] |
| dv = [grad_map.get(l, 0.0) for l in layers] |
| cv = [causal_map.get(l, 0.0) for l in layers] |
| gn, dn, cn = _norm(gv), _norm(dv), _norm(cv) |
|
|
| ranked = [] |
| for i, l in enumerate(layers): |
| score = alpha*gn[i] + beta*dn[i] + gamma*cn[i] |
| ranked.append({ |
| "layer": l, |
| "gate_sim": round(gv[i], 4), |
| "grad_norm": round(dv[i], 6), |
| "causal_effect": round(cv[i], 6), |
| "gate_sim_n": round(gn[i], 4), |
| "grad_norm_n": round(dn[i], 4), |
| "causal_n": round(cn[i], 4), |
| "combined": round(score, 4), |
| "best_slots": (grad_slots.get(l) or [])[:5], |
| }) |
|
|
| ranked.sort(key=lambda x: -x["combined"]) |
|
|
| return { |
| "ranked_layers": ranked, |
| "phase_layer": gate_data["phase_layer"], |
| "subject_pos": gate_data["subject_pos"], |
| "clean_prob": causal_data["clean_prob"], |
| "corrupt_prob": causal_data["corrupt_prob"], |
| "recommendation": ranked[0] if ranked else None, |
| "weights": {"alpha": alpha, "beta": beta, "gamma": gamma}, |
| } |
|
|
| def smart_edit(self, prompt: str, subject: str, relation: str, |
| old_target: str, new_target: str, |
| top_layers: int = 3, slots_per_layer: int = 2, |
| scale: float = 1.5, noise_std: float = 0.1, |
| alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2, |
| log: Optional[List[str]] = None) -> Dict: |
| """Auto edit: runs smart_locate on (prompt, subject, old_target) to find |
| the exact layer+slot targets via gradient+causal+gate consensus, then |
| patches those W_down columns toward embed(new_target). |
| |
| old_target = what the model currently predicts (used to locate) |
| new_target = what you want to inject |
| top_layers = how many top-ranked layers to patch |
| slots_per_layer = gradient-identified slots to patch per layer |
| scale = col_norm multiplier (1.5-3.0 recommended) |
| beta > alpha because grad_norm is more reliable than gate_sim for small models.""" |
| if log is None: log = [] |
| self._snapshot() |
|
|
| log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} β {new_target!r}") |
| log.append(f" Running smart_locate on prompt: {prompt!r}") |
| log.append(f" Weights: Ξ±={alpha} Ξ²={beta} Ξ³={gamma} noise_std={noise_std}") |
|
|
| sl = self.smart_locate(prompt, subject, old_target, |
| alpha=alpha, beta=beta, gamma=gamma, |
| noise_std=noise_std) |
|
|
| log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}") |
| log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}") |
|
|
| if sl["clean_prob"] < 1e-5: |
| log.append(" β clean_prob near zero β model barely knows this fact.") |
| log.append(" Grad-norm signal still valid. Causal IE=0 is expected.") |
| log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.") |
|
|
| tv = self.embed(new_target) |
| tv_n = F.normalize(tv, dim=0) |
| ops = [] |
| used = [] |
|
|
| top_ranked = sl["ranked_layers"][:top_layers] |
| for lr in top_ranked: |
| li = lr["layer"] |
| |
| grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]] |
| if not grad_slots: |
| log.append(f" L{li}: no grad slots, skipping") |
| continue |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| for slot in grad_slots: |
| col_norm = Wd[:, slot].norm().item() |
| new_col = (tv_n * col_norm * scale).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) |
| log.append(f" β L{li} slot {slot}: combined={lr['combined']} " |
| f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} " |
| f"inject={col_norm*scale:.4f}") |
| used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]}) |
|
|
| self.patches.append({ |
| "type": "SMART_UPDATE", |
| "entity": subject, |
| "relation": relation, |
| "new_target": new_target, |
| "old_target": old_target, |
| "smart_top": top_ranked, |
| "ops": ops, |
| }) |
| self._apply_all_patches() |
| log.append(f"\n β {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}") |
|
|
| return { |
| "ops": ops, |
| "used_layers": used, |
| "smart_locate": sl, |
| "log": log, |
| } |
|
|
|
|
|
|
| def infer(self, prompt: str, top_k: int = 5): |
| probs = torch.softmax(self._forward(prompt), dim=-1) |
| top = probs.topk(top_k) |
| return [{"token": self.tok.decode([idx.item()]).strip(), |
| "prob": round(val.item(), 6)} |
| for idx, val in zip(top.indices, top.values)] |
|
|
| def describe(self, entity: str, top_k: int = 10): |
| ev_n = F.normalize(self.embed(entity), dim=0) |
| all_edges = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device); Wd = Wd.to(self.device) |
| sims = F.normalize(Wg, dim=1) @ ev_n |
| for fid in sims.topk(min(5,sims.shape[0])).indices: |
| gsim = sims[fid].item() |
| if gsim < 0.08: continue |
| for tok, score in self.decode_down_col(Wd[:,fid], 4): |
| if tok: all_edges.append({"tok":tok,"score":score,"layer":li,"gate_sim":gsim}) |
| best: Dict[str, Any] = {} |
| for e in all_edges: |
| t = e["tok"] |
| if t not in best or e["score"] > best[t]["score"]: |
| best[t] = e |
| ranked = sorted(best.values(), key=lambda x: -x["score"])[:top_k] |
| return ranked |
|
|
| def trace(self, prompt: str, target: str): |
| target_id = self.token_id(target) |
| W_u = self.arch.get_unembedding().to(self.device) |
| stats: List[Dict] = [] |
| handles = [] |
|
|
| def make_hook(li): |
| def hook(m, inp, out): |
| h = out[0] if isinstance(out,tuple) else out |
| last = h[0,-1].detach() |
| p = torch.softmax(W_u @ last, dim=-1) |
| rank = int((p > p[target_id]).sum().item()) + 1 |
| stats.append({"l":li,"rank":rank,"prob":round(p[target_id].item(),8)}) |
| return hook |
|
|
| for li in range(self.arch.n_layers): |
| handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) |
| inputs = self.tok(prompt, return_tensors="pt").to(self.device) |
| with torch.no_grad(): self.model(**inputs) |
| for h in handles: h.remove() |
| return stats |
|
|
| |
|
|
| def _snapshot(self): |
| if self._base_weights is not None: return |
| self._base_weights = {} |
| for li in range(self.arch.n_layers): |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| self._base_weights[li] = (Wg.clone().cpu(), Wd.clone().cpu()) |
|
|
| def _restore_base(self): |
| if self._base_weights is None: return |
| for li,(Wg,Wd) in self._base_weights.items(): |
| self.arch.set_ffn_weights(li, Wg.to(self.device), Wd.to(self.device)) |
|
|
| def _apply_all_patches(self): |
| self._restore_base() |
| for patch in self.patches: |
| for op in patch.get("ops",[]): |
| li=op["layer"]; slot=op["slot"] |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg=Wg.clone(); Wd=Wd.clone() |
| if op["op"] in ("insert","update_gate"): |
| Wg[slot] = torch.tensor(op["gate_row"],dtype=Wg.dtype,device=self.device) |
| if op["op"] in ("insert","update_down"): |
| Wd[:,slot] = torch.tensor(op["down_col"],dtype=Wd.dtype,device=self.device) |
| self.arch.set_ffn_weights(li, Wg, Wd) |
|
|
| def insert(self, entity, relation, target, alpha=0.25, spread=4, log=None): |
| if log is None: log = [] |
| self._snapshot() |
| gate_dir = F.normalize(self.embed(entity), dim=0) |
| down_dir = F.normalize(self.embed(target), dim=0) |
| ls=self.kb_start; le=min(ls+spread, self.arch.n_layers) |
| log.append(f"INSERT: '{entity}' -[{relation}]-> '{target}' alpha={alpha}") |
| ops=[] |
| for li in range(ls,le): |
| Wg, Wd = self.arch.get_ffn_weights(li) |
| Wg=Wg.to(self.device); Wd=Wd.to(self.device) |
| norms_g=Wg.norm(dim=1); norms_d=Wd.norm(dim=0) |
| slot=norms_g.argmin().item() |
| wg_mean=norms_g.mean().item(); wd_mean=norms_d.mean().item() |
| ops.append({"op":"insert","layer":li,"slot":slot, |
| "gate_row":(gate_dir*wg_mean*alpha).cpu().tolist(), |
| "down_col":(down_dir*wd_mean*alpha).cpu().tolist()}) |
| log.append(f" L{li}: slot={slot} inject={wg_mean*alpha:.4f}") |
| self.patches.append({"type":"INSERT","entity":entity,"relation":relation, |
| "target":target,"ops":ops}) |
| self._apply_all_patches() |
| return ops |
|
|
| def update(self, entity, relation, new_target, top_k=3, scale=1.0, log=None): |
| if log is None: log = [] |
| self._snapshot() |
| ev = self.embed(entity) |
| tv = self.embed(new_target) |
| ev_n = F.normalize(ev, dim=0) |
| log.append(f"UPDATE: '{entity}' -[{relation}]-> '{new_target}' top_k={top_k} scale={scale}") |
| candidates = [] |
| for li in range(self.kb_start, self.kb_end): |
| Wg, _ = self.arch.get_ffn_weights(li) |
| Wg = Wg.to(self.device) |
| sims = F.normalize(Wg, dim=1) @ ev_n |
| k = min(top_k, sims.shape[0]) |
| vals, idxs = sims.topk(k) |
| for v, idx in zip(vals, idxs): |
| candidates.append((v.item(), li, idx.item())) |
| log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) |
| candidates.sort(key=lambda x: -x[0]) |
| best_sim = candidates[0][0] if candidates else 0.0 |
| if best_sim < 0.08: |
| log.append(" β sim<0.08 β INSERT fallback") |
| return self.insert(entity, relation, new_target, log=log), "INSERT_FALLBACK" |
| chosen = [c for c in candidates if c[0] >= 0.08][:top_k] |
| ops = [] |
| for sim, li, slot in chosen: |
| _, Wd = self.arch.get_ffn_weights(li) |
| Wd = Wd.to(self.device) |
| col_norm = Wd[:, slot].norm().item() |
| new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() |
| ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) |
| log.append(f" β L{li} slot {slot}: sim={sim:.4f} norm={col_norm:.4f}") |
| self.patches.append({"type":"UPDATE","entity":entity,"relation":relation, |
| "new_target":new_target,"ops":ops}) |
| self._apply_all_patches() |
| li0,slot0,sim0 = chosen[0] |
| return li0, slot0, sim0 |
|
|
| def reset(self): |
| self.patches.clear() |
| self._restore_base() |
|
|
| def save_patch(self, path: str): |
| with open(path,"w") as f: |
| json.dump(self.patches, f, indent=2) |
|
|
| def compile_to(self, output_dir: str): |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| self.model.save_pretrained(output_dir) |
| self.tok.save_pretrained(output_dir) |
|
|
|
|
| |
| |
| |
|
|
| HTML_PAGE = r"""<!DOCTYPE html> |
| <html lang="en"> |
| <head> |
| <meta charset="UTF-8"> |
| <title>VINDEX β LLM Knowledge Editor</title> |
| <script src="https://unpkg.com/d3@7"></script> |
| <style> |
| :root { |
| --bg: #0d1117; --bg2: #161b22; --bg3: #21262d; |
| --border: #30363d; --text: #e6edf3; --muted: #8b949e; |
| --blue: #58a6ff; --green: #3fb950; --red: #f85149; |
| --yellow: #d29922; --purple: #bc8cff; --cyan: #39d353; |
| --font: 'JetBrains Mono', 'Fira Code', 'Cascadia Code', monospace; |
| } |
| * { box-sizing: border-box; margin: 0; padding: 0; } |
| body { background: var(--bg); color: var(--text); font-family: var(--font); |
| font-size: 13px; min-height: 100vh; } |
| #app { max-width: 1400px; margin: 0 auto; padding: 16px; } |
| |
| /* HEADER */ |
| .header { background: var(--bg2); border: 1px solid var(--border); |
| border-radius: 10px; padding: 20px 28px; margin-bottom: 14px; |
| display: flex; align-items: center; justify-content: space-between; } |
| .header-title { font-size: 20px; font-weight: 700; letter-spacing: 1px; } |
| .header-title span { color: var(--blue); } |
| .status-pill { background: var(--bg3); border: 1px solid var(--border); |
| border-radius: 20px; padding: 6px 14px; font-size: 11px; |
| color: var(--muted); display: flex; align-items: center; gap: 8px; } |
| .status-dot { width: 8px; height: 8px; border-radius: 50%; background: var(--red); } |
| .status-dot.ok { background: var(--green); } |
| |
| /* TABS */ |
| .tabs { display: flex; gap: 2px; background: var(--bg2); border: 1px solid var(--border); |
| border-radius: 8px; padding: 4px; margin-bottom: 14px; flex-wrap: wrap; } |
| .tab-btn { padding: 7px 16px; border: none; background: transparent; color: var(--muted); |
| font-family: var(--font); font-size: 12px; cursor: pointer; border-radius: 6px; |
| letter-spacing: .5px; transition: all .15s; } |
| .tab-btn:hover { color: var(--text); background: var(--bg3); } |
| .tab-btn.active { background: var(--blue); color: #fff; } |
| |
| /* PANELS */ |
| .panel { display: none; } |
| .panel.active { display: block; } |
| .card { background: var(--bg2); border: 1px solid var(--border); |
| border-radius: 8px; padding: 18px; margin-bottom: 12px; } |
| .card h3 { font-size: 11px; letter-spacing: 2px; text-transform: uppercase; |
| color: var(--muted); margin-bottom: 14px; } |
| |
| /* FORM ELEMENTS */ |
| label { font-size: 11px; color: var(--muted); display: block; margin-bottom: 4px; } |
| input[type=text], textarea, select { |
| width: 100%; background: var(--bg3); border: 1px solid var(--border); |
| color: var(--text); font-family: var(--font); font-size: 12px; |
| padding: 8px 10px; border-radius: 6px; outline: none; } |
| input[type=text]:focus, textarea:focus { border-color: var(--blue); } |
| textarea { resize: vertical; min-height: 80px; } |
| input[type=range] { width: 100%; accent-color: var(--blue); } |
| .range-row { display: flex; align-items: center; gap: 10px; } |
| .range-row span { min-width: 40px; text-align: right; color: var(--blue); } |
| button { |
| background: var(--blue); color: #fff; border: none; font-family: var(--font); |
| font-size: 12px; padding: 8px 18px; border-radius: 6px; cursor: pointer; |
| font-weight: 600; letter-spacing: .5px; transition: opacity .15s; } |
| button:hover { opacity: .85; } |
| button.secondary { background: var(--bg3); color: var(--text); border: 1px solid var(--border); } |
| button.danger { background: var(--red); } |
| button.warn { background: var(--yellow); color: #000; } |
| .row { display: flex; gap: 12px; flex-wrap: wrap; } |
| .col { flex: 1; min-width: 200px; } |
| .col2 { flex: 2; min-width: 300px; } |
| |
| /* RADIO GROUP */ |
| .radio-group { display: flex; flex-wrap: wrap; gap: 6px; } |
| .radio-group label { display: flex; align-items: center; gap: 5px; cursor: pointer; |
| background: var(--bg3); border: 1px solid var(--border); border-radius: 5px; |
| padding: 5px 10px; color: var(--text); font-size: 11px; margin: 0; } |
| .radio-group input { accent-color: var(--blue); } |
| |
| /* LOG BOX */ |
| .log { background: var(--bg); border: 1px solid var(--border); border-radius: 6px; |
| padding: 12px; font-size: 11px; color: var(--muted); max-height: 200px; |
| overflow-y: auto; white-space: pre-wrap; margin-top: 8px; } |
| |
| /* SVG CHARTS */ |
| .chart-wrap { background: var(--bg); border: 1px solid var(--border); |
| border-radius: 6px; overflow: hidden; margin-top: 10px; } |
| svg text { font-family: var(--font); fill: var(--text); } |
| .tooltip { |
| position: absolute; background: var(--bg2); border: 1px solid var(--border); |
| border-radius: 6px; padding: 8px 12px; font-size: 11px; pointer-events: none; |
| opacity: 0; transition: opacity .1s; z-index: 100; } |
| |
| /* PATCH CARDS */ |
| .patch-card { background: var(--bg3); border: 1px solid var(--border); |
| border-radius: 6px; padding: 10px 14px; margin-bottom: 8px; |
| display: flex; align-items: center; justify-content: space-between; } |
| .patch-badge { padding: 2px 8px; border-radius: 4px; font-size: 10px; |
| font-weight: 700; letter-spacing: 1px; margin-right: 10px; } |
| .badge-UPDATE { background: var(--blue); color: #fff; } |
| .badge-INSERT { background: var(--green); color: #000; } |
| .badge-PRECISE { background: var(--purple); color: #fff; } |
| .badge-SUPPRESS{ background: var(--red); color: #fff; } |
| .badge-AMPLIFY { background: var(--yellow); color: #000; } |
| .badge-STYLE { background: var(--cyan); color: #000; } |
| .patch-del { background: transparent; border: none; color: var(--red); |
| font-size: 16px; cursor: pointer; padding: 0 6px; } |
| |
| /* HEATMAP */ |
| .hm-cell { stroke: var(--bg); stroke-width: 1; cursor: pointer; } |
| .hm-cell:hover { stroke: var(--blue); stroke-width: 2; } |
| |
| /* LOCATE SECTION */ |
| .locate-bar { height: 22px; background: var(--bg3); border-radius: 3px; |
| margin-bottom: 3px; display: flex; align-items: center; |
| padding: 0 8px; cursor: pointer; transition: background .1s; } |
| .locate-bar:hover { background: var(--bg2); } |
| .locate-bar .layer-lbl { width: 40px; color: var(--muted); font-size: 10px; } |
| .locate-bar .bar-fill { height: 10px; border-radius: 2px; background: var(--purple); } |
| .locate-bar .sim-val { margin-left: 8px; font-size: 10px; color: var(--blue); } |
| |
| /* Scrollbar */ |
| ::-webkit-scrollbar { width: 6px; } ::-webkit-scrollbar-track { background: var(--bg2); } |
| ::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; } |
| </style> |
| </head> |
| <body> |
| <div id="app"> |
| |
| <!-- HEADER --> |
| <div class="header"> |
| <div> |
| <div class="header-title">VINDEX <span>β</span> LLM Knowledge Editor</div> |
| <div style="color:var(--muted);font-size:11px;margin-top:4px"> |
| The model IS the database. Inspect Β· Edit Β· Locate Β· Compile. |
| </div> |
| </div> |
| <div class="status-pill"> |
| <div class="status-dot" id="status-dot"></div> |
| <span id="status-text">No model loaded</span> |
| <span id="patch-count" style="color:var(--blue)"></span> |
| </div> |
| </div> |
| |
| <!-- TABS --> |
| <div class="tabs"> |
| <button class="tab-btn active" onclick="showTab('infer')">β Infer</button> |
| <button class="tab-btn" onclick="showTab('describe')">β‘ Describe</button> |
| <button class="tab-btn" onclick="showTab('trace')">β’ Trace</button> |
| <button class="tab-btn" onclick="showTab('locate')">β£ Locate</button> |
| <button class="tab-btn" onclick="showTab('smartlocate')">β€ Smart Locate</button> |
| <button class="tab-btn" onclick="showTab('heatmap')">β₯ Heatmap</button> |
| <button class="tab-btn" onclick="showTab('edit')">β¦ Edit</button> |
| <button class="tab-btn" onclick="showTab('patches')">β§ Patches</button> |
| <button class="tab-btn" onclick="showTab('guide')" style="margin-left:auto;color:var(--green)">π Guide</button> |
| <button class="tab-btn" onclick="showTab('load')">β Load</button> |
| </div> |
| |
| <!-- TOOLTIP --> |
| <div class="tooltip" id="tooltip"></div> |
| |
| <!-- ββββββββββ LOAD PANEL ββββββββββ --> |
| <div id="panel-load" class="panel"> |
| <div class="card"> |
| <h3>Load Model</h3> |
| <div class="row"> |
| <div class="col2"> |
| <label>Model name or local path</label> |
| <input type="text" id="load-model" value="distilgpt2" |
| placeholder="distilgpt2 | Qwen/Qwen2.5-1.5B-Instruct | /local/path"> |
| </div> |
| <div class="col"> |
| <label>Device</label> |
| <select id="load-device"> |
| <option value="auto">Auto</option> |
| <option value="cpu">CPU</option> |
| <option value="cuda">CUDA</option> |
| <option value="mps">MPS</option> |
| </select> |
| </div> |
| </div> |
| <div style="margin-top:12px"> |
| <button onclick="loadModel()">β‘ Load Model</button> |
| </div> |
| <div id="load-log" class="log" style="margin-top:12px;display:none"></div> |
| </div> |
| <div class="card"> |
| <h3>Quick Models</h3> |
| <div style="color:var(--muted);line-height:1.8"> |
| <span style="color:var(--blue)">distilgpt2</span> β 350 MB, instant<br> |
| <span style="color:var(--blue)">gpt2</span> β 550 MB<br> |
| <span style="color:var(--blue)">gpt2-medium</span> β 1.5 GB<br> |
| <span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> β 3 GB, strong facts |
| </div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ INFER PANEL ββββββββββ --> |
| <div id="panel-infer" class="panel active"> |
| <div class="card"> |
| <h3>Next-Token Prediction</h3> |
| <div class="row"> |
| <div class="col2"> |
| <label>Prompt</label> |
| <input type="text" id="infer-prompt" value="The capital of France is"> |
| </div> |
| <div class="col"> |
| <label>Top-K: <span id="infer-k-val">10</span></label> |
| <div class="range-row"> |
| <input type="range" id="infer-k" min="1" max="20" value="10" |
| oninput="document.getElementById('infer-k-val').textContent=this.value"> |
| </div> |
| </div> |
| </div> |
| <button onclick="runInfer()" style="margin-top:12px">βΆ Run Infer</button> |
| </div> |
| <div class="card"> |
| <h3>Results</h3> |
| <div class="chart-wrap" id="infer-chart" style="min-height:200px"></div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ DESCRIBE PANEL ββββββββββ --> |
| <div id="panel-describe" class="panel"> |
| <div class="card"> |
| <h3>Entity Knowledge Graph (W_gate KNN β W_down decode)</h3> |
| <div class="row"> |
| <div class="col"> |
| <label>Entity</label> |
| <input type="text" id="desc-entity" value="France"> |
| </div> |
| <div class="col"> |
| <label>Top edges: <span id="desc-k-val">10</span></label> |
| <div class="range-row"> |
| <input type="range" id="desc-k" min="1" max="20" value="10" |
| oninput="document.getElementById('desc-k-val').textContent=this.value"> |
| </div> |
| </div> |
| </div> |
| <button onclick="runDescribe()" style="margin-top:12px">βΆ Run Describe</button> |
| </div> |
| <div class="card"> |
| <h3>Force-Directed Graph</h3> |
| <div class="chart-wrap" id="desc-chart" style="height:420px"></div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ TRACE PANEL ββββββββββ --> |
| <div id="panel-trace" class="panel"> |
| <div class="card"> |
| <h3>Layer-by-Layer Rank Trace</h3> |
| <div class="row"> |
| <div class="col2"> |
| <label>Prompt</label> |
| <input type="text" id="trace-prompt" value="The capital of France is"> |
| </div> |
| <div class="col"> |
| <label>Target token</label> |
| <input type="text" id="trace-target" value="Paris"> |
| </div> |
| </div> |
| <button onclick="runTrace()" style="margin-top:12px">βΆ Run Trace</button> |
| </div> |
| <div class="card"> |
| <h3>Rank + Probability over Layers</h3> |
| <div class="chart-wrap" id="trace-chart" style="min-height:300px"></div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ LOCATE PANEL ββββββββββ --> |
| <div id="panel-locate" class="panel"> |
| <div class="card"> |
| <h3>Locate β Diagnostic (Trace + Activation Similarity)</h3> |
| <div class="row"> |
| <div class="col2"> |
| <label>Prompt</label> |
| <input type="text" id="loc-prompt" value="The capital of France is"> |
| </div> |
| <div class="col"> |
| <label>Subject</label> |
| <input type="text" id="loc-subject" value="France"> |
| </div> |
| <div class="col"> |
| <label>Target</label> |
| <input type="text" id="loc-target" value="Paris"> |
| </div> |
| </div> |
| <button onclick="runLocate()" style="margin-top:12px">βΆ Locate</button> |
| </div> |
| <div class="row"> |
| <div class="col2 card"> |
| <h3>Trace (rank over layers)</h3> |
| <div class="chart-wrap" id="loc-trace-chart" style="min-height:220px"></div> |
| </div> |
| <div class="col card"> |
| <h3>Activation Sim per KB Layer</h3> |
| <div id="loc-bars" style="margin-top:8px"></div> |
| <div id="loc-slots" class="log" style="margin-top:8px;display:none"></div> |
| </div> |
| </div> |
| <div class="card" id="loc-summary" style="display:none"> |
| <h3>Summary</h3> |
| <div id="loc-summary-text" style="color:var(--blue)"></div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ SMART LOCATE PANEL ββββββββββ --> |
| <div id="panel-smartlocate" class="panel"> |
| <div class="card"> |
| <h3>Smart Locate β gradient + causal + gate_sim combined</h3> |
| <div style="color:var(--muted);font-size:11px;margin-bottom:12px;line-height:1.7"> |
| Three independent signals combined into one ranked layer list.<br> |
| <span style="color:var(--blue)">β gate_sim</span> β static embedding cosine (fast, weak proxy) |
| <span style="color:var(--green)">β grad_norm</span> β βloss/βW_down per slot (one backward pass) |
| <span style="color:var(--yellow)">β causal IE</span> β indirect effect via subject-corruption patching (N_layers passes, slow) |
| </div> |
| <div class="row"> |
| <div class="col2"> |
| <label>Prompt</label> |
| <input type="text" id="sl-prompt" value="The capital of France is"> |
| </div> |
| <div class="col"> |
| <label>Subject</label> |
| <input type="text" id="sl-subject" value="France"> |
| </div> |
| <div class="col"> |
| <label>Target</label> |
| <input type="text" id="sl-target" value="Paris"> |
| </div> |
| </div> |
| <div class="row" style="margin-top:10px"> |
| <div class="col"> |
| <label>Ξ± gate_sim: <span id="sl-a-val">0.4</span></label> |
| <input type="range" id="sl-alpha" min="0" max="1" step="0.05" value="0.4" |
| oninput="document.getElementById('sl-a-val').textContent=this.value"> |
| </div> |
| <div class="col"> |
| <label>Ξ² grad_norm: <span id="sl-b-val">0.3</span></label> |
| <input type="range" id="sl-beta" min="0" max="1" step="0.05" value="0.3" |
| oninput="document.getElementById('sl-b-val').textContent=this.value"> |
| </div> |
| <div class="col"> |
| <label>Ξ³ causal: <span id="sl-g-val">0.3</span></label> |
| <input type="range" id="sl-gamma" min="0" max="1" step="0.05" value="0.3" |
| oninput="document.getElementById('sl-g-val').textContent=this.value"> |
| </div> |
| <div class="col"> |
| <label>Noise Ο: <span id="sl-noise-val">0.1</span></label> |
| <input type="range" id="sl-noise" min="0.02" max="0.5" step="0.02" value="0.1" |
| oninput="document.getElementById('sl-noise-val').textContent=this.value"> |
| </div> |
| </div> |
| <div style="display:flex;gap:8px;margin-top:12px;flex-wrap:wrap"> |
| <button onclick="runSmartLocate()">β‘ Smart Locate (full)</button> |
| <button class="secondary" onclick="runGradientOnly()">βΆ Gradient only (fast)</button> |
| <button class="secondary" onclick="runCausalOnly()">βΆ Causal trace only</button> |
| </div> |
| <div id="sl-status" style="color:var(--muted);font-size:11px;margin-top:8px"></div> |
| </div> |
| |
| <div class="row"> |
| <div class="col2 card"> |
| <h3>Layer Rankings β 3-signal stacked bars</h3> |
| <div class="chart-wrap" id="sl-chart" style="min-height:320px"></div> |
| </div> |
| <div class="col card"> |
| <h3>Recommendation</h3> |
| <div id="sl-rec" class="log">Run Smart Locate to see the best edit target.</div> |
| <h3 style="margin-top:14px">Collateral Probe</h3> |
| <div class="row" style="margin-top:8px"> |
| <input type="text" id="sl-coll-prompt" value="Biggest cities in France" |
| style="flex:2" placeholder="Collateral promptβ¦"> |
| <button class="secondary" onclick="runCollateral()" style="flex:0">βΆ</button> |
| </div> |
| <div id="sl-coll-out" class="log" style="margin-top:8px">Probe a prompt to check collateral damage.</div> |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Per-Layer Detail</h3> |
| <div id="sl-table" style="overflow-x:auto"> |
| <div style="color:var(--muted);font-size:11px">Run Smart Locate first.</div> |
| </div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ HEATMAP PANEL ββββββββββ --> |
| <div id="panel-heatmap" class="panel"> |
| <div class="card"> |
| <h3>Gate Heatmap β Layer Γ Slot Cosine Similarity</h3> |
| <div class="row"> |
| <div class="col"> |
| <label>Entity</label> |
| <input type="text" id="hm-entity" value="France"> |
| </div> |
| <div class="col"> |
| <label>Prompt (optional β for activation mode)</label> |
| <input type="text" id="hm-prompt" value="The capital of France is"> |
| </div> |
| <div class="col"> |
| <label>Mode</label> |
| <div class="radio-group"> |
| <label><input type="radio" name="hm-mode" value="embed" checked> Embed</label> |
| <label><input type="radio" name="hm-mode" value="activation"> Activation</label> |
| </div> |
| </div> |
| </div> |
| <button onclick="runHeatmap()" style="margin-top:12px">βΆ Run Heatmap</button> |
| </div> |
| <div class="row"> |
| <div class="col2 card"> |
| <h3>Heatmap</h3> |
| <div class="chart-wrap" id="hm-chart" style="min-height:300px"></div> |
| </div> |
| <div class="col card"> |
| <h3>Selected Slot</h3> |
| <div id="hm-slot-detail" class="log">Click a cell to see decoded tokens.</div> |
| </div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ EDIT PANEL ββββββββββ --> |
| <div id="panel-edit" class="panel"> |
| <div class="card"> |
| <h3>Edit</h3> |
| <div class="row"> |
| <div class="col"> |
| <label>Entity</label> |
| <input type="text" id="edit-entity" value="France"> |
| <label style="margin-top:8px">Relation</label> |
| <input type="text" id="edit-relation" value="capital"> |
| <label style="margin-top:8px">Old value (context only)</label> |
| <input type="text" id="edit-old" value="Paris"> |
| <label style="margin-top:8px">New value (inject)</label> |
| <input type="text" id="edit-new" value="Lyon"> |
| </div> |
| <div class="col"> |
| <label>Mode</label> |
| <div class="radio-group" id="edit-mode-group"> |
| <label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label> |
| <label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label> |
| <label><input type="radio" name="edit-mode" value="SMART"> β
SMART</label> |
| <label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label> |
| <label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label> |
| <label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label> |
| <label><input type="radio" name="edit-mode" value="STYLE-SHIFT"> STYLE-SHIFT</label> |
| <label><input type="radio" name="edit-mode" value="MULTI-EDIT"> MULTI-EDIT</label> |
| </div> |
| <div id="precise-prompt-row" style="display:none;margin-top:8px"> |
| <label>Prompt (PRECISE mode)</label> |
| <input type="text" id="edit-prompt" value="The capital of France is"> |
| </div> |
| <div id="smart-row" style="display:none;margin-top:8px;background:var(--bg);border:1px solid var(--border);border-radius:6px;padding:10px"> |
| <div style="color:var(--blue);font-size:11px;font-weight:700;margin-bottom:6px">β
SMART AUTO MODE</div> |
| <label>Prompt (used for locate + after-check)</label> |
| <input type="text" id="smart-prompt" value="The capital of France is"> |
| <label style="margin-top:6px">Old value (what model currently says)</label> |
| <input type="text" id="smart-old" value="Paris"> |
| <div class="row" style="margin-top:6px"> |
| <div class="col"> |
| <label>Top layers: <span id="smart-layers-val">3</span></label> |
| <input type="range" id="smart-layers" min="1" max="8" value="3" |
| oninput="document.getElementById('smart-layers-val').textContent=this.value"> |
| </div> |
| <div class="col"> |
| <label>Slots/layer: <span id="smart-slots-val">2</span></label> |
| <input type="range" id="smart-slots" min="1" max="5" value="2" |
| oninput="document.getElementById('smart-slots-val').textContent=this.value"> |
| </div> |
| </div> |
| <div style="color:var(--muted);font-size:10px;margin-top:6px"> |
| Runs smart_locate internally β patches gradient-identified slots. No manual tuning needed. |
| </div> |
| </div> |
| <div id="style-shift-row" style="display:none;margin-top:8px"> |
| <label>From concept</label> |
| <input type="text" id="ss-from" value="formal"> |
| <label style="margin-top:6px">To concept</label> |
| <input type="text" id="ss-to" value="casual"> |
| <label style="margin-top:6px">Strength: <span id="ss-str-val">0.5</span></label> |
| <div class="range-row"> |
| <input type="range" id="ss-strength" min="0.1" max="2.0" step="0.1" value="0.5" |
| oninput="document.getElementById('ss-str-val').textContent=this.value"> |
| </div> |
| </div> |
| <div id="multiedit-row" style="display:none;margin-top:8px"> |
| <label>JSON array [{entity,relation,new_target,prompt?},...]</label> |
| <textarea id="multi-json" rows="5">[{"entity":"France","relation":"capital","new_target":"Lyon"}]</textarea> |
| </div> |
| <label style="margin-top:8px">Top-K slots: <span id="edit-k-val">3</span></label> |
| <div class="range-row"> |
| <input type="range" id="edit-k" min="1" max="10" value="3" |
| oninput="document.getElementById('edit-k-val').textContent=this.value"> |
| </div> |
| <label style="margin-top:6px">Scale: <span id="edit-scale-val">1.5</span></label> |
| <div class="range-row"> |
| <input type="range" id="edit-scale" min="0.5" max="5.0" step="0.25" value="1.5" |
| oninput="document.getElementById('edit-scale-val').textContent=this.value"> |
| </div> |
| <label style="margin-top:6px">Alpha (INSERT): <span id="edit-alpha-val">0.25</span></label> |
| <div class="range-row"> |
| <input type="range" id="edit-alpha" min="0.05" max="1.0" step="0.05" value="0.25" |
| oninput="document.getElementById('edit-alpha-val').textContent=this.value"> |
| </div> |
| <div style="display:flex;gap:8px;margin-top:14px;flex-wrap:wrap"> |
| <button onclick="runDryRun()">π Dry Run</button> |
| <button onclick="runEdit()">β‘ Apply Edit</button> |
| </div> |
| </div> |
| </div> |
| </div> |
| |
| <div class="row"> |
| <div class="col card"> |
| <h3>Before / After</h3> |
| <div class="chart-wrap" id="edit-chart" style="min-height:220px"></div> |
| </div> |
| <div class="col card"> |
| <h3>Dry Run Preview</h3> |
| <div id="dryrun-log" class="log">Run dry-run first.</div> |
| </div> |
| </div> |
| <div class="card"> |
| <h3>Edit Log + Delta</h3> |
| <div id="edit-log" class="log">No edit yet.</div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ PATCHES PANEL ββββββββββ --> |
| <div id="panel-patches" class="panel"> |
| <div class="card"> |
| <h3>Active Patches</h3> |
| <div style="display:flex;gap:8px;margin-bottom:12px;flex-wrap:wrap"> |
| <button onclick="loadPatches()">β» Refresh</button> |
| <button class="secondary" onclick="savePatches()">πΎ Save JSON</button> |
| <button class="secondary" onclick="compileModel()">π¦ Compile</button> |
| <button class="danger" onclick="resetModel()">β Reset to Base</button> |
| </div> |
| <div id="patches-list"><div style="color:var(--muted)">No patches.</div></div> |
| </div> |
| <div class="card"> |
| <h3>Concept Reference</h3> |
| <div style="color:var(--muted);line-height:1.9;font-size:11px"> |
| <span style="color:var(--blue)">UPDATE</span> β rewrites W_down column β different fact<br> |
| <span style="color:var(--purple)">PRECISE</span> β activation-guided UPDATE (3β5Γ better sim)<br> |
| <span style="color:var(--green)">INSERT</span> β new (gate,down) pair in weakest slot<br> |
| <span style="color:var(--red)">SUPPRESS</span> β scale down β model forgets entity<br> |
| <span style="color:var(--yellow)">AMPLIFY</span> β scale up β stronger recall<br> |
| <span style="color:var(--cyan)">STYLE-SHIFT</span> β adds direction vector β tone/bias shift |
| </div> |
| </div> |
| </div> |
| |
| <!-- ββββββββββ GUIDE PANEL ββββββββββ --> |
| <div id="panel-guide" class="panel"> |
| <div class="card"> |
| <h3>What is VINDEX doing?</h3> |
| <div style="line-height:1.9;color:var(--muted)"> |
| In a transformer, factual associations like <span style="color:var(--text)">"France β capital β Paris"</span> |
| are stored as direction vectors in the <span style="color:var(--blue)">W_down columns</span> of FFN layers. |
| The <span style="color:var(--blue)">W_gate rows</span> act as keys: when the residual stream resembles "France", |
| the matching gate fires, the down column adds "Paris" direction to the stream, and the unembedding reads out "Paris". |
| VINDEX surgically replaces those down columns without retraining. |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Quickstart β 5-step experiment</h3> |
| <div style="line-height:2;font-size:12px"> |
| <div style="color:var(--yellow);margin-bottom:4px">Step 1 β Load a model that actually knows facts</div> |
| <div style="color:var(--muted);margin-left:16px;margin-bottom:10px"> |
| β Load tab β <span style="color:var(--blue)">gpt2-medium</span> (1.5 GB, knows capitals) or |
| <span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> (3 GB, strong).<br> |
| distilgpt2 has clean_probβ0 for most facts β causal IE=0 everywhere β misleading results. |
| </div> |
| |
| <div style="color:var(--yellow);margin-bottom:4px">Step 2 β Verify the model knows the fact</div> |
| <div style="color:var(--muted);margin-left:16px;margin-bottom:10px"> |
| β Infer: prompt = <code>"The capital of France is"</code><br> |
| β Good: "Paris" appears in top-3 with prob > 0.05<br> |
| β Bad: top tokens are "a", "the", "known" β model doesn't know it β skip to INSERT mode |
| </div> |
| |
| <div style="color:var(--yellow);margin-bottom:4px">Step 3 β Find where the fact lives</div> |
| <div style="color:var(--muted);margin-left:16px;margin-bottom:10px"> |
| β’ Trace: prompt = <code>"The capital of France is"</code>, target = <code>"Paris"</code><br> |
| β Look for phase layer: where rank drops from ~30000 to <100. That's where the fact materializes.<br> |
| β€ Smart Locate β Gradient only (fast, 1 backward pass):<br> |
| <span style="margin-left:16px">subject = <code>France</code>, target = <code>Paris</code></span><br> |
| β The layer with highest grad_norm bar = best edit target. Note the slot numbers. |
| </div> |
| |
| <div style="color:var(--yellow);margin-bottom:4px">Step 4 β Edit with SMART mode</div> |
| <div style="color:var(--muted);margin-left:16px;margin-bottom:10px"> |
| β¦ Edit tab β mode = <span style="color:var(--blue)">β
SMART</span><br> |
| Entity = <code>France</code> | Relation = <code>capital</code><br> |
| Old value = <code>Paris</code> (what model says now β used for locate)<br> |
| New value = <code>Lyon</code> (what you want)<br> |
| Prompt = <code>"The capital of France is"</code><br> |
| Scale = <code>2.0</code> (start here; increase to 3.0 if effect is weak)<br> |
| β Click <b>Apply Edit</b>. Smart locate runs internally, patches grad-identified slots. |
| </div> |
| |
| <div style="color:var(--yellow);margin-bottom:4px">Step 5 β Check collateral damage</div> |
| <div style="color:var(--muted);margin-left:16px;margin-bottom:10px"> |
| β Infer: <code>"The capital of France is"</code> β should now say Lyon<br> |
| β Infer: <code>"Biggest cities in France"</code> β should be unchanged (different slots)<br> |
| β Infer: <code>"Paris is a city in"</code> β should still say France<br> |
| β Infer: <code>"Lyon is a city in"</code> β might now also say France (collateral)<br> |
| β€ Smart Locate collateral probe β run these prompts, compare slot lists in β§ Patches |
| </div> |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Interpreting Smart Locate results</h3> |
| <div style="font-size:11px;line-height:1.9"> |
| <div class="row" style="gap:20px"> |
| <div class="col"> |
| <div style="color:var(--blue);font-weight:700;margin-bottom:6px">β gate_sim (blue)</div> |
| <div style="color:var(--muted)"> |
| Cosine between W_gate[slot] and embed(subject).<br> |
| Fast, cheap, but <b>weak proxy</b> β measures embedding-space similarity,<br> |
| not causal contribution. Useful for finding <i>related</i> slots.<br> |
| <b>High gate_sim + low grad_norm</b> = slot activates for this entity<br> |
| but doesn't contribute much to this specific prediction. |
| </div> |
| </div> |
| <div class="col"> |
| <div style="color:var(--green);font-weight:700;margin-bottom:6px">β grad_norm (green)</div> |
| <div style="color:var(--muted)"> |
| ββ(-log p(target))/βW_down[:,slot]β β how much changing this slot<br> |
| would affect the loss for this (prompt, target) pair.<br> |
| <b>Most reliable signal</b>, works even when clean_prob is tiny.<br> |
| One backward pass. Use Ξ² > Ξ± to weight this higher.<br> |
| <b>High grad_norm</b> = this slot is causally upstream of the prediction. |
| </div> |
| </div> |
| <div class="col"> |
| <div style="color:var(--yellow);font-weight:700;margin-bottom:6px">β causal IE (yellow)</div> |
| <div style="color:var(--muted)"> |
| Indirect effect via noise-corruption patching (ROME-style).<br> |
| Measures: if I corrupt subject embeddings, how much does patching<br> |
| layer L's hidden state at subject pos <i>restore</i> the prediction?<br> |
| <b>Most interpretable</b> β true causal measurement. But:<br> |
| If clean_prob β 0, IE = 0 everywhere (nothing to restore).<br> |
| Needs a model that actually knows the fact. |
| </div> |
| </div> |
| </div> |
| <div style="margin-top:12px;padding:10px;background:var(--bg);border-radius:6px;border:1px solid var(--border)"> |
| <span style="color:var(--yellow)">β Your distilgpt2 result:</span> |
| <span style="color:var(--muted)"> clean_prob=0.000001 β causal IE=0 everywhere (expected, not a bug). |
| grad_norm on L9/slot515 IS real signal β that slot responds to France+capital context in the gradient sense. |
| But the probability mass is too diffuse to show causal separation. |
| Switch to gpt2-medium for textbook causal results.</span> |
| </div> |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Edit modes β when to use which</h3> |
| <div style="font-size:11px"> |
| <table style="width:100%;border-collapse:collapse"> |
| <thead><tr style="border-bottom:1px solid var(--border);color:var(--muted)"> |
| <th style="padding:6px 8px;text-align:left">Mode</th> |
| <th style="padding:6px 8px;text-align:left">Slot selection</th> |
| <th style="padding:6px 8px;text-align:left">Best for</th> |
| <th style="padding:6px 8px;text-align:left">Knobs</th> |
| </tr></thead> |
| <tbody style="color:var(--muted)"> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--blue)">UPDATE</td> |
| <td style="padding:6px 8px">gate cosine sim to embed(entity)</td> |
| <td style="padding:6px 8px">Quick experiment, model knows the fact well</td> |
| <td style="padding:6px 8px">Top-K=3-5, Scale=1.5-3</td> |
| </tr> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--purple)">PRECISE</td> |
| <td style="padding:6px 8px">gate cosine sim to h_L[subject_pos]</td> |
| <td style="padding:6px 8px">In-context subject representation (3-5Γ better than UPDATE)</td> |
| <td style="padding:6px 8px">+ Prompt field</td> |
| </tr> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--yellow)">β
SMART</td> |
| <td style="padding:6px 8px">gradient norm β exact slots, then patch</td> |
| <td style="padding:6px 8px"><b>Best overall.</b> Auto-locates, no manual tuning</td> |
| <td style="padding:6px 8px">Top layers=3, Slots/layer=2, Scale=1.5-2.5</td> |
| </tr> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--green)">INSERT</td> |
| <td style="padding:6px 8px">weakest slot (norm-based)</td> |
| <td style="padding:6px 8px">Model has no knowledge of fact, build from scratch</td> |
| <td style="padding:6px 8px">Alpha=0.4-0.7, Spread=4-6</td> |
| </tr> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--red)">SUPPRESS</td> |
| <td style="padding:6px 8px">gate cosine β scale W_down to 0</td> |
| <td style="padding:6px 8px">Make model forget an entity (factor=0) or weaken (0.5)</td> |
| <td style="padding:6px 8px">Factor: 0=forget, 0.5=weaken</td> |
| </tr> |
| <tr style="border-bottom:1px solid var(--border)"> |
| <td style="padding:6px 8px;color:var(--cyan)">STYLE-SHIFT</td> |
| <td style="padding:6px 8px">gate cosine β add direction vector</td> |
| <td style="padding:6px 8px">Bias/tone shifts: CEOβless male-coded, Parisβdarker</td> |
| <td style="padding:6px 8px">from/to concepts, strength=0.3-0.8</td> |
| </tr> |
| </tbody> |
| </table> |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Experiments to run</h3> |
| <div style="font-size:11px;line-height:1.9;color:var(--muted)"> |
| <div style="color:var(--text);margin-bottom:4px">Experiment A β Capital swap (classic ROME benchmark)</div> |
| Model: gpt2-medium | Prompt: "The capital of France is" | Old: Paris | New: Lyon<br> |
| Check: "France's capital city" | "Lyon is now" | "Paris is in" | "Eiffel Tower is in"<br> |
| Insight: does it generalize (paraphrase) or is it prompt-specific?<br><br> |
| |
| <div style="color:var(--text);margin-bottom:4px">Experiment B β Slot overlap analysis (your collateral question)</div> |
| 1. SMART locate "The capital of France is" β note slot numbers in recommendation<br> |
| 2. SMART locate "The biggest city in France is" β compare slot lists<br> |
| 3. Overlap = slots that will be collaterally damaged<br> |
| 4. No overlap = clean surgery β<br><br> |
| |
| <div style="color:var(--text);margin-bottom:4px">Experiment C β Suppression then INSERT</div> |
| SUPPRESS France β then INSERT France capital Lyon β Infer<br> |
| vs just UPDATE. Which gives cleaner, more confident result?<br><br> |
| |
| <div style="color:var(--text);margin-bottom:4px">Experiment D β Style shift (no factual change)</div> |
| STYLE-SHIFT: anchor=CEO, from="male", to="female", strength=0.3<br> |
| Then Infer: "The CEO of the company is a" β does pronoun distribution shift?<br> |
| Insight: this is mechanical debiasing without retraining.<br><br> |
| |
| <div style="color:var(--text);margin-bottom:4px">Experiment E β Compile and compare</div> |
| Edit 5 facts. Compile β save as new model directory.<br> |
| Load compiled model fresh β Infer same prompts β edits should persist in weights.<br> |
| Then Trace on compiled model β phase layers should shift or sharpen. |
| </div> |
| </div> |
| |
| <div class="card"> |
| <h3>Ξ± Ξ² Ξ³ tuning guide</h3> |
| <div style="font-size:11px;line-height:1.9;color:var(--muted)"> |
| <b style="color:var(--text)">Default (0.4 / 0.3 / 0.3)</b> β balanced, works for unknown model quality<br> |
| <b style="color:var(--text)">Grad-heavy (0.1 / 0.7 / 0.2)</b> β clean_prob > 0.01. Grad signal is sharp, trust it.<br> |
| <b style="color:var(--text)">Gate+Grad (0.4 / 0.4 / 0.2)</b> β recommended for smart_edit when causal IE is weak<br> |
| <b style="color:var(--text)">Causal-heavy (0.2 / 0.2 / 0.6)</b> β only when clean_prob > 0.1. IE is the gold signal then.<br> |
| <b style="color:var(--text)">Gate-only (1.0 / 0.0 / 0.0)</b> β equivalent to basic locate(), sanity check<br> |
| <br> |
| <b style="color:var(--yellow)">Your distilgpt2 setting:</b> use (0.3 / 0.7 / 0.0) β gate+grad, skip causal (it's 0 anyway). |
| </div> |
| </div> |
| </div> |
| |
| </div><!-- /app --> |
| |
| <script> |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // STATE |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| let modelLoaded = false; |
| const BASE = ''; // same origin |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // UTILS |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| function showTab(name) { |
| document.querySelectorAll('.panel').forEach(p => p.classList.remove('active')); |
| document.querySelectorAll('.tab-btn').forEach(b => b.classList.remove('active')); |
| document.getElementById('panel-'+name).classList.add('active'); |
| event.target.classList.add('active'); |
| } |
| |
| function showTooltip(html, x, y) { |
| const t = document.getElementById('tooltip'); |
| t.innerHTML = html; t.style.opacity = 1; |
| t.style.left = (x+14)+'px'; t.style.top = (y-10)+'px'; |
| } |
| function hideTooltip() { document.getElementById('tooltip').style.opacity = 0; } |
| |
| async function api(path, body) { |
| const opts = body !== undefined |
| ? { method:'POST', headers:{'Content-Type':'application/json'}, body:JSON.stringify(body) } |
| : { method:'GET' }; |
| const r = await fetch(BASE+path, opts); |
| if (!r.ok) { const e = await r.json(); throw new Error(e.detail || r.statusText); } |
| return r.json(); |
| } |
| |
| function setStatus(ok, text, patches) { |
| const dot = document.getElementById('status-dot'); |
| dot.classList.toggle('ok', ok); |
| document.getElementById('status-text').textContent = text; |
| if (patches !== undefined) |
| document.getElementById('patch-count').textContent = patches ? ` Β· ${patches} patch(es)` : ''; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // LOAD |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function loadModel() { |
| const logEl = document.getElementById('load-log'); |
| logEl.style.display = 'block'; logEl.textContent = 'Loadingβ¦'; |
| try { |
| const r = await api('/api/load', { |
| model_name: document.getElementById('load-model').value.trim(), |
| device: document.getElementById('load-device').value |
| }); |
| logEl.textContent = 'β ' + r.info; |
| modelLoaded = true; |
| setStatus(true, r.info.split(' | ')[0], 0); |
| } catch(e) { logEl.textContent = 'β '+e.message; } |
| } |
| |
| async function checkStatus() { |
| try { |
| const r = await api('/api/status'); |
| if (r.loaded) { |
| modelLoaded = true; |
| setStatus(true, r.info.split(' | ')[0], r.patches_count); |
| } |
| } catch(e) {} |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // D3 HELPERS |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| const C = { bg:'#0d1117', bg2:'#161b22', bg3:'#21262d', |
| blue:'#58a6ff', green:'#3fb950', red:'#f85149', |
| yellow:'#d29922', purple:'#bc8cff', cyan:'#39d353', |
| text:'#e6edf3', muted:'#8b949e' }; |
| |
| function clearChart(id) { |
| const el = document.getElementById(id); |
| el.innerHTML = ''; return el; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // INFER |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function runInfer() { |
| try { |
| const data = await api('/api/infer', { |
| prompt: document.getElementById('infer-prompt').value, |
| top_k: +document.getElementById('infer-k').value |
| }); |
| drawInferChart(data.results); |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawInferChart(results) { |
| const el = clearChart('infer-chart'); |
| const W = el.clientWidth || 700, H = 36 + results.length * 38; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const margin = {left:120, right:70, top:16, bottom:16}; |
| const w = W - margin.left - margin.right; |
| const maxP = d3.max(results, d=>d.prob); |
| const x = d3.scaleLinear().domain([0, maxP]).range([0, w]); |
| const g = svg.append('g').attr('transform',`translate(${margin.left},${margin.top})`); |
| |
| const color = d3.scaleSequential(d3.interpolateCool) |
| .domain([results.length-1, 0]); |
| |
| results.forEach((d,i) => { |
| const y = i * 38; |
| // Label |
| g.append('text').attr('x',-8).attr('y',y+19).attr('text-anchor','end') |
| .attr('fill',C.text).attr('font-size',13).text(d.token || '(empty)'); |
| // Bar |
| g.append('rect').attr('x',0).attr('y',y+8).attr('width', x(d.prob)) |
| .attr('height',20).attr('rx',3).attr('fill',color(i)).attr('opacity',.85) |
| .on('mousemove',(ev)=>showTooltip(`${d.token}: ${d.prob.toFixed(6)}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| // Value |
| g.append('text').attr('x',x(d.prob)+6).attr('y',y+19) |
| .attr('fill',C.muted).attr('font-size',11).text(d.prob.toFixed(4)); |
| }); |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // DESCRIBE β force graph |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function runDescribe() { |
| try { |
| const data = await api('/api/describe', { |
| entity: document.getElementById('desc-entity').value, |
| top_k: +document.getElementById('desc-k').value |
| }); |
| drawDescribeGraph(document.getElementById('desc-entity').value, data.edges); |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawDescribeGraph(center, edges) { |
| const el = clearChart('desc-chart'); |
| const W = el.clientWidth || 700, H = 420; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| if (!edges.length) { |
| svg.append('text').attr('x',W/2).attr('y',H/2).attr('text-anchor','middle') |
| .attr('fill',C.muted).text('No edges found (sim < 0.08)'); return; |
| } |
| |
| const maxLayer = d3.max(edges, d=>d.layer); |
| const layerColor = d3.scaleSequential(d3.interpolateRainbow).domain([0, maxLayer]); |
| const maxScore = d3.max(edges, d=>d.score); |
| |
| const nodes = [{id:'__center__', label:center, r:22, fixed:true}]; |
| const links = []; |
| edges.forEach((e,i) => { |
| const id = 'n'+i; |
| nodes.push({id, label:e.tok, score:e.score, layer:e.layer, gate_sim:e.gate_sim, |
| r: 8 + Math.sqrt(e.score/maxScore)*10}); |
| links.push({source:'__center__', target:id, layer:e.layer}); |
| }); |
| |
| const sim = d3.forceSimulation(nodes) |
| .force('link', d3.forceLink(links).id(d=>d.id).distance(100)) |
| .force('charge', d3.forceManyBody().strength(-120)) |
| .force('center', d3.forceCenter(W/2, H/2)) |
| .force('collision', d3.forceCollide(d=>d.r+4)); |
| |
| const link = svg.append('g').selectAll('line').data(links).join('line') |
| .attr('stroke', d=>layerColor(d.layer)).attr('stroke-opacity',.5).attr('stroke-width',1.5); |
| |
| const node = svg.append('g').selectAll('g').data(nodes).join('g') |
| .call(d3.drag() |
| .on('start',(ev,d)=>{if(!ev.active)sim.alphaTarget(.3).restart();d.fx=d.x;d.fy=d.y}) |
| .on('drag',(ev,d)=>{d.fx=ev.x;d.fy=ev.y}) |
| .on('end',(ev,d)=>{if(!ev.active)sim.alphaTarget(0);d.fx=null;d.fy=null})); |
| |
| node.append('circle').attr('r',d=>d.r) |
| .attr('fill', d=>d.id==='__center__' ? C.blue : layerColor(d.layer||0)) |
| .attr('opacity',.85) |
| .on('mousemove',(ev,d)=>{ |
| if(d.id==='__center__') return; |
| showTooltip(`${d.label}<br>score: ${d.score?.toFixed(1)}<br>L${d.layer} sim:${d.gate_sim?.toFixed(3)}`,ev.pageX,ev.pageY); |
| }).on('mouseleave',hideTooltip); |
| |
| node.append('text').attr('dy','0.35em').attr('text-anchor','middle') |
| .attr('font-size', d=>d.id==='__center__'?13:10) |
| .attr('fill',C.text).attr('pointer-events','none') |
| .text(d=>d.label.substring(0,12)); |
| |
| sim.on('tick',()=>{ |
| link.attr('x1',d=>d.source.x).attr('y1',d=>d.source.y) |
| .attr('x2',d=>d.target.x).attr('y2',d=>d.target.y); |
| node.attr('transform',d=>`translate(${d.x},${d.y})`); |
| }); |
| |
| // Fix center node |
| nodes[0].fx = W/2; nodes[0].fy = H/2; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // TRACE β dual-axis line chart |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function runTrace() { |
| try { |
| const data = await api('/api/trace', { |
| prompt: document.getElementById('trace-prompt').value, |
| target: document.getElementById('trace-target').value |
| }); |
| drawTraceChart(data.stats, 'trace-chart'); |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawTraceChart(stats, chartId) { |
| const el = clearChart(chartId); |
| const W = el.clientWidth || 700, H = 300; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const m = {left:55, right:55, top:20, bottom:30}; |
| const w = W-m.left-m.right, h = H-m.top-m.bottom; |
| |
| // Find phase transition |
| let phaseL = -1, maxDrop = 0, prevRank = null; |
| stats.forEach(s=>{ |
| if(prevRank!==null && prevRank>5){ |
| const drop=(prevRank-s.rank)/prevRank; |
| if(drop>maxDrop){maxDrop=drop;phaseL=s.l;} |
| } |
| prevRank=s.rank; |
| }); |
| |
| const x = d3.scaleLinear().domain([0,stats.length-1]).range([0,w]); |
| const yRank = d3.scaleLog().domain([1, d3.max(stats,d=>d.rank)]).range([h,0]).clamp(true); |
| const yProb = d3.scaleLinear().domain([0, d3.max(stats,d=>d.prob)]).range([h,0]); |
| |
| const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`); |
| |
| // Axes |
| g.append('g').attr('transform',`translate(0,${h})`) |
| .call(d3.axisBottom(x).ticks(stats.length<=12?stats.length:8).tickFormat(d=>'L'+d)) |
| .selectAll('text').attr('fill',C.muted).attr('font-size',10); |
| g.append('g').call(d3.axisLeft(yRank).ticks(5).tickFormat(d3.format('d'))) |
| .selectAll('text').attr('fill',C.muted).attr('font-size',10); |
| g.append('g').attr('transform',`translate(${w},0)`) |
| .call(d3.axisRight(yProb).ticks(5).tickFormat(d3.format('.2e'))) |
| .selectAll('text').attr('fill',C.green).attr('font-size',10); |
| |
| // Phase marker |
| if(phaseL>=0){ |
| const px = x(phaseL); |
| g.append('line').attr('x1',px).attr('x2',px).attr('y1',0).attr('y2',h) |
| .attr('stroke',C.yellow).attr('stroke-dasharray','4,2').attr('stroke-width',1.5); |
| g.append('text').attr('x',px+4).attr('y',12).attr('fill',C.yellow).attr('font-size',10) |
| .text('β‘ PHASE L'+phaseL); |
| } |
| |
| // Rank line (log) |
| const lineRank = d3.line().x((_,i)=>x(i)).y(d=>yRank(Math.max(1,d.rank))); |
| g.append('path').datum(stats).attr('fill','none').attr('stroke',C.blue) |
| .attr('stroke-width',2).attr('d',lineRank); |
| |
| // Prob line |
| const lineProb = d3.line().x((_,i)=>x(i)).y(d=>yProb(d.prob)); |
| g.append('path').datum(stats).attr('fill','none').attr('stroke',C.green) |
| .attr('stroke-width',1.5).attr('stroke-dasharray','5,2').attr('d',lineProb); |
| |
| // Dots |
| g.selectAll('.dot').data(stats).join('circle').attr('class','dot') |
| .attr('cx',(_,i)=>x(i)).attr('cy',d=>yRank(Math.max(1,d.rank))) |
| .attr('r',3).attr('fill',C.blue) |
| .on('mousemove',(ev,d)=>showTooltip(`L${d.l} rank:${d.rank}<br>prob:${d.prob.toExponential(3)}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| |
| // Legend |
| const leg = g.append('g').attr('transform',`translate(${w-120},6)`); |
| leg.append('line').attr('x2',16).attr('stroke',C.blue).attr('stroke-width',2); |
| leg.append('text').attr('x',20).attr('y',4).attr('fill',C.blue).attr('font-size',10).text('rank (log, left)'); |
| leg.append('line').attr('y1',14).attr('x2',16).attr('y2',14).attr('stroke',C.green).attr('stroke-width',1.5).attr('stroke-dasharray','5,2'); |
| leg.append('text').attr('x',20).attr('y',18).attr('fill',C.green).attr('font-size',10).text('prob (right)'); |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // LOCATE |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function runLocate() { |
| try { |
| const data = await api('/api/locate', { |
| prompt: document.getElementById('loc-prompt').value, |
| subject: document.getElementById('loc-subject').value, |
| target: document.getElementById('loc-target').value |
| }); |
| drawTraceChart(data.trace, 'loc-trace-chart'); |
| drawLocateBars(data.layer_scores, data.phase_layer); |
| const sumEl = document.getElementById('loc-summary'); |
| sumEl.style.display='block'; |
| document.getElementById('loc-summary-text').textContent = |
| `Phase transition at L${data.phase_layer}. Subject token pos: ${data.subject_pos}. ` + |
| `Peak activation sim: L${data.layer_scores.reduce((a,b)=>b.max_sim>a.max_sim?b:a,data.layer_scores[0])?.layer} = ` + |
| `${data.layer_scores.reduce((a,b)=>b.max_sim>a.max_sim?b:a,data.layer_scores[0])?.max_sim?.toFixed(4)}`; |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawLocateBars(layerScores, phaseLayer) { |
| const el = document.getElementById('loc-bars'); |
| el.innerHTML = ''; |
| const maxSim = d3.max(layerScores, d=>d.max_sim) || 1; |
| layerScores.forEach(ls=>{ |
| const div = document.createElement('div'); |
| div.className = 'locate-bar'; |
| div.style.border = ls.layer===phaseLayer ? '1px solid '+C.yellow : '1px solid transparent'; |
| const pct = (ls.max_sim / maxSim * 100).toFixed(1); |
| div.innerHTML = `<span class="layer-lbl">L${ls.layer}</span>`+ |
| `<div class="bar-fill" style="width:${pct}%;background:${ls.layer===phaseLayer?C.yellow:C.purple}"></div>`+ |
| `<span class="sim-val">${ls.max_sim.toFixed(4)}</span>`; |
| div.onclick = ()=>{ showSlotDetail(ls); }; |
| el.appendChild(div); |
| }); |
| } |
| |
| function showSlotDetail(ls) { |
| const el = document.getElementById('loc-slots'); |
| el.style.display = 'block'; |
| let txt = `L${ls.layer} slot=${ls.best_slot} max_sim=${ls.max_sim}\n\nTop decoded tokens:\n`; |
| ls.top_tokens.forEach(t=>{ txt+=` ${t.tok} (${t.score})\n`; }); |
| el.textContent = txt; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // HEATMAP |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function runHeatmap() { |
| try { |
| const mode = document.querySelector('input[name="hm-mode"]:checked').value; |
| const data = await api('/api/gate_heatmap', { |
| entity: document.getElementById('hm-entity').value, |
| use_activation: mode==='activation', |
| prompt: document.getElementById('hm-prompt').value |
| }); |
| drawHeatmap(data.layers); |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawHeatmap(layers) { |
| const el = clearChart('hm-chart'); |
| if(!layers.length){el.textContent='No data';return;} |
| const nLayers = layers.length; |
| const nSlots = d3.max(layers, l=>l.slots.length); |
| const W = el.clientWidth || 700; |
| const cellW = Math.max(18, Math.floor((W-60)/nSlots)); |
| const cellH = 24; |
| const H = nLayers*cellH + 40; |
| el.style.height = H+'px'; |
| |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const g = svg.append('g').attr('transform','translate(50,20)'); |
| |
| const allSims = layers.flatMap(l=>l.slots.map(s=>s.sim)); |
| const color = d3.scaleSequential(d3.interpolateYlOrRd).domain([0, d3.max(allSims)]); |
| |
| layers.forEach((layer,li)=>{ |
| layer.slots.forEach((slot,si)=>{ |
| const rect = g.append('rect') |
| .attr('class','hm-cell') |
| .attr('x', si*cellW).attr('y', li*cellH) |
| .attr('width', cellW-2).attr('height', cellH-2) |
| .attr('rx', 2) |
| .attr('fill', color(slot.sim)) |
| .on('mousemove',(ev)=>showTooltip( |
| `L${layer.layer} slot ${slot.slot}<br>sim: ${slot.sim}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip) |
| .on('click',()=>showHmSlotDetail(layer.layer, slot)); |
| }); |
| // Row label |
| g.append('text').attr('x',-6).attr('y',li*cellH+cellH/2+4) |
| .attr('text-anchor','end').attr('fill',C.muted).attr('font-size',9).text('L'+layer.layer); |
| }); |
| } |
| |
| function showHmSlotDetail(layer, slot) { |
| const el = document.getElementById('hm-slot-detail'); |
| let txt = `Layer ${layer} Slot ${slot.slot} sim=${slot.sim}\n\nDecoded top tokens:\n`; |
| slot.top_tokens.forEach(t=>{ txt+=` "${t.tok}" score=${t.score}\n`; }); |
| el.textContent = txt; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // EDIT |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{ |
| r.addEventListener('change', ()=>{ |
| document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none'; |
| document.getElementById('smart-row').style.display = r.value==='SMART'?'block':'none'; |
| document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none'; |
| document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none'; |
| }); |
| }); |
| |
| async function runDryRun() { |
| try { |
| const data = await api('/api/dry_run', { |
| entity: document.getElementById('edit-entity').value, |
| new_target: document.getElementById('edit-new').value, |
| top_k: +document.getElementById('edit-k').value, |
| scale: +document.getElementById('edit-scale').value, |
| prompt: document.getElementById('edit-prompt').value || null |
| }); |
| const el = document.getElementById('dryrun-log'); |
| let txt = `Mode: ${data.mode} best_sim=${data.best_sim} would_patch=${data.would_patch}\n\n`; |
| data.candidates.forEach(c=>{ |
| txt += `L${c.layer} slot ${c.slot} sim=${c.sim} col_norm=${c.col_norm}\n`; |
| txt += ` current top: ${c.current_top.map(t=>t.tok).join(', ')}\n`; |
| }); |
| el.textContent = txt; |
| } catch(e) { alert(e.message); } |
| } |
| |
| let _before = null; |
| async function runEdit() { |
| const mode = document.querySelector('input[name="edit-mode"]:checked').value; |
| const body = { |
| entity: document.getElementById('edit-entity').value, |
| relation: document.getElementById('edit-relation').value, |
| old_target: document.getElementById('edit-old').value, |
| new_target: document.getElementById('edit-new').value, |
| mode, |
| alpha: +document.getElementById('edit-alpha').value, |
| top_k: +document.getElementById('edit-k').value, |
| scale: +document.getElementById('edit-scale').value, |
| prompt: document.getElementById('edit-prompt').value || null, |
| from_concept:document.getElementById('ss-from').value, |
| to_concept: document.getElementById('ss-to').value, |
| strength: +document.getElementById('ss-strength').value, |
| }; |
| if(mode==='SMART'){ |
| try { |
| const r = await api('/api/smart_edit', { |
| prompt: document.getElementById('smart-prompt').value, |
| subject: document.getElementById('edit-entity').value, |
| relation: document.getElementById('edit-relation').value, |
| old_target: document.getElementById('smart-old').value, |
| new_target: document.getElementById('edit-new').value, |
| top_layers: +document.getElementById('smart-layers').value, |
| slots_per_layer: +document.getElementById('smart-slots').value, |
| scale: +document.getElementById('edit-scale').value, |
| noise_std: 0.1, alpha: 0.4, beta: 0.4, gamma: 0.2, |
| }); |
| drawBeforeAfterChart(r.before, r.after); |
| let log = r.debug_log.join('\n'); |
| log += '\n\nUsed layers:\n'; |
| r.used_layers.forEach(l=>{ log+=` L${l.layer} slots=[${l.slots.join(',')}] combined=${l.combined}\n`; }); |
| log += '\nDelta:\n'; |
| r.delta.slice(0,8).forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; }); |
| document.getElementById('edit-log').textContent = log; |
| updatePatchCount(); |
| } catch(e) { alert(e.message); } |
| return; |
| } |
| if(mode==='MULTI-EDIT'){ |
| try { |
| body.facts = JSON.parse(document.getElementById('multi-json').value); |
| const r = await api('/api/multi_edit', body.facts); |
| document.getElementById('edit-log').textContent = JSON.stringify(r, null, 2); |
| updatePatchCount(); |
| } catch(e) { alert(e.message); } |
| return; |
| } |
| try { |
| const r = await api('/api/edit', body); |
| _before = r.before; |
| drawBeforeAfterChart(r.before, r.after); |
| let log = r.debug_log.join('\n'); |
| log += '\n\nDelta:\n'; |
| r.delta.forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; }); |
| document.getElementById('edit-log').textContent = log; |
| updatePatchCount(); |
| } catch(e) { alert(e.message); } |
| } |
| |
| function drawBeforeAfterChart(before, after) { |
| const el = clearChart('edit-chart'); |
| const tokens = [...new Set([...before.map(d=>d.token), ...after.map(d=>d.token)])]; |
| const bMap = Object.fromEntries(before.map(d=>[d.token,d.prob])); |
| const aMap = Object.fromEntries(after.map(d=>[d.token,d.prob])); |
| const data = tokens.map(t=>({token:t,before:bMap[t]||0,after:aMap[t]||0})) |
| .sort((a,b)=>b.after-a.after); |
| const W = el.clientWidth || 700, H = 40+data.length*40; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const m = {left:110,right:10,top:20,bottom:10}; |
| const w = W-m.left-m.right; |
| const maxP = d3.max([...before,...after],d=>d.prob); |
| const x = d3.scaleLinear().domain([0,maxP]).range([0,w]); |
| const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`); |
| |
| data.forEach((d,i)=>{ |
| const y = i*40; |
| g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end') |
| .attr('fill',C.text).attr('font-size',12).text(d.token||'(empty)'); |
| // before |
| g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.before)) |
| .attr('height',14).attr('rx',2).attr('fill',C.blue).attr('opacity',.5); |
| // after |
| const delta = d.after - d.before; |
| g.append('rect').attr('x',0).attr('y',y+18).attr('width',x(d.after)) |
| .attr('height',14).attr('rx',2).attr('fill',delta>=0?C.green:C.red).attr('opacity',.8); |
| g.append('text').attr('x',x(d.after)+4).attr('y',y+30) |
| .attr('fill',delta>=0?C.green:C.red).attr('font-size',10) |
| .text((delta>0?'+':'')+delta.toFixed(4)); |
| }); |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // PATCHES |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| async function loadPatches() { |
| try { |
| const r = await api('/api/patches'); |
| const el = document.getElementById('patches-list'); |
| if(!r.patches.length){ el.innerHTML='<div style="color:var(--muted)">No patches.</div>'; return; } |
| el.innerHTML = r.patches.map((p,i)=>{ |
| const type = p.type.replace('_',' '); |
| const cls = 'badge-'+p.type.replace('_UPDATE','').replace('PRECISE_','PRECISE'); |
| return `<div class="patch-card"> |
| <div style="display:flex;align-items:center"> |
| <span class="patch-badge ${cls}">${type}</span> |
| <span>${p.entity}${p.relation?' Β· '+p.relation:''}${p.new_target?' β '+p.new_target:''}</span> |
| <span style="color:var(--muted);margin-left:10px;font-size:10px">${p.ops_count} op(s)</span> |
| </div> |
| <button class="patch-del" onclick="deletePatch(${i})">β</button> |
| </div>`; |
| }).join(''); |
| } catch(e) { alert(e.message); } |
| } |
| |
| async function deletePatch(i) { |
| try { |
| await fetch('/api/patches/'+i, {method:'DELETE'}); |
| loadPatches(); updatePatchCount(); |
| } catch(e) { alert(e.message); } |
| } |
| |
| async function savePatches() { |
| const path = prompt('Save to path:', 'patches.json'); |
| if(!path) return; |
| try { await api('/api/save', {path}); alert('Saved to '+path); } |
| catch(e) { alert(e.message); } |
| } |
| |
| async function compileModel() { |
| const dir = prompt('Output directory:', './vindex_compiled'); |
| if(!dir) return; |
| try { await api('/api/compile', {output_dir:dir}); alert('Compiled to '+dir); } |
| catch(e) { alert(e.message); } |
| } |
| |
| async function resetModel() { |
| if(!confirm('Reset to base weights? All patches discarded.')) return; |
| try { |
| await api('/api/reset', {}); |
| loadPatches(); updatePatchCount(); |
| alert('Reset done.'); |
| } catch(e) { alert(e.message); } |
| } |
| |
| async function updatePatchCount() { |
| try { |
| const r = await api('/api/status'); |
| setStatus(r.loaded, r.info.split(' | ')[0], r.patches_count); |
| } catch(e){} |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // SMART LOCATE |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| let _slData = null; |
| |
| async function runSmartLocate() { |
| const st = document.getElementById('sl-status'); |
| st.textContent = 'β³ Running gradient pass + causal sweep (may take ~20s for large models)β¦'; |
| try { |
| const data = await api('/api/smart_locate', { |
| prompt: document.getElementById('sl-prompt').value, |
| subject: document.getElementById('sl-subject').value, |
| target: document.getElementById('sl-target').value, |
| alpha: +document.getElementById('sl-alpha').value, |
| beta: +document.getElementById('sl-beta').value, |
| gamma: +document.getElementById('sl-gamma').value, |
| noise_std: +document.getElementById('sl-noise').value, |
| }); |
| _slData = data; |
| st.textContent = `β Done. clean_prob=${data.clean_prob.toFixed(4)} corrupt_prob=${data.corrupt_prob.toFixed(4)}`; |
| drawSmartLocateChart(data.ranked_layers); |
| showSmartRec(data); |
| buildSlTable(data.ranked_layers); |
| } catch(e) { st.textContent = 'β '+e.message; } |
| } |
| |
| async function runGradientOnly() { |
| const st = document.getElementById('sl-status'); |
| st.textContent = 'β³ Running gradient passβ¦'; |
| try { |
| const data = await api('/api/gradient_scores', { |
| prompt: document.getElementById('sl-prompt').value, |
| target: document.getElementById('sl-target').value, |
| }); |
| st.textContent = `β Gradient done. ${data.layer_scores.length} KB layers.`; |
| // Draw gradient-only bars |
| drawGradOnlyChart(data.layer_scores); |
| } catch(e) { st.textContent = 'β '+e.message; } |
| } |
| |
| async function runCausalOnly() { |
| const st = document.getElementById('sl-status'); |
| st.textContent = 'β³ Running causal patch traceβ¦'; |
| try { |
| const data = await api('/api/causal_trace', { |
| prompt: document.getElementById('sl-prompt').value, |
| subject: document.getElementById('sl-subject').value, |
| target: document.getElementById('sl-target').value, |
| noise_std: +document.getElementById('sl-noise').value, |
| }); |
| st.textContent = `β Causal done. clean=${data.clean_prob.toFixed(4)} corrupt=${data.corrupt_prob.toFixed(4)}`; |
| drawCausalOnlyChart(data.results); |
| } catch(e) { st.textContent = 'β '+e.message; } |
| } |
| |
| async function runCollateral() { |
| const prompt = document.getElementById('sl-coll-prompt').value; |
| try { |
| const data = await api('/api/infer', { prompt, top_k: 5 }); |
| const el = document.getElementById('sl-coll-out'); |
| el.textContent = `"${prompt}"\n` + |
| data.results.map(r=>` ${r.token.padEnd(18)} ${r.prob.toFixed(4)}`).join('\n'); |
| } catch(e) { document.getElementById('sl-coll-out').textContent = 'β '+e.message; } |
| } |
| |
| function drawSmartLocateChart(ranked) { |
| // Sort by layer for chart display |
| const byLayer = [...ranked].sort((a,b)=>a.layer-b.layer); |
| const el = clearChart('sl-chart'); |
| const W = el.clientWidth || 700, H = 40 + byLayer.length * 34; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const m = {left:50,right:110,top:20,bottom:20}; |
| const w = W-m.left-m.right; |
| const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`); |
| |
| // Each bar = 3 stacked segments (normalized: gate_sim_n, grad_norm_n, causal_n) |
| // Each segment width = signal_n * (w/3) so max of each is w/3 |
| const segW = w / 3; |
| |
| byLayer.forEach((d,i)=>{ |
| const y = i*34; |
| // Label |
| g.append('text').attr('x',-6).attr('y',y+17).attr('text-anchor','end') |
| .attr('fill', d.layer===(_slData?.recommendation?.layer) ? C.yellow : C.muted) |
| .attr('font-size',10).text('L'+d.layer); |
| |
| // gate_sim segment |
| g.append('rect').attr('x',0).attr('y',y+4).attr('width',d.gate_sim_n*segW) |
| .attr('height',12).attr('rx',2).attr('fill',C.blue).attr('opacity',.8) |
| .on('mousemove',(ev)=>showTooltip(`L${d.layer} gate_sim: ${d.gate_sim}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| // grad_norm segment |
| g.append('rect').attr('x',segW).attr('y',y+4).attr('width',d.grad_norm_n*segW) |
| .attr('height',12).attr('rx',2).attr('fill',C.green).attr('opacity',.8) |
| .on('mousemove',(ev)=>showTooltip(`L${d.layer} grad_norm: ${d.grad_norm}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| // causal segment |
| g.append('rect').attr('x',segW*2).attr('y',y+4).attr('width',d.causal_n*segW) |
| .attr('height',12).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8) |
| .on('mousemove',(ev)=>showTooltip(`L${d.layer} causal_IE: ${d.causal_effect}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| // combined score label |
| g.append('text').attr('x',w+6).attr('y',y+14) |
| .attr('fill', d.combined===Math.max(...ranked.map(r=>r.combined)) ? C.yellow : C.muted) |
| .attr('font-size',10).text(d.combined.toFixed(3)); |
| }); |
| |
| // Axis labels |
| const ax = g.append('g').attr('transform',`translate(0,${byLayer.length*34})`); |
| ax.append('text').attr('x',segW/2).attr('y',14).attr('text-anchor','middle') |
| .attr('fill',C.blue).attr('font-size',9).text('gate_sim'); |
| ax.append('text').attr('x',segW*1.5).attr('y',14).attr('text-anchor','middle') |
| .attr('fill',C.green).attr('font-size',9).text('grad_norm'); |
| ax.append('text').attr('x',segW*2.5).attr('y',14).attr('text-anchor','middle') |
| .attr('fill',C.yellow).attr('font-size',9).text('causal IE'); |
| |
| // Section dividers |
| [segW,segW*2].forEach(x=>{ |
| g.append('line').attr('x1',x).attr('x2',x).attr('y1',0).attr('y2',byLayer.length*34) |
| .attr('stroke',C.border).attr('stroke-width',1).attr('stroke-dasharray','3,2'); |
| }); |
| } |
| |
| function drawGradOnlyChart(layerScores) { |
| const el = clearChart('sl-chart'); |
| const W = el.clientWidth || 700, H = 40 + layerScores.length * 28; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const m = {left:50,right:80,top:20,bottom:10}; |
| const w = W-m.left-m.right; |
| const maxG = d3.max(layerScores, d=>d.max_grad) || 1; |
| const x = d3.scaleLinear().domain([0,maxG]).range([0,w]); |
| const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`); |
| layerScores.forEach((d,i)=>{ |
| const y=i*28; |
| g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end') |
| .attr('fill',C.muted).attr('font-size',10).text('L'+d.layer); |
| g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.max_grad)) |
| .attr('height',16).attr('rx',2).attr('fill',C.green).attr('opacity',.8) |
| .on('mousemove',(ev)=>showTooltip(`L${d.layer} max_grad: ${d.max_grad}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| g.append('text').attr('x',x(d.max_grad)+4).attr('y',y+14) |
| .attr('fill',C.green).attr('font-size',9).text(d.max_grad.toExponential(2)); |
| }); |
| } |
| |
| function drawCausalOnlyChart(results) { |
| const el = clearChart('sl-chart'); |
| const W = el.clientWidth || 700, H = 40 + results.length * 28; |
| el.style.height = H+'px'; |
| const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H); |
| const m = {left:50,right:80,top:20,bottom:10}; |
| const w = W-m.left-m.right; |
| const maxIE = Math.max(d3.max(results, d=>d.indirect_effect), 0.001); |
| const x = d3.scaleLinear().domain([0,maxIE]).range([0,w]); |
| const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`); |
| results.forEach((d,i)=>{ |
| const y=i*28; const ie=Math.max(0,d.indirect_effect); |
| g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end') |
| .attr('fill',C.muted).attr('font-size',10).text('L'+d.layer); |
| g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(ie)) |
| .attr('height',16).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8) |
| .on('mousemove',(ev)=>showTooltip(`L${d.layer} IE: ${d.indirect_effect} patch_p: ${d.patch_prob}`,ev.pageX,ev.pageY)) |
| .on('mouseleave',hideTooltip); |
| g.append('text').attr('x',x(ie)+4).attr('y',y+14) |
| .attr('fill',C.yellow).attr('font-size',9).text(d.indirect_effect.toFixed(5)); |
| }); |
| } |
| |
| function showSmartRec(data) { |
| const rec = data.recommendation; |
| if(!rec){ document.getElementById('sl-rec').textContent='No recommendation.'; return; } |
| let txt = `β
Best layer: L${rec.layer} combined=${rec.combined}\n\n`; |
| txt += ` gate_sim: ${rec.gate_sim} (norm ${rec.gate_sim_n})\n`; |
| txt += ` grad_norm: ${rec.grad_norm} (norm ${rec.grad_norm_n})\n`; |
| txt += ` causal_effect: ${rec.causal_effect} (norm ${rec.causal_n})\n`; |
| if(rec.best_slots.length){ |
| txt += `\nTop gradient slots in L${rec.layer}:\n`; |
| rec.best_slots.forEach(s=>{ txt+=` slot ${s.slot} grad_norm=${s.grad_norm}\n`; }); |
| } |
| txt += `\nPhase layer (trace): L${data.phase_layer}\n`; |
| txt += `Subject pos: ${data.subject_pos}\n`; |
| txt += `clean_prob: ${data.clean_prob} corrupt_prob: ${data.corrupt_prob}`; |
| document.getElementById('sl-rec').textContent = txt; |
| } |
| |
| function buildSlTable(ranked) { |
| const el = document.getElementById('sl-table'); |
| const maxC = Math.max(...ranked.map(r=>r.combined)); |
| let html = `<table style="width:100%;border-collapse:collapse;font-size:11px"> |
| <thead><tr style="color:var(--muted);border-bottom:1px solid var(--border)"> |
| <th style="padding:4px 8px;text-align:left">Layer</th> |
| <th style="padding:4px 8px;text-align:right;color:${C.blue}">gate_sim</th> |
| <th style="padding:4px 8px;text-align:right;color:${C.green}">grad_norm</th> |
| <th style="padding:4px 8px;text-align:right;color:${C.yellow}">causal IE</th> |
| <th style="padding:4px 8px;text-align:right">combined β
</th> |
| <th style="padding:4px 8px;text-align:left;color:var(--muted)">top grad slots</th> |
| </tr></thead><tbody>`; |
| ranked.forEach(r=>{ |
| const hi = r.combined===maxC ? `background:rgba(210,153,34,0.08)` : ''; |
| const slots = r.best_slots.slice(0,3).map(s=>s.slot).join(', '); |
| html+=`<tr style="${hi};border-bottom:1px solid var(--border)"> |
| <td style="padding:4px 8px;color:${r.combined===maxC?C.yellow:C.text}">L${r.layer}</td> |
| <td style="padding:4px 8px;text-align:right;color:${C.blue}">${r.gate_sim}</td> |
| <td style="padding:4px 8px;text-align:right;color:${C.green}">${r.grad_norm.toExponential(2)}</td> |
| <td style="padding:4px 8px;text-align:right;color:${C.yellow}">${r.causal_effect.toFixed(5)}</td> |
| <td style="padding:4px 8px;text-align:right;font-weight:700">${r.combined}</td> |
| <td style="padding:4px 8px;color:var(--muted)">${slots}</td> |
| </tr>`; |
| }); |
| html += '</tbody></table>'; |
| el.innerHTML = html; |
| } |
| |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| // INIT |
| // βββββββββββββββββββββββββββββββββββββββββββββββ |
| checkStatus(); |
| </script> |
| </body> |
| </html> |
| """ |
|
|
| |
| |
| |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel |
| import uvicorn |
|
|
| app = FastAPI(title="VINDEX") |
|
|
| _vi: VIndex | None = None |
|
|
|
|
| def _require(): |
| if _vi is None: |
| raise HTTPException(status_code=400, detail="No model loaded. POST /api/load first.") |
| return _vi |
|
|
|
|
| |
|
|
| class LoadReq(BaseModel): |
| model_name: str = "distilgpt2" |
| device: str = "auto" |
|
|
| class InferReq(BaseModel): |
| prompt: str |
| top_k: int = 10 |
|
|
| class DescribeReq(BaseModel): |
| entity: str |
| top_k: int = 10 |
|
|
| class TraceReq(BaseModel): |
| prompt: str |
| target: str |
|
|
| class LocateReq(BaseModel): |
| prompt: str |
| subject: str |
| target: str |
|
|
| class HeatmapReq(BaseModel): |
| entity: str |
| top_slots: int = 20 |
| use_activation: bool = False |
| prompt: Optional[str] = None |
|
|
| class GradientReq(BaseModel): |
| prompt: str |
| target: str |
|
|
| class CausalTraceReq(BaseModel): |
| prompt: str |
| subject: str |
| target: str |
| noise_std: float = 0.1 |
|
|
| class SmartLocateReq(BaseModel): |
| prompt: str |
| subject: str |
| target: str |
| alpha: float = 0.4 |
| beta: float = 0.3 |
| gamma: float = 0.3 |
| noise_std: float = 0.1 |
|
|
| class SmartEditReq(BaseModel): |
| prompt: str |
| subject: str |
| relation: str = "" |
| old_target: str |
| new_target: str |
| top_layers: int = 3 |
| slots_per_layer: int = 2 |
| scale: float = 1.5 |
| noise_std: float = 0.1 |
| alpha: float = 0.4 |
| beta: float = 0.4 |
| gamma: float = 0.2 |
|
|
| class DryRunReq(BaseModel): |
| entity: str |
| new_target: str |
| top_k: int = 3 |
| scale: float = 1.0 |
| prompt: Optional[str] = None |
|
|
| class EditReq(BaseModel): |
| entity: str |
| relation: str = "" |
| old_target: str = "" |
| new_target: str |
| mode: str = "UPDATE" |
| alpha: float = 0.25 |
| top_k: int = 3 |
| scale: float = 1.0 |
| prompt: Optional[str] = None |
| from_concept: str = "" |
| to_concept: str = "" |
| strength: float = 0.5 |
|
|
| class SuppressReq(BaseModel): |
| entity: str |
| top_k: int = 3 |
| factor: float = 0.0 |
|
|
| class AmplifyReq(BaseModel): |
| entity: str |
| top_k: int = 3 |
| factor: float = 2.0 |
|
|
| class StyleShiftReq(BaseModel): |
| anchor: str |
| from_concept: str |
| to_concept: str |
| top_k: int = 3 |
| strength: float = 0.5 |
|
|
| class SaveReq(BaseModel): |
| path: str = "patches.json" |
|
|
| class CompileReq(BaseModel): |
| output_dir: str = "./vindex_compiled" |
|
|
| class MultiEditFact(BaseModel): |
| entity: str |
| relation: str = "" |
| new_target: str |
| prompt: Optional[str] = None |
|
|
|
|
| |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(): |
| return HTML_PAGE |
|
|
|
|
| @app.post("/api/load") |
| async def api_load(req: LoadReq): |
| global _vi |
| device = None if req.device == "auto" else req.device |
| try: |
| _vi = VIndex(req.model_name, device=device) |
| return {"ok": True, "info": _vi.info, |
| "n_layers": _vi.arch.n_layers, |
| "kb_start": _vi.kb_start, "kb_end": _vi.kb_end} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/api/status") |
| async def api_status(): |
| if _vi is None: |
| return {"loaded": False, "info": "", "patches_count": 0, "kb_start": 0, "kb_end": 0, "n_layers": 0} |
| return {"loaded": True, "info": _vi.info, "patches_count": len(_vi.patches), |
| "kb_start": _vi.kb_start, "kb_end": _vi.kb_end, "n_layers": _vi.arch.n_layers} |
|
|
|
|
| @app.post("/api/infer") |
| async def api_infer(req: InferReq): |
| vi = _require() |
| return {"results": vi.infer(req.prompt, top_k=req.top_k)} |
|
|
|
|
| @app.post("/api/describe") |
| async def api_describe(req: DescribeReq): |
| vi = _require() |
| edges = vi.describe(req.entity, top_k=req.top_k) |
| out = [{"tok":e["tok"],"score":round(e["score"],2),"layer":e["layer"],"gate_sim":round(e["gate_sim"],4)} |
| for e in edges] |
| return {"edges": out} |
|
|
|
|
| @app.post("/api/trace") |
| async def api_trace(req: TraceReq): |
| vi = _require() |
| return {"stats": vi.trace(req.prompt, req.target)} |
|
|
|
|
| @app.post("/api/locate") |
| async def api_locate(req: LocateReq): |
| vi = _require() |
| return vi.locate(req.prompt, req.subject, req.target) |
|
|
|
|
| @app.post("/api/gradient_scores") |
| async def api_gradient_scores(req: GradientReq): |
| vi = _require() |
| return vi.gradient_slot_scores(req.prompt, req.target) |
|
|
|
|
| @app.post("/api/causal_trace") |
| async def api_causal_trace(req: CausalTraceReq): |
| vi = _require() |
| return vi.causal_patch_trace(req.prompt, req.subject, req.target, |
| noise_std=req.noise_std) |
|
|
|
|
| @app.post("/api/smart_locate") |
| async def api_smart_locate(req: SmartLocateReq): |
| vi = _require() |
| return vi.smart_locate(req.prompt, req.subject, req.target, |
| alpha=req.alpha, beta=req.beta, gamma=req.gamma, |
| noise_std=req.noise_std) |
|
|
|
|
| @app.post("/api/smart_edit") |
| async def api_smart_edit(req: SmartEditReq): |
| vi = _require() |
| prompt_str = req.prompt or f"The {req.relation} of {req.subject} is" |
| before = vi.infer(prompt_str, top_k=5) |
| log: List[str] = [] |
| try: |
| result = vi.smart_edit( |
| prompt_str, req.subject, req.relation, req.old_target, req.new_target, |
| top_layers=req.top_layers, slots_per_layer=req.slots_per_layer, |
| scale=req.scale, noise_std=req.noise_std, |
| alpha=req.alpha, beta=req.beta, gamma=req.gamma, log=log |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| after = vi.infer(prompt_str, top_k=5) |
| b_map = {d["token"]: d["prob"] for d in before} |
| a_map = {d["token"]: d["prob"] for d in after} |
| all_toks = set(b_map) | set(a_map) |
| delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0), |
| "delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks], |
| key=lambda x: -abs(x["delta"])) |
| return {"before": before, "after": after, "delta": delta, |
| "debug_log": log, "used_layers": result["used_layers"], |
| "smart_locate": result["smart_locate"]} |
|
|
|
|
| @app.post("/api/gate_heatmap") |
| async def api_gate_heatmap(req: HeatmapReq): |
| vi = _require() |
| return vi.gate_heatmap(req.entity, use_activation=req.use_activation, prompt=req.prompt) |
|
|
|
|
| @app.post("/api/dry_run") |
| async def api_dry_run(req: DryRunReq): |
| vi = _require() |
| return vi.dry_run(req.entity, req.new_target, top_k=req.top_k, scale=req.scale, prompt=req.prompt) |
|
|
|
|
| @app.post("/api/edit") |
| async def api_edit(req: EditReq): |
| vi = _require() |
| prompt_str = f"The {req.relation} of {req.entity} is" |
| before = vi.infer(prompt_str, top_k=5) |
| log: List[str] = [] |
|
|
| try: |
| mode = req.mode.upper() |
| if mode == "PRECISE": |
| vi.precise_update(req.prompt or prompt_str, req.entity, req.relation, |
| req.new_target, top_k=req.top_k, scale=req.scale, log=log) |
| elif mode == "INSERT": |
| vi.insert(req.entity, req.relation, req.new_target, |
| alpha=req.alpha, spread=req.top_k, log=log) |
| elif mode == "SUPPRESS": |
| vi.suppress(req.entity, top_k=req.top_k, log=log) |
| elif mode == "AMPLIFY": |
| vi.amplify(req.entity, top_k=req.top_k, log=log) |
| elif mode == "STYLE-SHIFT": |
| vi.style_shift(req.entity, req.from_concept, req.to_concept, |
| top_k=req.top_k, strength=req.strength, log=log) |
| else: |
| vi.update(req.entity, req.relation, req.new_target, |
| top_k=req.top_k, scale=req.scale, log=log) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| after = vi.infer(prompt_str, top_k=5) |
| b_map = {d["token"]: d["prob"] for d in before} |
| a_map = {d["token"]: d["prob"] for d in after} |
| all_toks = set(b_map) | set(a_map) |
| delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0), |
| "delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks], |
| key=lambda x: -abs(x["delta"])) |
| return {"before": before, "after": after, "delta": delta, "debug_log": log, |
| "ops": len(vi.patches[-1]["ops"]) if vi.patches else 0} |
|
|
|
|
| @app.post("/api/multi_edit") |
| async def api_multi_edit(facts: List[MultiEditFact]): |
| vi = _require() |
| return vi.multi_edit([f.model_dump() for f in facts]) |
|
|
|
|
| @app.post("/api/suppress") |
| async def api_suppress(req: SuppressReq): |
| vi = _require() |
| return vi.suppress(req.entity, top_k=req.top_k, factor=req.factor) |
|
|
|
|
| @app.post("/api/amplify") |
| async def api_amplify(req: AmplifyReq): |
| vi = _require() |
| return vi.amplify(req.entity, top_k=req.top_k, factor=req.factor) |
|
|
|
|
| @app.post("/api/style_shift") |
| async def api_style_shift(req: StyleShiftReq): |
| vi = _require() |
| return vi.style_shift(req.anchor, req.from_concept, req.to_concept, |
| top_k=req.top_k, strength=req.strength) |
|
|
|
|
| @app.get("/api/patches") |
| async def api_patches(): |
| vi = _require() |
| out = [] |
| for i, p in enumerate(vi.patches): |
| out.append({ |
| "i": i, "type": p["type"], |
| "entity": p.get("entity",""), |
| "relation": p.get("relation",""), |
| "new_target": p.get("new_target","") or p.get("target",""), |
| "ops_count": len(p.get("ops",[])) |
| }) |
| return {"patches": out} |
|
|
|
|
| @app.delete("/api/patches/{idx}") |
| async def api_delete_patch(idx: int): |
| vi = _require() |
| if idx < 0 or idx >= len(vi.patches): |
| raise HTTPException(status_code=404, detail="Patch index out of range") |
| vi.patches.pop(idx) |
| vi._apply_all_patches() |
| return {"ok": True} |
|
|
|
|
| @app.post("/api/reset") |
| async def api_reset(): |
| vi = _require() |
| vi.reset() |
| return {"ok": True} |
|
|
|
|
| @app.post("/api/save") |
| async def api_save(req: SaveReq): |
| vi = _require() |
| vi.save_patch(req.path) |
| return {"ok": True, "path": req.path} |
|
|
|
|
| @app.post("/api/compile") |
| async def api_compile(req: CompileReq): |
| vi = _require() |
| vi.compile_to(req.output_dir) |
| return {"ok": True, "output_dir": req.output_dir} |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import argparse, webbrowser, threading, time |
|
|
| ap = argparse.ArgumentParser() |
| ap.add_argument("--port", type=int, default=8787) |
| ap.add_argument("--model", default=None) |
| ap.add_argument("--device", default=None) |
| ap.add_argument("--no-browser", action="store_true") |
| args, _ = ap.parse_known_args() |
|
|
| if args.model: |
| print(f"Pre-loading {args.model}β¦") |
| _vi = VIndex(args.model, device=args.device) |
| print("Done.", _vi.info) |
|
|
| if not args.no_browser: |
| def _open(): |
| time.sleep(1.2) |
| webbrowser.open(f"http://localhost:{args.port}") |
| threading.Thread(target=_open, daemon=True).start() |
|
|
| uvicorn.run(app, host="0.0.0.0", port=args.port) |