#!/usr/bin/env python3 """ VINDEX — FastAPI + D3 single-file LLM knowledge editor Phase 1: Engine extensions | Phase 2: FastAPI | Phase 3: D3 frontend pip install transformers torch fastapi uvicorn pydantic python temp.py """ # ══════════════════════════════════════════════════════════════ # ENGINE # ══════════════════════════════════════════════════════════════ import re, json from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM class ArchAdapter: def __init__(self, model): self.model = model self.style = self._detect_style() self.n_layers = self._count_layers() def _detect_style(self): t = self.model.config.model_type if t in ("gpt2","gpt_neo","gpt_neox"): return "gpt2" if t in ("llama","mistral","qwen2","gemma","gemma2","phi3", "falcon","codellama","deepseek","internlm2"): return "gated" try: if hasattr(self._layer(0).mlp, "gate_proj"): return "gated" except: pass return "gpt2" def _layer(self, i): m = self.model if hasattr(m,"transformer") and hasattr(m.transformer,"h"): return m.transformer.h[i] if hasattr(m,"model") and hasattr(m.model,"layers"): return m.model.layers[i] raise ValueError("Unknown model structure") def _count_layers(self): m = self.model if hasattr(m,"transformer") and hasattr(m.transformer,"h"): return len(m.transformer.h) if hasattr(m,"model") and hasattr(m.model,"layers"): return len(m.model.layers) raise ValueError("Cannot count layers") def get_ffn_weights(self, li): layer = self._layer(li) if self.style == "gpt2": return layer.mlp.c_fc.weight.detach().T, layer.mlp.c_proj.weight.detach().T return layer.mlp.gate_proj.weight.detach(), layer.mlp.down_proj.weight.detach() def set_ffn_weights(self, li, Wg, Wd): layer = self._layer(li) with torch.no_grad(): if self.style == "gpt2": layer.mlp.c_fc.weight.copy_(Wg.T) layer.mlp.c_proj.weight.copy_(Wd.T) else: layer.mlp.gate_proj.weight.copy_(Wg) layer.mlp.down_proj.weight.copy_(Wd) def get_embedding(self): return self.model.get_input_embeddings().weight.detach() def get_unembedding(self): if hasattr(self.model,"lm_head"): return self.model.lm_head.weight.detach() return self.get_embedding() class VIndex: def __init__(self, model_name: str, device: Optional[str] = None): self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.patches: List[Dict] = [] self._base_weights: Optional[Dict] = None self.tok = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32).to(self.device) self.model.eval() if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token self.arch = ArchAdapter(self.model) n = self.arch.n_layers self.kb_start = n // 3 self.kb_end = n @property def info(self): return (f"{self.model_name} | {self.arch.n_layers} layers | " f"style={self.arch.style} | kb=L{self.kb_start}-L{self.kb_end-1} | {self.device}") # ── utils ────────────────────────────────────────────────── def embed(self, text: str): ids = self.tok.encode(text, add_special_tokens=False) if not ids: raise ValueError(f"Cannot tokenize: {text!r}") E = self.arch.get_embedding().to(self.device) return E[ids].mean(0) def decode_down_col(self, col, top_k=5): scores = self.arch.get_unembedding().to(self.device) @ col.to(self.device) top = scores.topk(top_k) return [(self.tok.decode([i.item()]).strip(), v.item()) for i,v in zip(top.indices,top.values) if v.item()>0] def token_id(self, word: str) -> int: ids = self.tok.encode(word, add_special_tokens=False) return ids[0] if ids else 0 def _forward(self, prompt: str): inputs = self.tok(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.model(**inputs) return out.logits[0,-1] # ── Phase 1: new engine methods ──────────────────────────── def _get_subject_activations(self, prompt: str, subject: str ) -> Tuple[Dict[int, torch.Tensor], int]: """Capture h_L[last_subject_token_pos] at every layer via forward hooks.""" enc = self.tok(prompt, return_tensors="pt").to(self.device) ids = enc["input_ids"][0].tolist() # Subsequence match to find subject position subj_ids = self.tok.encode(subject, add_special_tokens=False) subject_pos = 0 for start in range(len(ids) - len(subj_ids) + 1): if ids[start:start+len(subj_ids)] == subj_ids: subject_pos = start + len(subj_ids) - 1 break else: # Fallback: find any single token from subject_ids for si in subj_ids: if si in ids: subject_pos = ids.index(si) break activations: Dict[int, torch.Tensor] = {} handles = [] def make_hook(li): def hook(m, inp, out): h = out[0] if isinstance(out, tuple) else out activations[li] = h[0, subject_pos].detach().clone() return hook for li in range(self.arch.n_layers): handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) with torch.no_grad(): self.model(**enc) for h in handles: h.remove() return activations, subject_pos def locate(self, prompt: str, subject: str, target: str) -> Dict: """Diagnostic: combines trace + activation-guided similarity scan.""" trace_stats = self.trace(prompt, target) # Find phase_layer: biggest relative rank drop phase_layer = 0 best_drop = 0.0 prev_rank = None for s in trace_stats: if prev_rank is not None and prev_rank > 5: drop = (prev_rank - s["rank"]) / prev_rank if drop > best_drop: best_drop = drop phase_layer = s["l"] prev_rank = s["rank"] activations, subject_pos = self._get_subject_activations(prompt, subject) layer_scores = [] for li in range(self.kb_start, self.kb_end): h_L = activations.get(li) if h_L is None: layer_scores.append({"layer": li, "max_sim": 0.0, "best_slot": -1, "top_tokens": []}) continue Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) h_n = F.normalize(h_L, dim=0) sims = F.normalize(Wg, dim=1) @ h_n best_slot = int(sims.argmax().item()) max_sim = float(sims[best_slot].item()) Wd = Wd.to(self.device) top_tokens = self.decode_down_col(Wd[:, best_slot], top_k=3) layer_scores.append({ "layer": li, "max_sim": round(max_sim, 4), "best_slot": best_slot, "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_tokens] }) return { "phase_layer": phase_layer, "subject_pos": subject_pos, "layer_scores": layer_scores, "trace": trace_stats, } def precise_update(self, prompt: str, subject: str, relation: str, new_target: str, top_k: int = 3, scale: float = 1.0, log: Optional[List[str]] = None): """Activation-guided update: uses h_L[subject_pos] instead of embed(subject).""" if log is None: log = [] if not prompt: return self.update(subject, relation, new_target, top_k=top_k, scale=scale, log=log) self._snapshot() activations, subject_pos = self._get_subject_activations(prompt, subject) tv = self.embed(new_target) log.append(f"PRECISE_UPDATE: '{subject}' -[{relation}]-> '{new_target}'") log.append(f" subject_pos={subject_pos} scale={scale} top_k={top_k}") candidates = [] for li in range(self.kb_start, self.kb_end): h_L = activations.get(li) if h_L is None: continue Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) h_n = F.normalize(h_L, dim=0) sims = F.normalize(Wg, dim=1) @ h_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) candidates.sort(key=lambda x: -x[0]) best_sim = candidates[0][0] if candidates else 0.0 log.append(f"\n Best activation_sim = {best_sim:.4f} (embed-based would be ~{best_sim/3.5:.4f})") if best_sim < 0.05: log.append(" ⚠ sim < 0.05 → INSERT fallback") return self.insert(subject, relation, new_target, log=log) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] ops = [] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col_norm = Wd[:, slot].norm().item() new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot, "down_col":new_col,"activation_sim":round(sim,4)}) log.append(f" ✓ L{li} slot {slot}: act_sim={sim:.4f} col_norm={col_norm:.4f}") self.patches.append({"type":"PRECISE_UPDATE","entity":subject,"relation":relation, "new_target":new_target,"ops":ops}) self._apply_all_patches() log.append(f"\n ✓ Applied {len(ops)} op(s), patch #{len(self.patches)}") return ops def suppress(self, entity: str, top_k: int = 3, factor: float = 0.0, log: Optional[List[str]] = None): """Scale down (or zero) matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(entity), dim=0) log.append(f"SUPPRESS: '{entity}' factor={factor} top_k={top_k}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) new_col = (Wd[:, slot] * factor).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") self.patches.append({"type":"SUPPRESS","entity":entity,"factor":factor,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Suppressed {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def amplify(self, entity: str, top_k: int = 3, factor: float = 2.0, log: Optional[List[str]] = None): """Scale up matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(entity), dim=0) log.append(f"AMPLIFY: '{entity}' factor={factor} top_k={top_k}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) new_col = (Wd[:, slot] * factor).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") self.patches.append({"type":"AMPLIFY","entity":entity,"factor":factor,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Amplified {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def style_shift(self, anchor_entity: str, from_concept: str, to_concept: str, top_k: int = 3, strength: float = 0.5, log: Optional[List[str]] = None): """Add a direction vector to matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(anchor_entity), dim=0) from_v = self.embed(from_concept) to_v = self.embed(to_concept) dir_v = F.normalize(to_v - from_v, dim=0) log.append(f"STYLE_SHIFT: anchor='{anchor_entity}' {from_concept!r}→{to_concept!r} strength={strength}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col = Wd[:, slot] new_col = (col + dir_v * col.norm() * strength).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} col_norm={col.norm():.4f}") self.patches.append({"type":"STYLE_SHIFT","entity":anchor_entity, "from":from_concept,"to":to_concept,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Style-shifted {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def multi_edit(self, facts: List[Dict], mode: str = "UPDATE", alpha: float = 0.25, top_k: int = 3, scale: float = 1.0): """Apply a batch of edits sequentially.""" results = [] for f in facts: log: List[str] = [] try: entity = f["entity"] relation = f.get("relation", "") new_target = f["new_target"] prompt = f.get("prompt", "") if mode == "PRECISE" and prompt: ops = self.precise_update(prompt, entity, relation, new_target, top_k=top_k, scale=scale, log=log) elif mode == "INSERT": ops = self.insert(entity, relation, new_target, alpha=alpha, spread=top_k, log=log) else: ops = self.update(entity, relation, new_target, top_k=top_k, scale=scale, log=log) results.append({"entity":entity,"status":"ok", "ops": len(ops) if isinstance(ops,(list,)) else 1, "log":log}) except Exception as e: results.append({"entity":f.get("entity","?"),"status":"error", "error":str(e),"log":log}) return results def gate_heatmap(self, entity: str, use_activation: bool = False, prompt: Optional[str] = None) -> Dict: """Full layer×slot similarity matrix with top token decoding.""" if use_activation and prompt: activations, _ = self._get_subject_activations(prompt, entity) else: activations = {} ev_n = F.normalize(self.embed(entity), dim=0) layers_out = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) if use_activation and li in activations: q = F.normalize(activations[li], dim=0) else: q = ev_n sims = F.normalize(Wg, dim=1) @ q top_slots_count = min(20, sims.shape[0]) vals, idxs = sims.topk(top_slots_count) slots = [] for v, idx in zip(vals, idxs): top_toks = self.decode_down_col(Wd[:, idx], top_k=3) slots.append({ "slot": int(idx.item()), "sim": round(float(v.item()), 4), "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_toks] }) layers_out.append({"layer": li, "slots": slots}) return {"layers": layers_out} def dry_run(self, entity: str, new_target: str, top_k: int = 3, scale: float = 1.0, prompt: Optional[str] = None) -> Dict: """Same logic as precise_update/update but does NOT mutate weights.""" if prompt: activations, subject_pos = self._get_subject_activations(prompt, entity) use_act = True else: use_act = False ev_n = F.normalize(self.embed(entity), dim=0) tv = self.embed(new_target) candidates = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) if use_act and li in activations: q = F.normalize(activations[li], dim=0) else: q = ev_n sims = F.normalize(Wg, dim=1) @ q k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): col_norm = Wd[:, idx].norm().item() top_toks = self.decode_down_col(Wd[:, idx], top_k=3) candidates.append({ "layer": li, "slot": int(idx.item()), "sim": round(float(v.item()), 4), "col_norm": round(col_norm, 4), "inject_norm": round(col_norm * scale, 4), "current_top": [{"tok":t,"score":round(s,2)} for t,s in top_toks] }) candidates.sort(key=lambda x: -x["sim"]) best_sim = candidates[0]["sim"] if candidates else 0.0 chosen = [c for c in candidates if c["sim"] >= 0.05][:top_k] return { "candidates": chosen, "best_sim": best_sim, "would_patch": len(chosen), "new_target": new_target, "mode": "activation-guided" if use_act else "embed-based" } # ── Phase 1+: mechanistic attribution ───────────────────── def gradient_slot_scores(self, prompt: str, target: str) -> Dict: """One backward pass: grad norm of ∂(-log p(target))/∂W_down[:,slot] per KB layer. Identifies which slots causally contributed to this prediction via gradient signal.""" target_id = self.token_id(target) # Temporarily enable grad on down-proj weights down_params: List[Tuple[int, torch.nn.Parameter]] = [] for li in range(self.arch.n_layers): layer = self.arch._layer(li) p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \ else layer.mlp.down_proj.weight p.requires_grad_(True) down_params.append((li, p)) self.model.zero_grad() inputs = self.tok(prompt, return_tensors="pt").to(self.device) out = self.model(**inputs) loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id] loss.backward() layer_scores = [] for li, p in down_params: grad = p.grad p.requires_grad_(False) if grad is None: layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []}) continue # gpt2: c_proj.weight [ffn_dim, hidden] → rows = slots # gated: down_proj.weight [hidden, ffn_dim] → cols = slots slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \ else grad.norm(dim=0) # [ffn_dim] k = min(20, slot_norms.shape[0]) vals, idxs = slot_norms.topk(k) layer_scores.append({ "layer": li, "max_grad": round(float(vals[0].item()), 6), "top_slots": [{"slot": int(idx.item()), "grad_norm": round(float(v.item()), 6)} for idx, v in zip(idxs, vals)] }) self.model.zero_grad() return {"layer_scores": layer_scores} def causal_patch_trace(self, prompt: str, subject: str, target: str, noise_std: float = 0.1) -> Dict: """ROME-style causal tracing. Corrupts subject embeddings, then for each KB layer measures how much patching that layer's hidden state (at subject position) restores p(target). Expensive: O(n_layers) forward passes.""" target_id = self.token_id(target) W_u = self.arch.get_unembedding().to(self.device) inputs = self.tok(prompt, return_tensors="pt").to(self.device) ids = inputs["input_ids"][0].tolist() # Find subject token positions via subsequence match subj_ids = self.tok.encode(subject, add_special_tokens=False) subj_pos: List[int] = [] for start in range(len(ids) - len(subj_ids) + 1): if ids[start:start+len(subj_ids)] == subj_ids: subj_pos = list(range(start, start+len(subj_ids))) break if not subj_pos: for si in subj_ids: if si in ids: subj_pos = [ids.index(si)] break if not subj_pos: subj_pos = [0] # ── Clean forward — capture every layer's hidden states ── clean_hs: Dict[int, torch.Tensor] = {} clean_handles = [] def _mk_clean(li): def _h(m, inp, out): h = out[0] if isinstance(out, tuple) else out clean_hs[li] = h[0].detach().clone() # [seq, hidden] return _h for li in range(self.arch.n_layers): clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li))) with torch.no_grad(): clean_out = self.model(**inputs) for h in clean_handles: h.remove() clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item()) # ── Corrupted embeddings ── E = self.arch.get_embedding().to(self.device) emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() # [1, seq, hidden] noise_scale = emb.std().item() * noise_std for pos in subj_pos: emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale with torch.no_grad(): corr_out = self.model(inputs_embeds=emb) corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item()) # ── Causal patch sweep ── results = [] for li in range(self.kb_start, self.kb_end): def _mk_patch(target_li): def _h(m, inp, out): if target_li not in clean_hs: return out is_tuple = isinstance(out, tuple) h = list(out) if is_tuple else [out] clean = clean_hs[target_li] for pos in subj_pos: if pos < clean.shape[0]: h[0][0, pos] = clean[pos].to(h[0].device) return tuple(h) if is_tuple else h[0] return _h ph = self.arch._layer(li).register_forward_hook(_mk_patch(li)) with torch.no_grad(): patch_out = self.model(inputs_embeds=emb.clone()) ph.remove() patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item()) ie = patch_prob - corr_prob results.append({ "layer": li, "patch_prob": round(patch_prob, 6), "indirect_effect": round(ie, 6), }) return { "clean_prob": round(clean_prob, 6), "corrupt_prob": round(corr_prob, 6), "subject_pos": subj_pos, "results": results, } def smart_locate(self, prompt: str, subject: str, target: str, alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3, noise_std: float = 0.1) -> Dict: """Combined gate_sim + grad_norm + causal_effect → precise layer/slot ranking. alpha = weight for gate cosine sim beta = weight for gradient norm gamma = weight for causal indirect effect""" gate_data = self.locate(prompt, subject, target) grad_data = self.gradient_slot_scores(prompt, target) causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std) gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]} grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]} causal_map = {r["layer"]: max(0.0, r["indirect_effect"]) for r in causal_data["results"]} grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]} layers = sorted(set(gate_map) | set(grad_map) | set(causal_map)) def _norm(vals: List[float]) -> List[float]: m = max(vals) if vals else 1.0 return [v/m if m > 0 else 0.0 for v in vals] gv = [gate_map.get(l, 0.0) for l in layers] dv = [grad_map.get(l, 0.0) for l in layers] cv = [causal_map.get(l, 0.0) for l in layers] gn, dn, cn = _norm(gv), _norm(dv), _norm(cv) ranked = [] for i, l in enumerate(layers): score = alpha*gn[i] + beta*dn[i] + gamma*cn[i] ranked.append({ "layer": l, "gate_sim": round(gv[i], 4), "grad_norm": round(dv[i], 6), "causal_effect": round(cv[i], 6), "gate_sim_n": round(gn[i], 4), "grad_norm_n": round(dn[i], 4), "causal_n": round(cn[i], 4), "combined": round(score, 4), "best_slots": (grad_slots.get(l) or [])[:5], }) ranked.sort(key=lambda x: -x["combined"]) return { "ranked_layers": ranked, "phase_layer": gate_data["phase_layer"], "subject_pos": gate_data["subject_pos"], "clean_prob": causal_data["clean_prob"], "corrupt_prob": causal_data["corrupt_prob"], "recommendation": ranked[0] if ranked else None, "weights": {"alpha": alpha, "beta": beta, "gamma": gamma}, } def smart_edit(self, prompt: str, subject: str, relation: str, old_target: str, new_target: str, top_layers: int = 3, slots_per_layer: int = 2, scale: float = 1.5, noise_std: float = 0.1, alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2, log: Optional[List[str]] = None) -> Dict: """Auto edit: runs smart_locate on (prompt, subject, old_target) to find the exact layer+slot targets via gradient+causal+gate consensus, then patches those W_down columns toward embed(new_target). old_target = what the model currently predicts (used to locate) new_target = what you want to inject top_layers = how many top-ranked layers to patch slots_per_layer = gradient-identified slots to patch per layer scale = col_norm multiplier (1.5-3.0 recommended) beta > alpha because grad_norm is more reliable than gate_sim for small models.""" if log is None: log = [] self._snapshot() log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} → {new_target!r}") log.append(f" Running smart_locate on prompt: {prompt!r}") log.append(f" Weights: α={alpha} β={beta} γ={gamma} noise_std={noise_std}") sl = self.smart_locate(prompt, subject, old_target, alpha=alpha, beta=beta, gamma=gamma, noise_std=noise_std) log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}") log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}") if sl["clean_prob"] < 1e-5: log.append(" ⚠ clean_prob near zero — model barely knows this fact.") log.append(" Grad-norm signal still valid. Causal IE=0 is expected.") log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.") tv = self.embed(new_target) tv_n = F.normalize(tv, dim=0) ops = [] used = [] top_ranked = sl["ranked_layers"][:top_layers] for lr in top_ranked: li = lr["layer"] # Use gradient-identified slots — far more precise than gate cosine grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]] if not grad_slots: log.append(f" L{li}: no grad slots, skipping") continue _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) for slot in grad_slots: col_norm = Wd[:, slot].norm().item() new_col = (tv_n * col_norm * scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" ✓ L{li} slot {slot}: combined={lr['combined']} " f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} " f"inject={col_norm*scale:.4f}") used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]}) self.patches.append({ "type": "SMART_UPDATE", "entity": subject, "relation": relation, "new_target": new_target, "old_target": old_target, "smart_top": top_ranked, "ops": ops, }) self._apply_all_patches() log.append(f"\n ✓ {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}") return { "ops": ops, "used_layers": used, "smart_locate": sl, "log": log, } def infer(self, prompt: str, top_k: int = 5): probs = torch.softmax(self._forward(prompt), dim=-1) top = probs.topk(top_k) return [{"token": self.tok.decode([idx.item()]).strip(), "prob": round(val.item(), 6)} for idx, val in zip(top.indices, top.values)] def describe(self, entity: str, top_k: int = 10): ev_n = F.normalize(self.embed(entity), dim=0) all_edges = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n for fid in sims.topk(min(5,sims.shape[0])).indices: gsim = sims[fid].item() if gsim < 0.08: continue for tok, score in self.decode_down_col(Wd[:,fid], 4): if tok: all_edges.append({"tok":tok,"score":score,"layer":li,"gate_sim":gsim}) best: Dict[str, Any] = {} for e in all_edges: t = e["tok"] if t not in best or e["score"] > best[t]["score"]: best[t] = e ranked = sorted(best.values(), key=lambda x: -x["score"])[:top_k] return ranked def trace(self, prompt: str, target: str): target_id = self.token_id(target) W_u = self.arch.get_unembedding().to(self.device) stats: List[Dict] = [] handles = [] def make_hook(li): def hook(m, inp, out): h = out[0] if isinstance(out,tuple) else out last = h[0,-1].detach() p = torch.softmax(W_u @ last, dim=-1) rank = int((p > p[target_id]).sum().item()) + 1 stats.append({"l":li,"rank":rank,"prob":round(p[target_id].item(),8)}) return hook for li in range(self.arch.n_layers): handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) inputs = self.tok(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): self.model(**inputs) for h in handles: h.remove() return stats # ── patch management ─────────────────────────────────────── def _snapshot(self): if self._base_weights is not None: return self._base_weights = {} for li in range(self.arch.n_layers): Wg, Wd = self.arch.get_ffn_weights(li) self._base_weights[li] = (Wg.clone().cpu(), Wd.clone().cpu()) def _restore_base(self): if self._base_weights is None: return for li,(Wg,Wd) in self._base_weights.items(): self.arch.set_ffn_weights(li, Wg.to(self.device), Wd.to(self.device)) def _apply_all_patches(self): self._restore_base() for patch in self.patches: for op in patch.get("ops",[]): li=op["layer"]; slot=op["slot"] Wg, Wd = self.arch.get_ffn_weights(li) Wg=Wg.clone(); Wd=Wd.clone() if op["op"] in ("insert","update_gate"): Wg[slot] = torch.tensor(op["gate_row"],dtype=Wg.dtype,device=self.device) if op["op"] in ("insert","update_down"): Wd[:,slot] = torch.tensor(op["down_col"],dtype=Wd.dtype,device=self.device) self.arch.set_ffn_weights(li, Wg, Wd) def insert(self, entity, relation, target, alpha=0.25, spread=4, log=None): if log is None: log = [] self._snapshot() gate_dir = F.normalize(self.embed(entity), dim=0) down_dir = F.normalize(self.embed(target), dim=0) ls=self.kb_start; le=min(ls+spread, self.arch.n_layers) log.append(f"INSERT: '{entity}' -[{relation}]-> '{target}' alpha={alpha}") ops=[] for li in range(ls,le): Wg, Wd = self.arch.get_ffn_weights(li) Wg=Wg.to(self.device); Wd=Wd.to(self.device) norms_g=Wg.norm(dim=1); norms_d=Wd.norm(dim=0) slot=norms_g.argmin().item() wg_mean=norms_g.mean().item(); wd_mean=norms_d.mean().item() ops.append({"op":"insert","layer":li,"slot":slot, "gate_row":(gate_dir*wg_mean*alpha).cpu().tolist(), "down_col":(down_dir*wd_mean*alpha).cpu().tolist()}) log.append(f" L{li}: slot={slot} inject={wg_mean*alpha:.4f}") self.patches.append({"type":"INSERT","entity":entity,"relation":relation, "target":target,"ops":ops}) self._apply_all_patches() return ops def update(self, entity, relation, new_target, top_k=3, scale=1.0, log=None): if log is None: log = [] self._snapshot() ev = self.embed(entity) tv = self.embed(new_target) ev_n = F.normalize(ev, dim=0) log.append(f"UPDATE: '{entity}' -[{relation}]-> '{new_target}' top_k={top_k} scale={scale}") candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) candidates.sort(key=lambda x: -x[0]) best_sim = candidates[0][0] if candidates else 0.0 if best_sim < 0.08: log.append(" ⚠ sim<0.08 → INSERT fallback") return self.insert(entity, relation, new_target, log=log), "INSERT_FALLBACK" chosen = [c for c in candidates if c[0] >= 0.08][:top_k] ops = [] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col_norm = Wd[:, slot].norm().item() new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" ✓ L{li} slot {slot}: sim={sim:.4f} norm={col_norm:.4f}") self.patches.append({"type":"UPDATE","entity":entity,"relation":relation, "new_target":new_target,"ops":ops}) self._apply_all_patches() li0,slot0,sim0 = chosen[0] return li0, slot0, sim0 def reset(self): self.patches.clear() self._restore_base() def save_patch(self, path: str): with open(path,"w") as f: json.dump(self.patches, f, indent=2) def compile_to(self, output_dir: str): Path(output_dir).mkdir(parents=True, exist_ok=True) self.model.save_pretrained(output_dir) self.tok.save_pretrained(output_dir) # ══════════════════════════════════════════════════════════════ # HTML FRONTEND (Phase 3) # ══════════════════════════════════════════════════════════════ HTML_PAGE = r""" VINDEX — LLM Knowledge Editor
VINDEX LLM Knowledge Editor
The model IS the database. Inspect · Edit · Locate · Compile.
No model loaded

