#!/usr/bin/env python3 """ VINDEX — FastAPI + D3 single-file LLM knowledge editor Phase 1: Engine extensions | Phase 2: FastAPI | Phase 3: D3 frontend pip install transformers torch fastapi uvicorn pydantic python temp.py """ # ══════════════════════════════════════════════════════════════ # ENGINE # ══════════════════════════════════════════════════════════════ import re, json from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM class ArchAdapter: def __init__(self, model): self.model = model self.style = self._detect_style() self.n_layers = self._count_layers() def _detect_style(self): t = self.model.config.model_type if t in ("gpt2","gpt_neo","gpt_neox"): return "gpt2" if t in ("llama","mistral","qwen2","gemma","gemma2","phi3", "falcon","codellama","deepseek","internlm2"): return "gated" try: if hasattr(self._layer(0).mlp, "gate_proj"): return "gated" except: pass return "gpt2" def _layer(self, i): m = self.model if hasattr(m,"transformer") and hasattr(m.transformer,"h"): return m.transformer.h[i] if hasattr(m,"model") and hasattr(m.model,"layers"): return m.model.layers[i] raise ValueError("Unknown model structure") def _count_layers(self): m = self.model if hasattr(m,"transformer") and hasattr(m.transformer,"h"): return len(m.transformer.h) if hasattr(m,"model") and hasattr(m.model,"layers"): return len(m.model.layers) raise ValueError("Cannot count layers") def get_ffn_weights(self, li): layer = self._layer(li) if self.style == "gpt2": return layer.mlp.c_fc.weight.detach().T, layer.mlp.c_proj.weight.detach().T return layer.mlp.gate_proj.weight.detach(), layer.mlp.down_proj.weight.detach() def set_ffn_weights(self, li, Wg, Wd): layer = self._layer(li) with torch.no_grad(): if self.style == "gpt2": layer.mlp.c_fc.weight.copy_(Wg.T) layer.mlp.c_proj.weight.copy_(Wd.T) else: layer.mlp.gate_proj.weight.copy_(Wg) layer.mlp.down_proj.weight.copy_(Wd) def get_embedding(self): return self.model.get_input_embeddings().weight.detach() def get_unembedding(self): if hasattr(self.model,"lm_head"): return self.model.lm_head.weight.detach() return self.get_embedding() class VIndex: def __init__(self, model_name: str, device: Optional[str] = None): self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.patches: List[Dict] = [] self._base_weights: Optional[Dict] = None self.tok = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32).to(self.device) self.model.eval() if self.tok.pad_token is None: self.tok.pad_token = self.tok.eos_token self.arch = ArchAdapter(self.model) n = self.arch.n_layers self.kb_start = n // 3 self.kb_end = n @property def info(self): return (f"{self.model_name} | {self.arch.n_layers} layers | " f"style={self.arch.style} | kb=L{self.kb_start}-L{self.kb_end-1} | {self.device}") # ── utils ────────────────────────────────────────────────── def embed(self, text: str): ids = self.tok.encode(text, add_special_tokens=False) if not ids: raise ValueError(f"Cannot tokenize: {text!r}") E = self.arch.get_embedding().to(self.device) return E[ids].mean(0) def decode_down_col(self, col, top_k=5): scores = self.arch.get_unembedding().to(self.device) @ col.to(self.device) top = scores.topk(top_k) return [(self.tok.decode([i.item()]).strip(), v.item()) for i,v in zip(top.indices,top.values) if v.item()>0] def token_id(self, word: str) -> int: ids = self.tok.encode(word, add_special_tokens=False) return ids[0] if ids else 0 def _forward(self, prompt: str): inputs = self.tok(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.model(**inputs) return out.logits[0,-1] # ── Phase 1: new engine methods ──────────────────────────── def _get_subject_activations(self, prompt: str, subject: str ) -> Tuple[Dict[int, torch.Tensor], int]: """Capture h_L[last_subject_token_pos] at every layer via forward hooks.""" enc = self.tok(prompt, return_tensors="pt").to(self.device) ids = enc["input_ids"][0].tolist() # Subsequence match to find subject position subj_ids = self.tok.encode(subject, add_special_tokens=False) subject_pos = 0 for start in range(len(ids) - len(subj_ids) + 1): if ids[start:start+len(subj_ids)] == subj_ids: subject_pos = start + len(subj_ids) - 1 break else: # Fallback: find any single token from subject_ids for si in subj_ids: if si in ids: subject_pos = ids.index(si) break activations: Dict[int, torch.Tensor] = {} handles = [] def make_hook(li): def hook(m, inp, out): h = out[0] if isinstance(out, tuple) else out activations[li] = h[0, subject_pos].detach().clone() return hook for li in range(self.arch.n_layers): handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) with torch.no_grad(): self.model(**enc) for h in handles: h.remove() return activations, subject_pos def locate(self, prompt: str, subject: str, target: str) -> Dict: """Diagnostic: combines trace + activation-guided similarity scan.""" trace_stats = self.trace(prompt, target) # Find phase_layer: biggest relative rank drop phase_layer = 0 best_drop = 0.0 prev_rank = None for s in trace_stats: if prev_rank is not None and prev_rank > 5: drop = (prev_rank - s["rank"]) / prev_rank if drop > best_drop: best_drop = drop phase_layer = s["l"] prev_rank = s["rank"] activations, subject_pos = self._get_subject_activations(prompt, subject) layer_scores = [] for li in range(self.kb_start, self.kb_end): h_L = activations.get(li) if h_L is None: layer_scores.append({"layer": li, "max_sim": 0.0, "best_slot": -1, "top_tokens": []}) continue Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) h_n = F.normalize(h_L, dim=0) sims = F.normalize(Wg, dim=1) @ h_n best_slot = int(sims.argmax().item()) max_sim = float(sims[best_slot].item()) Wd = Wd.to(self.device) top_tokens = self.decode_down_col(Wd[:, best_slot], top_k=3) layer_scores.append({ "layer": li, "max_sim": round(max_sim, 4), "best_slot": best_slot, "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_tokens] }) return { "phase_layer": phase_layer, "subject_pos": subject_pos, "layer_scores": layer_scores, "trace": trace_stats, } def precise_update(self, prompt: str, subject: str, relation: str, new_target: str, top_k: int = 3, scale: float = 1.0, log: Optional[List[str]] = None): """Activation-guided update: uses h_L[subject_pos] instead of embed(subject).""" if log is None: log = [] if not prompt: return self.update(subject, relation, new_target, top_k=top_k, scale=scale, log=log) self._snapshot() activations, subject_pos = self._get_subject_activations(prompt, subject) tv = self.embed(new_target) log.append(f"PRECISE_UPDATE: '{subject}' -[{relation}]-> '{new_target}'") log.append(f" subject_pos={subject_pos} scale={scale} top_k={top_k}") candidates = [] for li in range(self.kb_start, self.kb_end): h_L = activations.get(li) if h_L is None: continue Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) h_n = F.normalize(h_L, dim=0) sims = F.normalize(Wg, dim=1) @ h_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) candidates.sort(key=lambda x: -x[0]) best_sim = candidates[0][0] if candidates else 0.0 log.append(f"\n Best activation_sim = {best_sim:.4f} (embed-based would be ~{best_sim/3.5:.4f})") if best_sim < 0.05: log.append(" ⚠ sim < 0.05 → INSERT fallback") return self.insert(subject, relation, new_target, log=log) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] ops = [] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col_norm = Wd[:, slot].norm().item() new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot, "down_col":new_col,"activation_sim":round(sim,4)}) log.append(f" ✓ L{li} slot {slot}: act_sim={sim:.4f} col_norm={col_norm:.4f}") self.patches.append({"type":"PRECISE_UPDATE","entity":subject,"relation":relation, "new_target":new_target,"ops":ops}) self._apply_all_patches() log.append(f"\n ✓ Applied {len(ops)} op(s), patch #{len(self.patches)}") return ops def suppress(self, entity: str, top_k: int = 3, factor: float = 0.0, log: Optional[List[str]] = None): """Scale down (or zero) matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(entity), dim=0) log.append(f"SUPPRESS: '{entity}' factor={factor} top_k={top_k}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) new_col = (Wd[:, slot] * factor).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") self.patches.append({"type":"SUPPRESS","entity":entity,"factor":factor,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Suppressed {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def amplify(self, entity: str, top_k: int = 3, factor: float = 2.0, log: Optional[List[str]] = None): """Scale up matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(entity), dim=0) log.append(f"AMPLIFY: '{entity}' factor={factor} top_k={top_k}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) new_col = (Wd[:, slot] * factor).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} factor={factor}") self.patches.append({"type":"AMPLIFY","entity":entity,"factor":factor,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Amplified {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def style_shift(self, anchor_entity: str, from_concept: str, to_concept: str, top_k: int = 3, strength: float = 0.5, log: Optional[List[str]] = None): """Add a direction vector to matching down columns.""" if log is None: log = [] self._snapshot() ev_n = F.normalize(self.embed(anchor_entity), dim=0) from_v = self.embed(from_concept) to_v = self.embed(to_concept) dir_v = F.normalize(to_v - from_v, dim=0) log.append(f"STYLE_SHIFT: anchor='{anchor_entity}' {from_concept!r}→{to_concept!r} strength={strength}") ops = [] candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) candidates.sort(key=lambda x: -x[0]) chosen = [c for c in candidates if c[0] >= 0.05][:top_k] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col = Wd[:, slot] new_col = (col + dir_v * col.norm() * strength).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" L{li} slot {slot}: gate_sim={sim:.4f} col_norm={col.norm():.4f}") self.patches.append({"type":"STYLE_SHIFT","entity":anchor_entity, "from":from_concept,"to":to_concept,"ops":ops}) self._apply_all_patches() log.append(f" ✓ Style-shifted {len(ops)} slot(s)") return {"ops": len(ops), "log": log} def multi_edit(self, facts: List[Dict], mode: str = "UPDATE", alpha: float = 0.25, top_k: int = 3, scale: float = 1.0): """Apply a batch of edits sequentially.""" results = [] for f in facts: log: List[str] = [] try: entity = f["entity"] relation = f.get("relation", "") new_target = f["new_target"] prompt = f.get("prompt", "") if mode == "PRECISE" and prompt: ops = self.precise_update(prompt, entity, relation, new_target, top_k=top_k, scale=scale, log=log) elif mode == "INSERT": ops = self.insert(entity, relation, new_target, alpha=alpha, spread=top_k, log=log) else: ops = self.update(entity, relation, new_target, top_k=top_k, scale=scale, log=log) results.append({"entity":entity,"status":"ok", "ops": len(ops) if isinstance(ops,(list,)) else 1, "log":log}) except Exception as e: results.append({"entity":f.get("entity","?"),"status":"error", "error":str(e),"log":log}) return results def gate_heatmap(self, entity: str, use_activation: bool = False, prompt: Optional[str] = None) -> Dict: """Full layer×slot similarity matrix with top token decoding.""" if use_activation and prompt: activations, _ = self._get_subject_activations(prompt, entity) else: activations = {} ev_n = F.normalize(self.embed(entity), dim=0) layers_out = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) if use_activation and li in activations: q = F.normalize(activations[li], dim=0) else: q = ev_n sims = F.normalize(Wg, dim=1) @ q top_slots_count = min(20, sims.shape[0]) vals, idxs = sims.topk(top_slots_count) slots = [] for v, idx in zip(vals, idxs): top_toks = self.decode_down_col(Wd[:, idx], top_k=3) slots.append({ "slot": int(idx.item()), "sim": round(float(v.item()), 4), "top_tokens": [{"tok": t, "score": round(s,2)} for t,s in top_toks] }) layers_out.append({"layer": li, "slots": slots}) return {"layers": layers_out} def dry_run(self, entity: str, new_target: str, top_k: int = 3, scale: float = 1.0, prompt: Optional[str] = None) -> Dict: """Same logic as precise_update/update but does NOT mutate weights.""" if prompt: activations, subject_pos = self._get_subject_activations(prompt, entity) use_act = True else: use_act = False ev_n = F.normalize(self.embed(entity), dim=0) tv = self.embed(new_target) candidates = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) if use_act and li in activations: q = F.normalize(activations[li], dim=0) else: q = ev_n sims = F.normalize(Wg, dim=1) @ q k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): col_norm = Wd[:, idx].norm().item() top_toks = self.decode_down_col(Wd[:, idx], top_k=3) candidates.append({ "layer": li, "slot": int(idx.item()), "sim": round(float(v.item()), 4), "col_norm": round(col_norm, 4), "inject_norm": round(col_norm * scale, 4), "current_top": [{"tok":t,"score":round(s,2)} for t,s in top_toks] }) candidates.sort(key=lambda x: -x["sim"]) best_sim = candidates[0]["sim"] if candidates else 0.0 chosen = [c for c in candidates if c["sim"] >= 0.05][:top_k] return { "candidates": chosen, "best_sim": best_sim, "would_patch": len(chosen), "new_target": new_target, "mode": "activation-guided" if use_act else "embed-based" } # ── Phase 1+: mechanistic attribution ───────────────────── def gradient_slot_scores(self, prompt: str, target: str) -> Dict: """One backward pass: grad norm of ∂(-log p(target))/∂W_down[:,slot] per KB layer. Identifies which slots causally contributed to this prediction via gradient signal.""" target_id = self.token_id(target) # Temporarily enable grad on down-proj weights down_params: List[Tuple[int, torch.nn.Parameter]] = [] for li in range(self.arch.n_layers): layer = self.arch._layer(li) p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \ else layer.mlp.down_proj.weight p.requires_grad_(True) down_params.append((li, p)) self.model.zero_grad() inputs = self.tok(prompt, return_tensors="pt").to(self.device) out = self.model(**inputs) loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id] loss.backward() layer_scores = [] for li, p in down_params: grad = p.grad p.requires_grad_(False) if grad is None: layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []}) continue # gpt2: c_proj.weight [ffn_dim, hidden] → rows = slots # gated: down_proj.weight [hidden, ffn_dim] → cols = slots slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \ else grad.norm(dim=0) # [ffn_dim] k = min(20, slot_norms.shape[0]) vals, idxs = slot_norms.topk(k) layer_scores.append({ "layer": li, "max_grad": round(float(vals[0].item()), 6), "top_slots": [{"slot": int(idx.item()), "grad_norm": round(float(v.item()), 6)} for idx, v in zip(idxs, vals)] }) self.model.zero_grad() return {"layer_scores": layer_scores} def causal_patch_trace(self, prompt: str, subject: str, target: str, noise_std: float = 0.1) -> Dict: """ROME-style causal tracing. Corrupts subject embeddings, then for each KB layer measures how much patching that layer's hidden state (at subject position) restores p(target). Expensive: O(n_layers) forward passes.""" target_id = self.token_id(target) W_u = self.arch.get_unembedding().to(self.device) inputs = self.tok(prompt, return_tensors="pt").to(self.device) ids = inputs["input_ids"][0].tolist() # Find subject token positions via subsequence match subj_ids = self.tok.encode(subject, add_special_tokens=False) subj_pos: List[int] = [] for start in range(len(ids) - len(subj_ids) + 1): if ids[start:start+len(subj_ids)] == subj_ids: subj_pos = list(range(start, start+len(subj_ids))) break if not subj_pos: for si in subj_ids: if si in ids: subj_pos = [ids.index(si)] break if not subj_pos: subj_pos = [0] # ── Clean forward — capture every layer's hidden states ── clean_hs: Dict[int, torch.Tensor] = {} clean_handles = [] def _mk_clean(li): def _h(m, inp, out): h = out[0] if isinstance(out, tuple) else out clean_hs[li] = h[0].detach().clone() # [seq, hidden] return _h for li in range(self.arch.n_layers): clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li))) with torch.no_grad(): clean_out = self.model(**inputs) for h in clean_handles: h.remove() clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item()) # ── Corrupted embeddings ── E = self.arch.get_embedding().to(self.device) emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() # [1, seq, hidden] noise_scale = emb.std().item() * noise_std for pos in subj_pos: emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale with torch.no_grad(): corr_out = self.model(inputs_embeds=emb) corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item()) # ── Causal patch sweep ── results = [] for li in range(self.kb_start, self.kb_end): def _mk_patch(target_li): def _h(m, inp, out): if target_li not in clean_hs: return out is_tuple = isinstance(out, tuple) h = list(out) if is_tuple else [out] clean = clean_hs[target_li] for pos in subj_pos: if pos < clean.shape[0]: h[0][0, pos] = clean[pos].to(h[0].device) return tuple(h) if is_tuple else h[0] return _h ph = self.arch._layer(li).register_forward_hook(_mk_patch(li)) with torch.no_grad(): patch_out = self.model(inputs_embeds=emb.clone()) ph.remove() patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item()) ie = patch_prob - corr_prob results.append({ "layer": li, "patch_prob": round(patch_prob, 6), "indirect_effect": round(ie, 6), }) return { "clean_prob": round(clean_prob, 6), "corrupt_prob": round(corr_prob, 6), "subject_pos": subj_pos, "results": results, } def smart_locate(self, prompt: str, subject: str, target: str, alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3, noise_std: float = 0.1) -> Dict: """Combined gate_sim + grad_norm + causal_effect → precise layer/slot ranking. alpha = weight for gate cosine sim beta = weight for gradient norm gamma = weight for causal indirect effect""" gate_data = self.locate(prompt, subject, target) grad_data = self.gradient_slot_scores(prompt, target) causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std) gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]} grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]} causal_map = {r["layer"]: max(0.0, r["indirect_effect"]) for r in causal_data["results"]} grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]} layers = sorted(set(gate_map) | set(grad_map) | set(causal_map)) def _norm(vals: List[float]) -> List[float]: m = max(vals) if vals else 1.0 return [v/m if m > 0 else 0.0 for v in vals] gv = [gate_map.get(l, 0.0) for l in layers] dv = [grad_map.get(l, 0.0) for l in layers] cv = [causal_map.get(l, 0.0) for l in layers] gn, dn, cn = _norm(gv), _norm(dv), _norm(cv) ranked = [] for i, l in enumerate(layers): score = alpha*gn[i] + beta*dn[i] + gamma*cn[i] ranked.append({ "layer": l, "gate_sim": round(gv[i], 4), "grad_norm": round(dv[i], 6), "causal_effect": round(cv[i], 6), "gate_sim_n": round(gn[i], 4), "grad_norm_n": round(dn[i], 4), "causal_n": round(cn[i], 4), "combined": round(score, 4), "best_slots": (grad_slots.get(l) or [])[:5], }) ranked.sort(key=lambda x: -x["combined"]) return { "ranked_layers": ranked, "phase_layer": gate_data["phase_layer"], "subject_pos": gate_data["subject_pos"], "clean_prob": causal_data["clean_prob"], "corrupt_prob": causal_data["corrupt_prob"], "recommendation": ranked[0] if ranked else None, "weights": {"alpha": alpha, "beta": beta, "gamma": gamma}, } def smart_edit(self, prompt: str, subject: str, relation: str, old_target: str, new_target: str, top_layers: int = 3, slots_per_layer: int = 2, scale: float = 1.5, noise_std: float = 0.1, alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2, log: Optional[List[str]] = None) -> Dict: """Auto edit: runs smart_locate on (prompt, subject, old_target) to find the exact layer+slot targets via gradient+causal+gate consensus, then patches those W_down columns toward embed(new_target). old_target = what the model currently predicts (used to locate) new_target = what you want to inject top_layers = how many top-ranked layers to patch slots_per_layer = gradient-identified slots to patch per layer scale = col_norm multiplier (1.5-3.0 recommended) beta > alpha because grad_norm is more reliable than gate_sim for small models.""" if log is None: log = [] self._snapshot() log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} → {new_target!r}") log.append(f" Running smart_locate on prompt: {prompt!r}") log.append(f" Weights: α={alpha} β={beta} γ={gamma} noise_std={noise_std}") sl = self.smart_locate(prompt, subject, old_target, alpha=alpha, beta=beta, gamma=gamma, noise_std=noise_std) log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}") log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}") if sl["clean_prob"] < 1e-5: log.append(" ⚠ clean_prob near zero — model barely knows this fact.") log.append(" Grad-norm signal still valid. Causal IE=0 is expected.") log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.") tv = self.embed(new_target) tv_n = F.normalize(tv, dim=0) ops = [] used = [] top_ranked = sl["ranked_layers"][:top_layers] for lr in top_ranked: li = lr["layer"] # Use gradient-identified slots — far more precise than gate cosine grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]] if not grad_slots: log.append(f" L{li}: no grad slots, skipping") continue _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) for slot in grad_slots: col_norm = Wd[:, slot].norm().item() new_col = (tv_n * col_norm * scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" ✓ L{li} slot {slot}: combined={lr['combined']} " f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} " f"inject={col_norm*scale:.4f}") used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]}) self.patches.append({ "type": "SMART_UPDATE", "entity": subject, "relation": relation, "new_target": new_target, "old_target": old_target, "smart_top": top_ranked, "ops": ops, }) self._apply_all_patches() log.append(f"\n ✓ {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}") return { "ops": ops, "used_layers": used, "smart_locate": sl, "log": log, } def infer(self, prompt: str, top_k: int = 5): probs = torch.softmax(self._forward(prompt), dim=-1) top = probs.topk(top_k) return [{"token": self.tok.decode([idx.item()]).strip(), "prob": round(val.item(), 6)} for idx, val in zip(top.indices, top.values)] def describe(self, entity: str, top_k: int = 10): ev_n = F.normalize(self.embed(entity), dim=0) all_edges = [] for li in range(self.kb_start, self.kb_end): Wg, Wd = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device); Wd = Wd.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n for fid in sims.topk(min(5,sims.shape[0])).indices: gsim = sims[fid].item() if gsim < 0.08: continue for tok, score in self.decode_down_col(Wd[:,fid], 4): if tok: all_edges.append({"tok":tok,"score":score,"layer":li,"gate_sim":gsim}) best: Dict[str, Any] = {} for e in all_edges: t = e["tok"] if t not in best or e["score"] > best[t]["score"]: best[t] = e ranked = sorted(best.values(), key=lambda x: -x["score"])[:top_k] return ranked def trace(self, prompt: str, target: str): target_id = self.token_id(target) W_u = self.arch.get_unembedding().to(self.device) stats: List[Dict] = [] handles = [] def make_hook(li): def hook(m, inp, out): h = out[0] if isinstance(out,tuple) else out last = h[0,-1].detach() p = torch.softmax(W_u @ last, dim=-1) rank = int((p > p[target_id]).sum().item()) + 1 stats.append({"l":li,"rank":rank,"prob":round(p[target_id].item(),8)}) return hook for li in range(self.arch.n_layers): handles.append(self.arch._layer(li).register_forward_hook(make_hook(li))) inputs = self.tok(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): self.model(**inputs) for h in handles: h.remove() return stats # ── patch management ─────────────────────────────────────── def _snapshot(self): if self._base_weights is not None: return self._base_weights = {} for li in range(self.arch.n_layers): Wg, Wd = self.arch.get_ffn_weights(li) self._base_weights[li] = (Wg.clone().cpu(), Wd.clone().cpu()) def _restore_base(self): if self._base_weights is None: return for li,(Wg,Wd) in self._base_weights.items(): self.arch.set_ffn_weights(li, Wg.to(self.device), Wd.to(self.device)) def _apply_all_patches(self): self._restore_base() for patch in self.patches: for op in patch.get("ops",[]): li=op["layer"]; slot=op["slot"] Wg, Wd = self.arch.get_ffn_weights(li) Wg=Wg.clone(); Wd=Wd.clone() if op["op"] in ("insert","update_gate"): Wg[slot] = torch.tensor(op["gate_row"],dtype=Wg.dtype,device=self.device) if op["op"] in ("insert","update_down"): Wd[:,slot] = torch.tensor(op["down_col"],dtype=Wd.dtype,device=self.device) self.arch.set_ffn_weights(li, Wg, Wd) def insert(self, entity, relation, target, alpha=0.25, spread=4, log=None): if log is None: log = [] self._snapshot() gate_dir = F.normalize(self.embed(entity), dim=0) down_dir = F.normalize(self.embed(target), dim=0) ls=self.kb_start; le=min(ls+spread, self.arch.n_layers) log.append(f"INSERT: '{entity}' -[{relation}]-> '{target}' alpha={alpha}") ops=[] for li in range(ls,le): Wg, Wd = self.arch.get_ffn_weights(li) Wg=Wg.to(self.device); Wd=Wd.to(self.device) norms_g=Wg.norm(dim=1); norms_d=Wd.norm(dim=0) slot=norms_g.argmin().item() wg_mean=norms_g.mean().item(); wd_mean=norms_d.mean().item() ops.append({"op":"insert","layer":li,"slot":slot, "gate_row":(gate_dir*wg_mean*alpha).cpu().tolist(), "down_col":(down_dir*wd_mean*alpha).cpu().tolist()}) log.append(f" L{li}: slot={slot} inject={wg_mean*alpha:.4f}") self.patches.append({"type":"INSERT","entity":entity,"relation":relation, "target":target,"ops":ops}) self._apply_all_patches() return ops def update(self, entity, relation, new_target, top_k=3, scale=1.0, log=None): if log is None: log = [] self._snapshot() ev = self.embed(entity) tv = self.embed(new_target) ev_n = F.normalize(ev, dim=0) log.append(f"UPDATE: '{entity}' -[{relation}]-> '{new_target}' top_k={top_k} scale={scale}") candidates = [] for li in range(self.kb_start, self.kb_end): Wg, _ = self.arch.get_ffn_weights(li) Wg = Wg.to(self.device) sims = F.normalize(Wg, dim=1) @ ev_n k = min(top_k, sims.shape[0]) vals, idxs = sims.topk(k) for v, idx in zip(vals, idxs): candidates.append((v.item(), li, idx.item())) log.append(f" L{li}: " + " ".join(f"{v:.4f}@s{i}" for v,i in zip(vals,idxs))) candidates.sort(key=lambda x: -x[0]) best_sim = candidates[0][0] if candidates else 0.0 if best_sim < 0.08: log.append(" ⚠ sim<0.08 → INSERT fallback") return self.insert(entity, relation, new_target, log=log), "INSERT_FALLBACK" chosen = [c for c in candidates if c[0] >= 0.08][:top_k] ops = [] for sim, li, slot in chosen: _, Wd = self.arch.get_ffn_weights(li) Wd = Wd.to(self.device) col_norm = Wd[:, slot].norm().item() new_col = (F.normalize(tv, dim=0)*col_norm*scale).cpu().tolist() ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col}) log.append(f" ✓ L{li} slot {slot}: sim={sim:.4f} norm={col_norm:.4f}") self.patches.append({"type":"UPDATE","entity":entity,"relation":relation, "new_target":new_target,"ops":ops}) self._apply_all_patches() li0,slot0,sim0 = chosen[0] return li0, slot0, sim0 def reset(self): self.patches.clear() self._restore_base() def save_patch(self, path: str): with open(path,"w") as f: json.dump(self.patches, f, indent=2) def compile_to(self, output_dir: str): Path(output_dir).mkdir(parents=True, exist_ok=True) self.model.save_pretrained(output_dir) self.tok.save_pretrained(output_dir) # ══════════════════════════════════════════════════════════════ # HTML FRONTEND (Phase 3) # ══════════════════════════════════════════════════════════════ HTML_PAGE = r"""
"The capital of France is""The capital of France is", target = "Paris"France, target = ParisFrance | Relation = capitalParis (what model says now — used for locate)Lyon (what you want)"The capital of France is"2.0 (start here; increase to 3.0 if effect is weak)"The capital of France is" → should now say Lyon"Biggest cities in France" → should be unchanged (different slots)"Paris is a city in" → should still say France"Lyon is a city in" → might now also say France (collateral)| Mode | Slot selection | Best for | Knobs |
|---|---|---|---|
| UPDATE | gate cosine sim to embed(entity) | Quick experiment, model knows the fact well | Top-K=3-5, Scale=1.5-3 |
| PRECISE | gate cosine sim to h_L[subject_pos] | In-context subject representation (3-5× better than UPDATE) | + Prompt field |
| ★ SMART | gradient norm → exact slots, then patch | Best overall. Auto-locates, no manual tuning | Top layers=3, Slots/layer=2, Scale=1.5-2.5 |
| INSERT | weakest slot (norm-based) | Model has no knowledge of fact, build from scratch | Alpha=0.4-0.7, Spread=4-6 |
| SUPPRESS | gate cosine → scale W_down to 0 | Make model forget an entity (factor=0) or weaken (0.5) | Factor: 0=forget, 0.5=weaken |
| STYLE-SHIFT | gate cosine → add direction vector | Bias/tone shifts: CEO→less male-coded, Paris→darker | from/to concepts, strength=0.3-0.8 |