Load Model

Quick Models

distilgpt2 — 350 MB, instant
gpt2 — 550 MB
gpt2-medium — 1.5 GB
Qwen/Qwen2.5-1.5B-Instruct — 3 GB, strong facts

Next-Token Prediction

Results

Entity Knowledge Graph (W_gate KNN → W_down decode)

Force-Directed Graph

Layer-by-Layer Rank Trace

Rank + Probability over Layers

Locate — Diagnostic (Trace + Activation Similarity)

Trace (rank over layers)

Activation Sim per KB Layer

Smart Locate — gradient + causal + gate_sim combined

Three independent signals combined into one ranked layer list.
■ gate_sim — static embedding cosine (fast, weak proxy)   ■ grad_norm — ∂loss/∂W_down per slot (one backward pass)   ■ causal IE — indirect effect via subject-corruption patching (N_layers passes, slow)

Layer Rankings — 3-signal stacked bars

Recommendation

Run Smart Locate to see the best edit target.

Collateral Probe

Probe a prompt to check collateral damage.

Per-Layer Detail

Run Smart Locate first.

Gate Heatmap — Layer × Slot Cosine Similarity

Heatmap

Selected Slot

Click a cell to see decoded tokens.

Edit

Before / After

Dry Run Preview

Run dry-run first.

Edit Log + Delta

No edit yet.

Active Patches

No patches.

Concept Reference

UPDATE — rewrites W_down column → different fact
PRECISE — activation-guided UPDATE (3–5× better sim)
INSERT — new (gate,down) pair in weakest slot
SUPPRESS — scale down → model forgets entity
AMPLIFY — scale up → stronger recall
STYLE-SHIFT — adds direction vector → tone/bias shift

What is VINDEX doing?

In a transformer, factual associations like "France → capital → Paris" are stored as direction vectors in the W_down columns of FFN layers. The W_gate rows act as keys: when the residual stream resembles "France", the matching gate fires, the down column adds "Paris" direction to the stream, and the unembedding reads out "Paris". VINDEX surgically replaces those down columns without retraining.

Quickstart — 5-step experiment

Step 1 — Load a model that actually knows facts
⚙ Load tab → gpt2-medium (1.5 GB, knows capitals) or Qwen/Qwen2.5-1.5B-Instruct (3 GB, strong).
distilgpt2 has clean_prob≈0 for most facts → causal IE=0 everywhere → misleading results.
Step 2 — Verify the model knows the fact
① Infer: prompt = "The capital of France is"
✓ Good: "Paris" appears in top-3 with prob > 0.05
✗ Bad: top tokens are "a", "the", "known" → model doesn't know it → skip to INSERT mode
Step 3 — Find where the fact lives
③ Trace: prompt = "The capital of France is", target = "Paris"
→ Look for phase layer: where rank drops from ~30000 to <100. That's where the fact materializes.
⑤ Smart Locate → Gradient only (fast, 1 backward pass):
subject = France, target = Paris
→ The layer with highest grad_norm bar = best edit target. Note the slot numbers.
Step 4 — Edit with SMART mode
⑦ Edit tab → mode = ★ SMART
Entity = France | Relation = capital
Old value = Paris (what model says now — used for locate)
New value = Lyon (what you want)
Prompt = "The capital of France is"
Scale = 2.0 (start here; increase to 3.0 if effect is weak)
→ Click Apply Edit. Smart locate runs internally, patches grad-identified slots.
Step 5 — Check collateral damage
① Infer: "The capital of France is" → should now say Lyon
① Infer: "Biggest cities in France" → should be unchanged (different slots)
① Infer: "Paris is a city in" → should still say France
① Infer: "Lyon is a city in" → might now also say France (collateral)
⑤ Smart Locate collateral probe → run these prompts, compare slot lists in ⑧ Patches

Interpreting Smart Locate results

■ gate_sim (blue)
Cosine between W_gate[slot] and embed(subject).
Fast, cheap, but weak proxy — measures embedding-space similarity,
not causal contribution. Useful for finding related slots.
High gate_sim + low grad_norm = slot activates for this entity
but doesn't contribute much to this specific prediction.
■ grad_norm (green)
‖∂(-log p(target))/∂W_down[:,slot]‖ — how much changing this slot
would affect the loss for this (prompt, target) pair.
Most reliable signal, works even when clean_prob is tiny.
One backward pass. Use β > α to weight this higher.
High grad_norm = this slot is causally upstream of the prediction.
■ causal IE (yellow)
Indirect effect via noise-corruption patching (ROME-style).
Measures: if I corrupt subject embeddings, how much does patching
layer L's hidden state at subject pos restore the prediction?
Most interpretable — true causal measurement. But:
If clean_prob ≈ 0, IE = 0 everywhere (nothing to restore).
Needs a model that actually knows the fact.
⚠ Your distilgpt2 result: clean_prob=0.000001 → causal IE=0 everywhere (expected, not a bug). grad_norm on L9/slot515 IS real signal — that slot responds to France+capital context in the gradient sense. But the probability mass is too diffuse to show causal separation. Switch to gpt2-medium for textbook causal results.

Edit modes — when to use which

Mode Slot selection Best for Knobs
UPDATE gate cosine sim to embed(entity) Quick experiment, model knows the fact well Top-K=3-5, Scale=1.5-3
PRECISE gate cosine sim to h_L[subject_pos] In-context subject representation (3-5× better than UPDATE) + Prompt field
★ SMART gradient norm → exact slots, then patch Best overall. Auto-locates, no manual tuning Top layers=3, Slots/layer=2, Scale=1.5-2.5
INSERT weakest slot (norm-based) Model has no knowledge of fact, build from scratch Alpha=0.4-0.7, Spread=4-6
SUPPRESS gate cosine → scale W_down to 0 Make model forget an entity (factor=0) or weaken (0.5) Factor: 0=forget, 0.5=weaken
STYLE-SHIFT gate cosine → add direction vector Bias/tone shifts: CEO→less male-coded, Paris→darker from/to concepts, strength=0.3-0.8

Experiments to run

Experiment A — Capital swap (classic ROME benchmark)
Model: gpt2-medium | Prompt: "The capital of France is" | Old: Paris | New: Lyon
Check: "France's capital city" | "Lyon is now" | "Paris is in" | "Eiffel Tower is in"
Insight: does it generalize (paraphrase) or is it prompt-specific?

Experiment B — Slot overlap analysis (your collateral question)
1. SMART locate "The capital of France is" → note slot numbers in recommendation
2. SMART locate "The biggest city in France is" → compare slot lists
3. Overlap = slots that will be collaterally damaged
4. No overlap = clean surgery ✓

Experiment C — Suppression then INSERT
SUPPRESS France → then INSERT France capital Lyon → Infer
vs just UPDATE. Which gives cleaner, more confident result?

Experiment D — Style shift (no factual change)
STYLE-SHIFT: anchor=CEO, from="male", to="female", strength=0.3
Then Infer: "The CEO of the company is a" — does pronoun distribution shift?
Insight: this is mechanical debiasing without retraining.

Experiment E — Compile and compare
Edit 5 facts. Compile → save as new model directory.
Load compiled model fresh → Infer same prompts → edits should persist in weights.
Then Trace on compiled model → phase layers should shift or sharpen.

α β γ tuning guide

Default (0.4 / 0.3 / 0.3) — balanced, works for unknown model quality
Grad-heavy (0.1 / 0.7 / 0.2) — clean_prob > 0.01. Grad signal is sharp, trust it.
Gate+Grad (0.4 / 0.4 / 0.2) — recommended for smart_edit when causal IE is weak
Causal-heavy (0.2 / 0.2 / 0.6) — only when clean_prob > 0.1. IE is the gold signal then.
Gate-only (1.0 / 0.0 / 0.0) — equivalent to basic locate(), sanity check

Your distilgpt2 setting: use (0.3 / 0.7 / 0.0) — gate+grad, skip causal (it's 0 anyway).
""" # ══════════════════════════════════════════════════════════════ # FASTAPI (Phase 2) # ══════════════════════════════════════════════════════════════ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from pydantic import BaseModel import uvicorn app = FastAPI(title="VINDEX") _vi: VIndex | None = None def _require(): if _vi is None: raise HTTPException(status_code=400, detail="No model loaded. POST /api/load first.") return _vi # ── Request models ───────────────────────────────────────────── class LoadReq(BaseModel): model_name: str = "distilgpt2" device: str = "auto" class InferReq(BaseModel): prompt: str top_k: int = 10 class DescribeReq(BaseModel): entity: str top_k: int = 10 class TraceReq(BaseModel): prompt: str target: str class LocateReq(BaseModel): prompt: str subject: str target: str class HeatmapReq(BaseModel): entity: str top_slots: int = 20 use_activation: bool = False prompt: Optional[str] = None class GradientReq(BaseModel): prompt: str target: str class CausalTraceReq(BaseModel): prompt: str subject: str target: str noise_std: float = 0.1 class SmartLocateReq(BaseModel): prompt: str subject: str target: str alpha: float = 0.4 beta: float = 0.3 gamma: float = 0.3 noise_std: float = 0.1 class SmartEditReq(BaseModel): prompt: str subject: str relation: str = "" old_target: str new_target: str top_layers: int = 3 slots_per_layer: int = 2 scale: float = 1.5 noise_std: float = 0.1 alpha: float = 0.4 beta: float = 0.4 gamma: float = 0.2 class DryRunReq(BaseModel): entity: str new_target: str top_k: int = 3 scale: float = 1.0 prompt: Optional[str] = None class EditReq(BaseModel): entity: str relation: str = "" old_target: str = "" new_target: str mode: str = "UPDATE" alpha: float = 0.25 top_k: int = 3 scale: float = 1.0 prompt: Optional[str] = None from_concept: str = "" to_concept: str = "" strength: float = 0.5 class SuppressReq(BaseModel): entity: str top_k: int = 3 factor: float = 0.0 class AmplifyReq(BaseModel): entity: str top_k: int = 3 factor: float = 2.0 class StyleShiftReq(BaseModel): anchor: str from_concept: str to_concept: str top_k: int = 3 strength: float = 0.5 class SaveReq(BaseModel): path: str = "patches.json" class CompileReq(BaseModel): output_dir: str = "./vindex_compiled" class MultiEditFact(BaseModel): entity: str relation: str = "" new_target: str prompt: Optional[str] = None # ── Endpoints ────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) async def root(): return HTML_PAGE @app.post("/api/load") async def api_load(req: LoadReq): global _vi device = None if req.device == "auto" else req.device try: _vi = VIndex(req.model_name, device=device) return {"ok": True, "info": _vi.info, "n_layers": _vi.arch.n_layers, "kb_start": _vi.kb_start, "kb_end": _vi.kb_end} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/status") async def api_status(): if _vi is None: return {"loaded": False, "info": "", "patches_count": 0, "kb_start": 0, "kb_end": 0, "n_layers": 0} return {"loaded": True, "info": _vi.info, "patches_count": len(_vi.patches), "kb_start": _vi.kb_start, "kb_end": _vi.kb_end, "n_layers": _vi.arch.n_layers} @app.post("/api/infer") async def api_infer(req: InferReq): vi = _require() return {"results": vi.infer(req.prompt, top_k=req.top_k)} @app.post("/api/describe") async def api_describe(req: DescribeReq): vi = _require() edges = vi.describe(req.entity, top_k=req.top_k) out = [{"tok":e["tok"],"score":round(e["score"],2),"layer":e["layer"],"gate_sim":round(e["gate_sim"],4)} for e in edges] return {"edges": out} @app.post("/api/trace") async def api_trace(req: TraceReq): vi = _require() return {"stats": vi.trace(req.prompt, req.target)} @app.post("/api/locate") async def api_locate(req: LocateReq): vi = _require() return vi.locate(req.prompt, req.subject, req.target) @app.post("/api/gradient_scores") async def api_gradient_scores(req: GradientReq): vi = _require() return vi.gradient_slot_scores(req.prompt, req.target) @app.post("/api/causal_trace") async def api_causal_trace(req: CausalTraceReq): vi = _require() return vi.causal_patch_trace(req.prompt, req.subject, req.target, noise_std=req.noise_std) @app.post("/api/smart_locate") async def api_smart_locate(req: SmartLocateReq): vi = _require() return vi.smart_locate(req.prompt, req.subject, req.target, alpha=req.alpha, beta=req.beta, gamma=req.gamma, noise_std=req.noise_std) @app.post("/api/smart_edit") async def api_smart_edit(req: SmartEditReq): vi = _require() prompt_str = req.prompt or f"The {req.relation} of {req.subject} is" before = vi.infer(prompt_str, top_k=5) log: List[str] = [] try: result = vi.smart_edit( prompt_str, req.subject, req.relation, req.old_target, req.new_target, top_layers=req.top_layers, slots_per_layer=req.slots_per_layer, scale=req.scale, noise_std=req.noise_std, alpha=req.alpha, beta=req.beta, gamma=req.gamma, log=log ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) after = vi.infer(prompt_str, top_k=5) b_map = {d["token"]: d["prob"] for d in before} a_map = {d["token"]: d["prob"] for d in after} all_toks = set(b_map) | set(a_map) delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0), "delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks], key=lambda x: -abs(x["delta"])) return {"before": before, "after": after, "delta": delta, "debug_log": log, "used_layers": result["used_layers"], "smart_locate": result["smart_locate"]} @app.post("/api/gate_heatmap") async def api_gate_heatmap(req: HeatmapReq): vi = _require() return vi.gate_heatmap(req.entity, use_activation=req.use_activation, prompt=req.prompt) @app.post("/api/dry_run") async def api_dry_run(req: DryRunReq): vi = _require() return vi.dry_run(req.entity, req.new_target, top_k=req.top_k, scale=req.scale, prompt=req.prompt) @app.post("/api/edit") async def api_edit(req: EditReq): vi = _require() prompt_str = f"The {req.relation} of {req.entity} is" before = vi.infer(prompt_str, top_k=5) log: List[str] = [] try: mode = req.mode.upper() if mode == "PRECISE": vi.precise_update(req.prompt or prompt_str, req.entity, req.relation, req.new_target, top_k=req.top_k, scale=req.scale, log=log) elif mode == "INSERT": vi.insert(req.entity, req.relation, req.new_target, alpha=req.alpha, spread=req.top_k, log=log) elif mode == "SUPPRESS": vi.suppress(req.entity, top_k=req.top_k, log=log) elif mode == "AMPLIFY": vi.amplify(req.entity, top_k=req.top_k, log=log) elif mode == "STYLE-SHIFT": vi.style_shift(req.entity, req.from_concept, req.to_concept, top_k=req.top_k, strength=req.strength, log=log) else: # UPDATE vi.update(req.entity, req.relation, req.new_target, top_k=req.top_k, scale=req.scale, log=log) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) after = vi.infer(prompt_str, top_k=5) b_map = {d["token"]: d["prob"] for d in before} a_map = {d["token"]: d["prob"] for d in after} all_toks = set(b_map) | set(a_map) delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0), "delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks], key=lambda x: -abs(x["delta"])) return {"before": before, "after": after, "delta": delta, "debug_log": log, "ops": len(vi.patches[-1]["ops"]) if vi.patches else 0} @app.post("/api/multi_edit") async def api_multi_edit(facts: List[MultiEditFact]): vi = _require() return vi.multi_edit([f.model_dump() for f in facts]) @app.post("/api/suppress") async def api_suppress(req: SuppressReq): vi = _require() return vi.suppress(req.entity, top_k=req.top_k, factor=req.factor) @app.post("/api/amplify") async def api_amplify(req: AmplifyReq): vi = _require() return vi.amplify(req.entity, top_k=req.top_k, factor=req.factor) @app.post("/api/style_shift") async def api_style_shift(req: StyleShiftReq): vi = _require() return vi.style_shift(req.anchor, req.from_concept, req.to_concept, top_k=req.top_k, strength=req.strength) @app.get("/api/patches") async def api_patches(): vi = _require() out = [] for i, p in enumerate(vi.patches): out.append({ "i": i, "type": p["type"], "entity": p.get("entity",""), "relation": p.get("relation",""), "new_target": p.get("new_target","") or p.get("target",""), "ops_count": len(p.get("ops",[])) }) return {"patches": out} @app.delete("/api/patches/{idx}") async def api_delete_patch(idx: int): vi = _require() if idx < 0 or idx >= len(vi.patches): raise HTTPException(status_code=404, detail="Patch index out of range") vi.patches.pop(idx) vi._apply_all_patches() return {"ok": True} @app.post("/api/reset") async def api_reset(): vi = _require() vi.reset() return {"ok": True} @app.post("/api/save") async def api_save(req: SaveReq): vi = _require() vi.save_patch(req.path) return {"ok": True, "path": req.path} @app.post("/api/compile") async def api_compile(req: CompileReq): vi = _require() vi.compile_to(req.output_dir) return {"ok": True, "output_dir": req.output_dir} # ══════════════════════════════════════════════════════════════ # ENTRY # ══════════════════════════════════════════════════════════════ if __name__ == "__main__": import argparse, webbrowser, threading, time ap = argparse.ArgumentParser() ap.add_argument("--port", type=int, default=8787) ap.add_argument("--model", default=None) ap.add_argument("--device", default=None) ap.add_argument("--no-browser", action="store_true") args, _ = ap.parse_known_args() if args.model: print(f"Pre-loading {args.model}…") _vi = VIndex(args.model, device=args.device) print("Done.", _vi.info) if not args.no_browser: def _open(): time.sleep(1.2) webbrowser.open(f"http://localhost:{args.port}") threading.Thread(target=_open, daemon=True).start() uvicorn.run(app, host="0.0.0.0", port=args.port)