Chris4K commited on
Commit
bdbe22a
Β·
verified Β·
1 Parent(s): 82946f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +899 -5
app.py CHANGED
@@ -466,7 +466,272 @@ class VIndex:
466
  "mode": "activation-guided" if use_act else "embed-based"
467
  }
468
 
469
- # ── base ops (unchanged) ───────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  def infer(self, prompt: str, top_k: int = 5):
472
  probs = torch.softmax(self._forward(prompt), dim=-1)
@@ -771,10 +1036,12 @@ svg text { font-family: var(--font); fill: var(--text); }
771
  <button class="tab-btn" onclick="showTab('describe')">β‘‘ Describe</button>
772
  <button class="tab-btn" onclick="showTab('trace')">β‘’ Trace</button>
773
  <button class="tab-btn" onclick="showTab('locate')">β‘£ Locate</button>
774
- <button class="tab-btn" onclick="showTab('heatmap')">β‘€ Heatmap</button>
775
- <button class="tab-btn" onclick="showTab('edit')">β‘₯ Edit</button>
776
- <button class="tab-btn" onclick="showTab('patches')">⑦ Patches</button>
777
- <button class="tab-btn" onclick="showTab('load')" style="margin-left:auto">βš™ Load</button>
 
 
778
  </div>
779
 
780
  <!-- TOOLTIP -->
@@ -925,6 +1192,86 @@ svg text { font-family: var(--font); fill: var(--text); }
925
  </div>
926
  </div>
927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  <!-- ══════════ HEATMAP PANEL ══════════ -->
929
  <div id="panel-heatmap" class="panel">
930
  <div class="card">
@@ -980,6 +1327,7 @@ svg text { font-family: var(--font); fill: var(--text); }
980
  <div class="radio-group" id="edit-mode-group">
981
  <label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label>
982
  <label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label>
 
983
  <label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label>
984
  <label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label>
985
  <label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label>
@@ -990,6 +1338,28 @@ svg text { font-family: var(--font); fill: var(--text); }
990
  <label>Prompt (PRECISE mode)</label>
991
  <input type="text" id="edit-prompt" value="The capital of France is">
992
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  <div id="style-shift-row" style="display:none;margin-top:8px">
994
  <label>From concept</label>
995
  <input type="text" id="ss-from" value="formal">
@@ -1069,6 +1439,209 @@ svg text { font-family: var(--font); fill: var(--text); }
1069
  </div>
1070
  </div>
1071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1072
  </div><!-- /app -->
1073
 
1074
  <script>
@@ -1459,6 +2032,7 @@ function showHmSlotDetail(layer, slot) {
1459
  document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{
1460
  r.addEventListener('change', ()=>{
1461
  document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none';
 
1462
  document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none';
1463
  document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none';
1464
  });
@@ -1500,6 +2074,30 @@ async function runEdit() {
1500
  to_concept: document.getElementById('ss-to').value,
1501
  strength: +document.getElementById('ss-strength').value,
1502
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1503
  if(mode==='MULTI-EDIT'){
1504
  try {
1505
  body.facts = JSON.parse(document.getElementById('multi-json').value);
@@ -1614,6 +2212,221 @@ async function updatePatchCount() {
1614
  } catch(e){}
1615
  }
1616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1617
  // ═══════════════════════════════════════════════
1618
  // INIT
1619
  // ═══════════════════════════════════════════════
@@ -1672,6 +2485,39 @@ class HeatmapReq(BaseModel):
1672
  use_activation: bool = False
1673
  prompt: Optional[str] = None
1674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1675
  class DryRunReq(BaseModel):
1676
  entity: str
1677
  new_target: str
@@ -1778,6 +2624,54 @@ async def api_locate(req: LocateReq):
1778
  return vi.locate(req.prompt, req.subject, req.target)
1779
 
1780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1781
  @app.post("/api/gate_heatmap")
1782
  async def api_gate_heatmap(req: HeatmapReq):
1783
  vi = _require()
 
466
  "mode": "activation-guided" if use_act else "embed-based"
467
  }
468
 
469
+ # ── Phase 1+: mechanistic attribution ─────────────────────
470
+
471
+ def gradient_slot_scores(self, prompt: str, target: str) -> Dict:
472
+ """One backward pass: grad norm of βˆ‚(-log p(target))/βˆ‚W_down[:,slot] per KB layer.
473
+ Identifies which slots causally contributed to this prediction via gradient signal."""
474
+ target_id = self.token_id(target)
475
+
476
+ # Temporarily enable grad on down-proj weights
477
+ down_params: List[Tuple[int, torch.nn.Parameter]] = []
478
+ for li in range(self.arch.n_layers):
479
+ layer = self.arch._layer(li)
480
+ p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \
481
+ else layer.mlp.down_proj.weight
482
+ p.requires_grad_(True)
483
+ down_params.append((li, p))
484
+
485
+ self.model.zero_grad()
486
+ inputs = self.tok(prompt, return_tensors="pt").to(self.device)
487
+ out = self.model(**inputs)
488
+ loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id]
489
+ loss.backward()
490
+
491
+ layer_scores = []
492
+ for li, p in down_params:
493
+ grad = p.grad
494
+ p.requires_grad_(False)
495
+ if grad is None:
496
+ layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []})
497
+ continue
498
+ # gpt2: c_proj.weight [ffn_dim, hidden] β†’ rows = slots
499
+ # gated: down_proj.weight [hidden, ffn_dim] β†’ cols = slots
500
+ slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \
501
+ else grad.norm(dim=0) # [ffn_dim]
502
+ k = min(20, slot_norms.shape[0])
503
+ vals, idxs = slot_norms.topk(k)
504
+ layer_scores.append({
505
+ "layer": li,
506
+ "max_grad": round(float(vals[0].item()), 6),
507
+ "top_slots": [{"slot": int(idx.item()),
508
+ "grad_norm": round(float(v.item()), 6)}
509
+ for idx, v in zip(idxs, vals)]
510
+ })
511
+
512
+ self.model.zero_grad()
513
+ return {"layer_scores": layer_scores}
514
+
515
+ def causal_patch_trace(self, prompt: str, subject: str, target: str,
516
+ noise_std: float = 0.1) -> Dict:
517
+ """ROME-style causal tracing.
518
+ Corrupts subject embeddings, then for each KB layer measures how much
519
+ patching that layer's hidden state (at subject position) restores p(target).
520
+ Expensive: O(n_layers) forward passes."""
521
+ target_id = self.token_id(target)
522
+ W_u = self.arch.get_unembedding().to(self.device)
523
+ inputs = self.tok(prompt, return_tensors="pt").to(self.device)
524
+ ids = inputs["input_ids"][0].tolist()
525
+
526
+ # Find subject token positions via subsequence match
527
+ subj_ids = self.tok.encode(subject, add_special_tokens=False)
528
+ subj_pos: List[int] = []
529
+ for start in range(len(ids) - len(subj_ids) + 1):
530
+ if ids[start:start+len(subj_ids)] == subj_ids:
531
+ subj_pos = list(range(start, start+len(subj_ids)))
532
+ break
533
+ if not subj_pos:
534
+ for si in subj_ids:
535
+ if si in ids:
536
+ subj_pos = [ids.index(si)]
537
+ break
538
+ if not subj_pos:
539
+ subj_pos = [0]
540
+
541
+ # ── Clean forward β€” capture every layer's hidden states ──
542
+ clean_hs: Dict[int, torch.Tensor] = {}
543
+ clean_handles = []
544
+ def _mk_clean(li):
545
+ def _h(m, inp, out):
546
+ h = out[0] if isinstance(out, tuple) else out
547
+ clean_hs[li] = h[0].detach().clone() # [seq, hidden]
548
+ return _h
549
+ for li in range(self.arch.n_layers):
550
+ clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li)))
551
+ with torch.no_grad():
552
+ clean_out = self.model(**inputs)
553
+ for h in clean_handles: h.remove()
554
+ clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item())
555
+
556
+ # ── Corrupted embeddings ──
557
+ E = self.arch.get_embedding().to(self.device)
558
+ emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() # [1, seq, hidden]
559
+ noise_scale = emb.std().item() * noise_std
560
+ for pos in subj_pos:
561
+ emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale
562
+
563
+ with torch.no_grad():
564
+ corr_out = self.model(inputs_embeds=emb)
565
+ corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item())
566
+
567
+ # ── Causal patch sweep ──
568
+ results = []
569
+ for li in range(self.kb_start, self.kb_end):
570
+ def _mk_patch(target_li):
571
+ def _h(m, inp, out):
572
+ if target_li not in clean_hs:
573
+ return out
574
+ is_tuple = isinstance(out, tuple)
575
+ h = list(out) if is_tuple else [out]
576
+ clean = clean_hs[target_li]
577
+ for pos in subj_pos:
578
+ if pos < clean.shape[0]:
579
+ h[0][0, pos] = clean[pos].to(h[0].device)
580
+ return tuple(h) if is_tuple else h[0]
581
+ return _h
582
+ ph = self.arch._layer(li).register_forward_hook(_mk_patch(li))
583
+ with torch.no_grad():
584
+ patch_out = self.model(inputs_embeds=emb.clone())
585
+ ph.remove()
586
+ patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item())
587
+ ie = patch_prob - corr_prob
588
+ results.append({
589
+ "layer": li,
590
+ "patch_prob": round(patch_prob, 6),
591
+ "indirect_effect": round(ie, 6),
592
+ })
593
+
594
+ return {
595
+ "clean_prob": round(clean_prob, 6),
596
+ "corrupt_prob": round(corr_prob, 6),
597
+ "subject_pos": subj_pos,
598
+ "results": results,
599
+ }
600
+
601
+ def smart_locate(self, prompt: str, subject: str, target: str,
602
+ alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3,
603
+ noise_std: float = 0.1) -> Dict:
604
+ """Combined gate_sim + grad_norm + causal_effect β†’ precise layer/slot ranking.
605
+ alpha = weight for gate cosine sim
606
+ beta = weight for gradient norm
607
+ gamma = weight for causal indirect effect"""
608
+ gate_data = self.locate(prompt, subject, target)
609
+ grad_data = self.gradient_slot_scores(prompt, target)
610
+ causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std)
611
+
612
+ gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]}
613
+ grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]}
614
+ causal_map = {r["layer"]: max(0.0, r["indirect_effect"])
615
+ for r in causal_data["results"]}
616
+ grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]}
617
+
618
+ layers = sorted(set(gate_map) | set(grad_map) | set(causal_map))
619
+
620
+ def _norm(vals: List[float]) -> List[float]:
621
+ m = max(vals) if vals else 1.0
622
+ return [v/m if m > 0 else 0.0 for v in vals]
623
+
624
+ gv = [gate_map.get(l, 0.0) for l in layers]
625
+ dv = [grad_map.get(l, 0.0) for l in layers]
626
+ cv = [causal_map.get(l, 0.0) for l in layers]
627
+ gn, dn, cn = _norm(gv), _norm(dv), _norm(cv)
628
+
629
+ ranked = []
630
+ for i, l in enumerate(layers):
631
+ score = alpha*gn[i] + beta*dn[i] + gamma*cn[i]
632
+ ranked.append({
633
+ "layer": l,
634
+ "gate_sim": round(gv[i], 4),
635
+ "grad_norm": round(dv[i], 6),
636
+ "causal_effect": round(cv[i], 6),
637
+ "gate_sim_n": round(gn[i], 4),
638
+ "grad_norm_n": round(dn[i], 4),
639
+ "causal_n": round(cn[i], 4),
640
+ "combined": round(score, 4),
641
+ "best_slots": (grad_slots.get(l) or [])[:5],
642
+ })
643
+
644
+ ranked.sort(key=lambda x: -x["combined"])
645
+
646
+ return {
647
+ "ranked_layers": ranked,
648
+ "phase_layer": gate_data["phase_layer"],
649
+ "subject_pos": gate_data["subject_pos"],
650
+ "clean_prob": causal_data["clean_prob"],
651
+ "corrupt_prob": causal_data["corrupt_prob"],
652
+ "recommendation": ranked[0] if ranked else None,
653
+ "weights": {"alpha": alpha, "beta": beta, "gamma": gamma},
654
+ }
655
+
656
+ def smart_edit(self, prompt: str, subject: str, relation: str,
657
+ old_target: str, new_target: str,
658
+ top_layers: int = 3, slots_per_layer: int = 2,
659
+ scale: float = 1.5, noise_std: float = 0.1,
660
+ alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2,
661
+ log: Optional[List[str]] = None) -> Dict:
662
+ """Auto edit: runs smart_locate on (prompt, subject, old_target) to find
663
+ the exact layer+slot targets via gradient+causal+gate consensus, then
664
+ patches those W_down columns toward embed(new_target).
665
+
666
+ old_target = what the model currently predicts (used to locate)
667
+ new_target = what you want to inject
668
+ top_layers = how many top-ranked layers to patch
669
+ slots_per_layer = gradient-identified slots to patch per layer
670
+ scale = col_norm multiplier (1.5-3.0 recommended)
671
+ beta > alpha because grad_norm is more reliable than gate_sim for small models."""
672
+ if log is None: log = []
673
+ self._snapshot()
674
+
675
+ log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} β†’ {new_target!r}")
676
+ log.append(f" Running smart_locate on prompt: {prompt!r}")
677
+ log.append(f" Weights: οΏ½οΏ½={alpha} Ξ²={beta} Ξ³={gamma} noise_std={noise_std}")
678
+
679
+ sl = self.smart_locate(prompt, subject, old_target,
680
+ alpha=alpha, beta=beta, gamma=gamma,
681
+ noise_std=noise_std)
682
+
683
+ log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}")
684
+ log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}")
685
+
686
+ if sl["clean_prob"] < 1e-5:
687
+ log.append(" ⚠ clean_prob near zero β€” model barely knows this fact.")
688
+ log.append(" Grad-norm signal still valid. Causal IE=0 is expected.")
689
+ log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.")
690
+
691
+ tv = self.embed(new_target)
692
+ tv_n = F.normalize(tv, dim=0)
693
+ ops = []
694
+ used = []
695
+
696
+ top_ranked = sl["ranked_layers"][:top_layers]
697
+ for lr in top_ranked:
698
+ li = lr["layer"]
699
+ # Use gradient-identified slots β€” far more precise than gate cosine
700
+ grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]]
701
+ if not grad_slots:
702
+ log.append(f" L{li}: no grad slots, skipping")
703
+ continue
704
+ _, Wd = self.arch.get_ffn_weights(li)
705
+ Wd = Wd.to(self.device)
706
+ for slot in grad_slots:
707
+ col_norm = Wd[:, slot].norm().item()
708
+ new_col = (tv_n * col_norm * scale).cpu().tolist()
709
+ ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
710
+ log.append(f" βœ“ L{li} slot {slot}: combined={lr['combined']} "
711
+ f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} "
712
+ f"inject={col_norm*scale:.4f}")
713
+ used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]})
714
+
715
+ self.patches.append({
716
+ "type": "SMART_UPDATE",
717
+ "entity": subject,
718
+ "relation": relation,
719
+ "new_target": new_target,
720
+ "old_target": old_target,
721
+ "smart_top": top_ranked,
722
+ "ops": ops,
723
+ })
724
+ self._apply_all_patches()
725
+ log.append(f"\n βœ“ {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}")
726
+
727
+ return {
728
+ "ops": ops,
729
+ "used_layers": used,
730
+ "smart_locate": sl,
731
+ "log": log,
732
+ }
733
+
734
+
735
 
736
  def infer(self, prompt: str, top_k: int = 5):
737
  probs = torch.softmax(self._forward(prompt), dim=-1)
 
1036
  <button class="tab-btn" onclick="showTab('describe')">β‘‘ Describe</button>
1037
  <button class="tab-btn" onclick="showTab('trace')">β‘’ Trace</button>
1038
  <button class="tab-btn" onclick="showTab('locate')">β‘£ Locate</button>
1039
+ <button class="tab-btn" onclick="showTab('smartlocate')">β‘€ Smart Locate</button>
1040
+ <button class="tab-btn" onclick="showTab('heatmap')">β‘₯ Heatmap</button>
1041
+ <button class="tab-btn" onclick="showTab('edit')">⑦ Edit</button>
1042
+ <button class="tab-btn" onclick="showTab('patches')">β‘§ Patches</button>
1043
+ <button class="tab-btn" onclick="showTab('guide')" style="margin-left:auto;color:var(--green)">πŸ“– Guide</button>
1044
+ <button class="tab-btn" onclick="showTab('load')">βš™ Load</button>
1045
  </div>
1046
 
1047
  <!-- TOOLTIP -->
 
1192
  </div>
1193
  </div>
1194
 
1195
+ <!-- ══════════ SMART LOCATE PANEL ══════════ -->
1196
+ <div id="panel-smartlocate" class="panel">
1197
+ <div class="card">
1198
+ <h3>Smart Locate β€” gradient + causal + gate_sim combined</h3>
1199
+ <div style="color:var(--muted);font-size:11px;margin-bottom:12px;line-height:1.7">
1200
+ Three independent signals combined into one ranked layer list.<br>
1201
+ <span style="color:var(--blue)">β–  gate_sim</span> β€” static embedding cosine (fast, weak proxy) &nbsp;
1202
+ <span style="color:var(--green)">β–  grad_norm</span> β€” βˆ‚loss/βˆ‚W_down per slot (one backward pass) &nbsp;
1203
+ <span style="color:var(--yellow)">β–  causal IE</span> β€” indirect effect via subject-corruption patching (N_layers passes, slow)
1204
+ </div>
1205
+ <div class="row">
1206
+ <div class="col2">
1207
+ <label>Prompt</label>
1208
+ <input type="text" id="sl-prompt" value="The capital of France is">
1209
+ </div>
1210
+ <div class="col">
1211
+ <label>Subject</label>
1212
+ <input type="text" id="sl-subject" value="France">
1213
+ </div>
1214
+ <div class="col">
1215
+ <label>Target</label>
1216
+ <input type="text" id="sl-target" value="Paris">
1217
+ </div>
1218
+ </div>
1219
+ <div class="row" style="margin-top:10px">
1220
+ <div class="col">
1221
+ <label>Ξ± gate_sim: <span id="sl-a-val">0.4</span></label>
1222
+ <input type="range" id="sl-alpha" min="0" max="1" step="0.05" value="0.4"
1223
+ oninput="document.getElementById('sl-a-val').textContent=this.value">
1224
+ </div>
1225
+ <div class="col">
1226
+ <label>Ξ² grad_norm: <span id="sl-b-val">0.3</span></label>
1227
+ <input type="range" id="sl-beta" min="0" max="1" step="0.05" value="0.3"
1228
+ oninput="document.getElementById('sl-b-val').textContent=this.value">
1229
+ </div>
1230
+ <div class="col">
1231
+ <label>Ξ³ causal: <span id="sl-g-val">0.3</span></label>
1232
+ <input type="range" id="sl-gamma" min="0" max="1" step="0.05" value="0.3"
1233
+ oninput="document.getElementById('sl-g-val').textContent=this.value">
1234
+ </div>
1235
+ <div class="col">
1236
+ <label>Noise Οƒ: <span id="sl-noise-val">0.1</span></label>
1237
+ <input type="range" id="sl-noise" min="0.02" max="0.5" step="0.02" value="0.1"
1238
+ oninput="document.getElementById('sl-noise-val').textContent=this.value">
1239
+ </div>
1240
+ </div>
1241
+ <div style="display:flex;gap:8px;margin-top:12px;flex-wrap:wrap">
1242
+ <button onclick="runSmartLocate()">⚑ Smart Locate (full)</button>
1243
+ <button class="secondary" onclick="runGradientOnly()">β–Ά Gradient only (fast)</button>
1244
+ <button class="secondary" onclick="runCausalOnly()">β–Ά Causal trace only</button>
1245
+ </div>
1246
+ <div id="sl-status" style="color:var(--muted);font-size:11px;margin-top:8px"></div>
1247
+ </div>
1248
+
1249
+ <div class="row">
1250
+ <div class="col2 card">
1251
+ <h3>Layer Rankings β€” 3-signal stacked bars</h3>
1252
+ <div class="chart-wrap" id="sl-chart" style="min-height:320px"></div>
1253
+ </div>
1254
+ <div class="col card">
1255
+ <h3>Recommendation</h3>
1256
+ <div id="sl-rec" class="log">Run Smart Locate to see the best edit target.</div>
1257
+ <h3 style="margin-top:14px">Collateral Probe</h3>
1258
+ <div class="row" style="margin-top:8px">
1259
+ <input type="text" id="sl-coll-prompt" value="Biggest cities in France"
1260
+ style="flex:2" placeholder="Collateral prompt…">
1261
+ <button class="secondary" onclick="runCollateral()" style="flex:0">β–Ά</button>
1262
+ </div>
1263
+ <div id="sl-coll-out" class="log" style="margin-top:8px">Probe a prompt to check collateral damage.</div>
1264
+ </div>
1265
+ </div>
1266
+
1267
+ <div class="card">
1268
+ <h3>Per-Layer Detail</h3>
1269
+ <div id="sl-table" style="overflow-x:auto">
1270
+ <div style="color:var(--muted);font-size:11px">Run Smart Locate first.</div>
1271
+ </div>
1272
+ </div>
1273
+ </div>
1274
+
1275
  <!-- ══════════ HEATMAP PANEL ══════════ -->
1276
  <div id="panel-heatmap" class="panel">
1277
  <div class="card">
 
1327
  <div class="radio-group" id="edit-mode-group">
1328
  <label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label>
1329
  <label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label>
1330
+ <label><input type="radio" name="edit-mode" value="SMART"> β˜… SMART</label>
1331
  <label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label>
1332
  <label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label>
1333
  <label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label>
 
1338
  <label>Prompt (PRECISE mode)</label>
1339
  <input type="text" id="edit-prompt" value="The capital of France is">
1340
  </div>
1341
+ <div id="smart-row" style="display:none;margin-top:8px;background:var(--bg);border:1px solid var(--border);border-radius:6px;padding:10px">
1342
+ <div style="color:var(--blue);font-size:11px;font-weight:700;margin-bottom:6px">β˜… SMART AUTO MODE</div>
1343
+ <label>Prompt (used for locate + after-check)</label>
1344
+ <input type="text" id="smart-prompt" value="The capital of France is">
1345
+ <label style="margin-top:6px">Old value (what model currently says)</label>
1346
+ <input type="text" id="smart-old" value="Paris">
1347
+ <div class="row" style="margin-top:6px">
1348
+ <div class="col">
1349
+ <label>Top layers: <span id="smart-layers-val">3</span></label>
1350
+ <input type="range" id="smart-layers" min="1" max="8" value="3"
1351
+ oninput="document.getElementById('smart-layers-val').textContent=this.value">
1352
+ </div>
1353
+ <div class="col">
1354
+ <label>Slots/layer: <span id="smart-slots-val">2</span></label>
1355
+ <input type="range" id="smart-slots" min="1" max="5" value="2"
1356
+ oninput="document.getElementById('smart-slots-val').textContent=this.value">
1357
+ </div>
1358
+ </div>
1359
+ <div style="color:var(--muted);font-size:10px;margin-top:6px">
1360
+ Runs smart_locate internally β†’ patches gradient-identified slots. No manual tuning needed.
1361
+ </div>
1362
+ </div>
1363
  <div id="style-shift-row" style="display:none;margin-top:8px">
1364
  <label>From concept</label>
1365
  <input type="text" id="ss-from" value="formal">
 
1439
  </div>
1440
  </div>
1441
 
1442
+ <!-- ══════════ GUIDE PANEL ══════════ -->
1443
+ <div id="panel-guide" class="panel">
1444
+ <div class="card">
1445
+ <h3>What is VINDEX doing?</h3>
1446
+ <div style="line-height:1.9;color:var(--muted)">
1447
+ In a transformer, factual associations like <span style="color:var(--text)">"France β†’ capital β†’ Paris"</span>
1448
+ are stored as direction vectors in the <span style="color:var(--blue)">W_down columns</span> of FFN layers.
1449
+ The <span style="color:var(--blue)">W_gate rows</span> act as keys: when the residual stream resembles "France",
1450
+ the matching gate fires, the down column adds "Paris" direction to the stream, and the unembedding reads out "Paris".
1451
+ VINDEX surgically replaces those down columns without retraining.
1452
+ </div>
1453
+ </div>
1454
+
1455
+ <div class="card">
1456
+ <h3>Quickstart β€” 5-step experiment</h3>
1457
+ <div style="line-height:2;font-size:12px">
1458
+ <div style="color:var(--yellow);margin-bottom:4px">Step 1 β€” Load a model that actually knows facts</div>
1459
+ <div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
1460
+ βš™ Load tab β†’ <span style="color:var(--blue)">gpt2-medium</span> (1.5 GB, knows capitals) or
1461
+ <span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> (3 GB, strong).<br>
1462
+ distilgpt2 has clean_probβ‰ˆ0 for most facts β†’ causal IE=0 everywhere β†’ misleading results.
1463
+ </div>
1464
+
1465
+ <div style="color:var(--yellow);margin-bottom:4px">Step 2 β€” Verify the model knows the fact</div>
1466
+ <div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
1467
+ β‘  Infer: prompt = <code>"The capital of France is"</code><br>
1468
+ βœ“ Good: "Paris" appears in top-3 with prob &gt; 0.05<br>
1469
+ βœ— Bad: top tokens are "a", "the", "known" β†’ model doesn't know it β†’ skip to INSERT mode
1470
+ </div>
1471
+
1472
+ <div style="color:var(--yellow);margin-bottom:4px">Step 3 β€” Find where the fact lives</div>
1473
+ <div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
1474
+ β‘’ Trace: prompt = <code>"The capital of France is"</code>, target = <code>"Paris"</code><br>
1475
+ β†’ Look for phase layer: where rank drops from ~30000 to &lt;100. That's where the fact materializes.<br>
1476
+ β‘€ Smart Locate β†’ Gradient only (fast, 1 backward pass):<br>
1477
+ <span style="margin-left:16px">subject = <code>France</code>, target = <code>Paris</code></span><br>
1478
+ β†’ The layer with highest grad_norm bar = best edit target. Note the slot numbers.
1479
+ </div>
1480
+
1481
+ <div style="color:var(--yellow);margin-bottom:4px">Step 4 β€” Edit with SMART mode</div>
1482
+ <div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
1483
+ ⑦ Edit tab β†’ mode = <span style="color:var(--blue)">β˜… SMART</span><br>
1484
+ Entity = <code>France</code> | Relation = <code>capital</code><br>
1485
+ Old value = <code>Paris</code> (what model says now β€” used for locate)<br>
1486
+ New value = <code>Lyon</code> (what you want)<br>
1487
+ Prompt = <code>"The capital of France is"</code><br>
1488
+ Scale = <code>2.0</code> (start here; increase to 3.0 if effect is weak)<br>
1489
+ β†’ Click <b>Apply Edit</b>. Smart locate runs internally, patches grad-identified slots.
1490
+ </div>
1491
+
1492
+ <div style="color:var(--yellow);margin-bottom:4px">Step 5 β€” Check collateral damage</div>
1493
+ <div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
1494
+ β‘  Infer: <code>"The capital of France is"</code> β†’ should now say Lyon<br>
1495
+ β‘  Infer: <code>"Biggest cities in France"</code> β†’ should be unchanged (different slots)<br>
1496
+ β‘  Infer: <code>"Paris is a city in"</code> β†’ should still say France<br>
1497
+ β‘  Infer: <code>"Lyon is a city in"</code> β†’ might now also say France (collateral)<br>
1498
+ β‘€ Smart Locate collateral probe β†’ run these prompts, compare slot lists in β‘§ Patches
1499
+ </div>
1500
+ </div>
1501
+ </div>
1502
+
1503
+ <div class="card">
1504
+ <h3>Interpreting Smart Locate results</h3>
1505
+ <div style="font-size:11px;line-height:1.9">
1506
+ <div class="row" style="gap:20px">
1507
+ <div class="col">
1508
+ <div style="color:var(--blue);font-weight:700;margin-bottom:6px">β–  gate_sim (blue)</div>
1509
+ <div style="color:var(--muted)">
1510
+ Cosine between W_gate[slot] and embed(subject).<br>
1511
+ Fast, cheap, but <b>weak proxy</b> β€” measures embedding-space similarity,<br>
1512
+ not causal contribution. Useful for finding <i>related</i> slots.<br>
1513
+ <b>High gate_sim + low grad_norm</b> = slot activates for this entity<br>
1514
+ but doesn't contribute much to this specific prediction.
1515
+ </div>
1516
+ </div>
1517
+ <div class="col">
1518
+ <div style="color:var(--green);font-weight:700;margin-bottom:6px">β–  grad_norm (green)</div>
1519
+ <div style="color:var(--muted)">
1520
+ β€–βˆ‚(-log p(target))/βˆ‚W_down[:,slot]β€– β€” how much changing this slot<br>
1521
+ would affect the loss for this (prompt, target) pair.<br>
1522
+ <b>Most reliable signal</b>, works even when clean_prob is tiny.<br>
1523
+ One backward pass. Use Ξ² &gt; Ξ± to weight this higher.<br>
1524
+ <b>High grad_norm</b> = this slot is causally upstream of the prediction.
1525
+ </div>
1526
+ </div>
1527
+ <div class="col">
1528
+ <div style="color:var(--yellow);font-weight:700;margin-bottom:6px">β–  causal IE (yellow)</div>
1529
+ <div style="color:var(--muted)">
1530
+ Indirect effect via noise-corruption patching (ROME-style).<br>
1531
+ Measures: if I corrupt subject embeddings, how much does patching<br>
1532
+ layer L's hidden state at subject pos <i>restore</i> the prediction?<br>
1533
+ <b>Most interpretable</b> β€” true causal measurement. But:<br>
1534
+ If clean_prob β‰ˆ 0, IE = 0 everywhere (nothing to restore).<br>
1535
+ Needs a model that actually knows the fact.
1536
+ </div>
1537
+ </div>
1538
+ </div>
1539
+ <div style="margin-top:12px;padding:10px;background:var(--bg);border-radius:6px;border:1px solid var(--border)">
1540
+ <span style="color:var(--yellow)">⚠ Your distilgpt2 result:</span>
1541
+ <span style="color:var(--muted)"> clean_prob=0.000001 β†’ causal IE=0 everywhere (expected, not a bug).
1542
+ grad_norm on L9/slot515 IS real signal β€” that slot responds to France+capital context in the gradient sense.
1543
+ But the probability mass is too diffuse to show causal separation.
1544
+ Switch to gpt2-medium for textbook causal results.</span>
1545
+ </div>
1546
+ </div>
1547
+ </div>
1548
+
1549
+ <div class="card">
1550
+ <h3>Edit modes β€” when to use which</h3>
1551
+ <div style="font-size:11px">
1552
+ <table style="width:100%;border-collapse:collapse">
1553
+ <thead><tr style="border-bottom:1px solid var(--border);color:var(--muted)">
1554
+ <th style="padding:6px 8px;text-align:left">Mode</th>
1555
+ <th style="padding:6px 8px;text-align:left">Slot selection</th>
1556
+ <th style="padding:6px 8px;text-align:left">Best for</th>
1557
+ <th style="padding:6px 8px;text-align:left">Knobs</th>
1558
+ </tr></thead>
1559
+ <tbody style="color:var(--muted)">
1560
+ <tr style="border-bottom:1px solid var(--border)">
1561
+ <td style="padding:6px 8px;color:var(--blue)">UPDATE</td>
1562
+ <td style="padding:6px 8px">gate cosine sim to embed(entity)</td>
1563
+ <td style="padding:6px 8px">Quick experiment, model knows the fact well</td>
1564
+ <td style="padding:6px 8px">Top-K=3-5, Scale=1.5-3</td>
1565
+ </tr>
1566
+ <tr style="border-bottom:1px solid var(--border)">
1567
+ <td style="padding:6px 8px;color:var(--purple)">PRECISE</td>
1568
+ <td style="padding:6px 8px">gate cosine sim to h_L[subject_pos]</td>
1569
+ <td style="padding:6px 8px">In-context subject representation (3-5Γ— better than UPDATE)</td>
1570
+ <td style="padding:6px 8px">+ Prompt field</td>
1571
+ </tr>
1572
+ <tr style="border-bottom:1px solid var(--border)">
1573
+ <td style="padding:6px 8px;color:var(--yellow)">β˜… SMART</td>
1574
+ <td style="padding:6px 8px">gradient norm β†’ exact slots, then patch</td>
1575
+ <td style="padding:6px 8px"><b>Best overall.</b> Auto-locates, no manual tuning</td>
1576
+ <td style="padding:6px 8px">Top layers=3, Slots/layer=2, Scale=1.5-2.5</td>
1577
+ </tr>
1578
+ <tr style="border-bottom:1px solid var(--border)">
1579
+ <td style="padding:6px 8px;color:var(--green)">INSERT</td>
1580
+ <td style="padding:6px 8px">weakest slot (norm-based)</td>
1581
+ <td style="padding:6px 8px">Model has no knowledge of fact, build from scratch</td>
1582
+ <td style="padding:6px 8px">Alpha=0.4-0.7, Spread=4-6</td>
1583
+ </tr>
1584
+ <tr style="border-bottom:1px solid var(--border)">
1585
+ <td style="padding:6px 8px;color:var(--red)">SUPPRESS</td>
1586
+ <td style="padding:6px 8px">gate cosine β†’ scale W_down to 0</td>
1587
+ <td style="padding:6px 8px">Make model forget an entity (factor=0) or weaken (0.5)</td>
1588
+ <td style="padding:6px 8px">Factor: 0=forget, 0.5=weaken</td>
1589
+ </tr>
1590
+ <tr style="border-bottom:1px solid var(--border)">
1591
+ <td style="padding:6px 8px;color:var(--cyan)">STYLE-SHIFT</td>
1592
+ <td style="padding:6px 8px">gate cosine β†’ add direction vector</td>
1593
+ <td style="padding:6px 8px">Bias/tone shifts: CEO→less male-coded, Paris→darker</td>
1594
+ <td style="padding:6px 8px">from/to concepts, strength=0.3-0.8</td>
1595
+ </tr>
1596
+ </tbody>
1597
+ </table>
1598
+ </div>
1599
+ </div>
1600
+
1601
+ <div class="card">
1602
+ <h3>Experiments to run</h3>
1603
+ <div style="font-size:11px;line-height:1.9;color:var(--muted)">
1604
+ <div style="color:var(--text);margin-bottom:4px">Experiment A β€” Capital swap (classic ROME benchmark)</div>
1605
+ Model: gpt2-medium | Prompt: "The capital of France is" | Old: Paris | New: Lyon<br>
1606
+ Check: "France's capital city" | "Lyon is now" | "Paris is in" | "Eiffel Tower is in"<br>
1607
+ Insight: does it generalize (paraphrase) or is it prompt-specific?<br><br>
1608
+
1609
+ <div style="color:var(--text);margin-bottom:4px">Experiment B β€” Slot overlap analysis (your collateral question)</div>
1610
+ 1. SMART locate "The capital of France is" β†’ note slot numbers in recommendation<br>
1611
+ 2. SMART locate "The biggest city in France is" β†’ compare slot lists<br>
1612
+ 3. Overlap = slots that will be collaterally damaged<br>
1613
+ 4. No overlap = clean surgery βœ“<br><br>
1614
+
1615
+ <div style="color:var(--text);margin-bottom:4px">Experiment C β€” Suppression then INSERT</div>
1616
+ SUPPRESS France β†’ then INSERT France capital Lyon β†’ Infer<br>
1617
+ vs just UPDATE. Which gives cleaner, more confident result?<br><br>
1618
+
1619
+ <div style="color:var(--text);margin-bottom:4px">Experiment D β€” Style shift (no factual change)</div>
1620
+ STYLE-SHIFT: anchor=CEO, from="male", to="female", strength=0.3<br>
1621
+ Then Infer: "The CEO of the company is a" β€” does pronoun distribution shift?<br>
1622
+ Insight: this is mechanical debiasing without retraining.<br><br>
1623
+
1624
+ <div style="color:var(--text);margin-bottom:4px">Experiment E β€” Compile and compare</div>
1625
+ Edit 5 facts. Compile β†’ save as new model directory.<br>
1626
+ Load compiled model fresh β†’ Infer same prompts β†’ edits should persist in weights.<br>
1627
+ Then Trace on compiled model β†’ phase layers should shift or sharpen.
1628
+ </div>
1629
+ </div>
1630
+
1631
+ <div class="card">
1632
+ <h3>Ξ± Ξ² Ξ³ tuning guide</h3>
1633
+ <div style="font-size:11px;line-height:1.9;color:var(--muted)">
1634
+ <b style="color:var(--text)">Default (0.4 / 0.3 / 0.3)</b> β€” balanced, works for unknown model quality<br>
1635
+ <b style="color:var(--text)">Grad-heavy (0.1 / 0.7 / 0.2)</b> β€” clean_prob &gt; 0.01. Grad signal is sharp, trust it.<br>
1636
+ <b style="color:var(--text)">Gate+Grad (0.4 / 0.4 / 0.2)</b> β€” recommended for smart_edit when causal IE is weak<br>
1637
+ <b style="color:var(--text)">Causal-heavy (0.2 / 0.2 / 0.6)</b> β€” only when clean_prob &gt; 0.1. IE is the gold signal then.<br>
1638
+ <b style="color:var(--text)">Gate-only (1.0 / 0.0 / 0.0)</b> β€” equivalent to basic locate(), sanity check<br>
1639
+ <br>
1640
+ <b style="color:var(--yellow)">Your distilgpt2 setting:</b> use (0.3 / 0.7 / 0.0) β€” gate+grad, skip causal (it's 0 anyway).
1641
+ </div>
1642
+ </div>
1643
+ </div>
1644
+
1645
  </div><!-- /app -->
1646
 
1647
  <script>
 
2032
  document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{
2033
  r.addEventListener('change', ()=>{
2034
  document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none';
2035
+ document.getElementById('smart-row').style.display = r.value==='SMART'?'block':'none';
2036
  document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none';
2037
  document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none';
2038
  });
 
2074
  to_concept: document.getElementById('ss-to').value,
2075
  strength: +document.getElementById('ss-strength').value,
2076
  };
2077
+ if(mode==='SMART'){
2078
+ try {
2079
+ const r = await api('/api/smart_edit', {
2080
+ prompt: document.getElementById('smart-prompt').value,
2081
+ subject: document.getElementById('edit-entity').value,
2082
+ relation: document.getElementById('edit-relation').value,
2083
+ old_target: document.getElementById('smart-old').value,
2084
+ new_target: document.getElementById('edit-new').value,
2085
+ top_layers: +document.getElementById('smart-layers').value,
2086
+ slots_per_layer: +document.getElementById('smart-slots').value,
2087
+ scale: +document.getElementById('edit-scale').value,
2088
+ noise_std: 0.1, alpha: 0.4, beta: 0.4, gamma: 0.2,
2089
+ });
2090
+ drawBeforeAfterChart(r.before, r.after);
2091
+ let log = r.debug_log.join('\n');
2092
+ log += '\n\nUsed layers:\n';
2093
+ r.used_layers.forEach(l=>{ log+=` L${l.layer} slots=[${l.slots.join(',')}] combined=${l.combined}\n`; });
2094
+ log += '\nDelta:\n';
2095
+ r.delta.slice(0,8).forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β†’ ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; });
2096
+ document.getElementById('edit-log').textContent = log;
2097
+ updatePatchCount();
2098
+ } catch(e) { alert(e.message); }
2099
+ return;
2100
+ }
2101
  if(mode==='MULTI-EDIT'){
2102
  try {
2103
  body.facts = JSON.parse(document.getElementById('multi-json').value);
 
2212
  } catch(e){}
2213
  }
2214
 
2215
+ // ═══════════════════════════════════════════════
2216
+ // SMART LOCATE
2217
+ // ═══════════════════════════════════════════════
2218
+ let _slData = null;
2219
+
2220
+ async function runSmartLocate() {
2221
+ const st = document.getElementById('sl-status');
2222
+ st.textContent = '⏳ Running gradient pass + causal sweep (may take ~20s for large models)…';
2223
+ try {
2224
+ const data = await api('/api/smart_locate', {
2225
+ prompt: document.getElementById('sl-prompt').value,
2226
+ subject: document.getElementById('sl-subject').value,
2227
+ target: document.getElementById('sl-target').value,
2228
+ alpha: +document.getElementById('sl-alpha').value,
2229
+ beta: +document.getElementById('sl-beta').value,
2230
+ gamma: +document.getElementById('sl-gamma').value,
2231
+ noise_std: +document.getElementById('sl-noise').value,
2232
+ });
2233
+ _slData = data;
2234
+ st.textContent = `βœ“ Done. clean_prob=${data.clean_prob.toFixed(4)} corrupt_prob=${data.corrupt_prob.toFixed(4)}`;
2235
+ drawSmartLocateChart(data.ranked_layers);
2236
+ showSmartRec(data);
2237
+ buildSlTable(data.ranked_layers);
2238
+ } catch(e) { st.textContent = 'βœ— '+e.message; }
2239
+ }
2240
+
2241
+ async function runGradientOnly() {
2242
+ const st = document.getElementById('sl-status');
2243
+ st.textContent = '⏳ Running gradient pass…';
2244
+ try {
2245
+ const data = await api('/api/gradient_scores', {
2246
+ prompt: document.getElementById('sl-prompt').value,
2247
+ target: document.getElementById('sl-target').value,
2248
+ });
2249
+ st.textContent = `βœ“ Gradient done. ${data.layer_scores.length} KB layers.`;
2250
+ // Draw gradient-only bars
2251
+ drawGradOnlyChart(data.layer_scores);
2252
+ } catch(e) { st.textContent = 'βœ— '+e.message; }
2253
+ }
2254
+
2255
+ async function runCausalOnly() {
2256
+ const st = document.getElementById('sl-status');
2257
+ st.textContent = '⏳ Running causal patch trace…';
2258
+ try {
2259
+ const data = await api('/api/causal_trace', {
2260
+ prompt: document.getElementById('sl-prompt').value,
2261
+ subject: document.getElementById('sl-subject').value,
2262
+ target: document.getElementById('sl-target').value,
2263
+ noise_std: +document.getElementById('sl-noise').value,
2264
+ });
2265
+ st.textContent = `βœ“ Causal done. clean=${data.clean_prob.toFixed(4)} corrupt=${data.corrupt_prob.toFixed(4)}`;
2266
+ drawCausalOnlyChart(data.results);
2267
+ } catch(e) { st.textContent = 'βœ— '+e.message; }
2268
+ }
2269
+
2270
+ async function runCollateral() {
2271
+ const prompt = document.getElementById('sl-coll-prompt').value;
2272
+ try {
2273
+ const data = await api('/api/infer', { prompt, top_k: 5 });
2274
+ const el = document.getElementById('sl-coll-out');
2275
+ el.textContent = `"${prompt}"\n` +
2276
+ data.results.map(r=>` ${r.token.padEnd(18)} ${r.prob.toFixed(4)}`).join('\n');
2277
+ } catch(e) { document.getElementById('sl-coll-out').textContent = 'βœ— '+e.message; }
2278
+ }
2279
+
2280
+ function drawSmartLocateChart(ranked) {
2281
+ // Sort by layer for chart display
2282
+ const byLayer = [...ranked].sort((a,b)=>a.layer-b.layer);
2283
+ const el = clearChart('sl-chart');
2284
+ const W = el.clientWidth || 700, H = 40 + byLayer.length * 34;
2285
+ el.style.height = H+'px';
2286
+ const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
2287
+ const m = {left:50,right:110,top:20,bottom:20};
2288
+ const w = W-m.left-m.right;
2289
+ const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
2290
+
2291
+ // Each bar = 3 stacked segments (normalized: gate_sim_n, grad_norm_n, causal_n)
2292
+ // Each segment width = signal_n * (w/3) so max of each is w/3
2293
+ const segW = w / 3;
2294
+
2295
+ byLayer.forEach((d,i)=>{
2296
+ const y = i*34;
2297
+ // Label
2298
+ g.append('text').attr('x',-6).attr('y',y+17).attr('text-anchor','end')
2299
+ .attr('fill', d.layer===(_slData?.recommendation?.layer) ? C.yellow : C.muted)
2300
+ .attr('font-size',10).text('L'+d.layer);
2301
+
2302
+ // gate_sim segment
2303
+ g.append('rect').attr('x',0).attr('y',y+4).attr('width',d.gate_sim_n*segW)
2304
+ .attr('height',12).attr('rx',2).attr('fill',C.blue).attr('opacity',.8)
2305
+ .on('mousemove',(ev)=>showTooltip(`L${d.layer} gate_sim: ${d.gate_sim}`,ev.pageX,ev.pageY))
2306
+ .on('mouseleave',hideTooltip);
2307
+ // grad_norm segment
2308
+ g.append('rect').attr('x',segW).attr('y',y+4).attr('width',d.grad_norm_n*segW)
2309
+ .attr('height',12).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
2310
+ .on('mousemove',(ev)=>showTooltip(`L${d.layer} grad_norm: ${d.grad_norm}`,ev.pageX,ev.pageY))
2311
+ .on('mouseleave',hideTooltip);
2312
+ // causal segment
2313
+ g.append('rect').attr('x',segW*2).attr('y',y+4).attr('width',d.causal_n*segW)
2314
+ .attr('height',12).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
2315
+ .on('mousemove',(ev)=>showTooltip(`L${d.layer} causal_IE: ${d.causal_effect}`,ev.pageX,ev.pageY))
2316
+ .on('mouseleave',hideTooltip);
2317
+ // combined score label
2318
+ g.append('text').attr('x',w+6).attr('y',y+14)
2319
+ .attr('fill', d.combined===Math.max(...ranked.map(r=>r.combined)) ? C.yellow : C.muted)
2320
+ .attr('font-size',10).text(d.combined.toFixed(3));
2321
+ });
2322
+
2323
+ // Axis labels
2324
+ const ax = g.append('g').attr('transform',`translate(0,${byLayer.length*34})`);
2325
+ ax.append('text').attr('x',segW/2).attr('y',14).attr('text-anchor','middle')
2326
+ .attr('fill',C.blue).attr('font-size',9).text('gate_sim');
2327
+ ax.append('text').attr('x',segW*1.5).attr('y',14).attr('text-anchor','middle')
2328
+ .attr('fill',C.green).attr('font-size',9).text('grad_norm');
2329
+ ax.append('text').attr('x',segW*2.5).attr('y',14).attr('text-anchor','middle')
2330
+ .attr('fill',C.yellow).attr('font-size',9).text('causal IE');
2331
+
2332
+ // Section dividers
2333
+ [segW,segW*2].forEach(x=>{
2334
+ g.append('line').attr('x1',x).attr('x2',x).attr('y1',0).attr('y2',byLayer.length*34)
2335
+ .attr('stroke',C.border).attr('stroke-width',1).attr('stroke-dasharray','3,2');
2336
+ });
2337
+ }
2338
+
2339
+ function drawGradOnlyChart(layerScores) {
2340
+ const el = clearChart('sl-chart');
2341
+ const W = el.clientWidth || 700, H = 40 + layerScores.length * 28;
2342
+ el.style.height = H+'px';
2343
+ const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
2344
+ const m = {left:50,right:80,top:20,bottom:10};
2345
+ const w = W-m.left-m.right;
2346
+ const maxG = d3.max(layerScores, d=>d.max_grad) || 1;
2347
+ const x = d3.scaleLinear().domain([0,maxG]).range([0,w]);
2348
+ const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
2349
+ layerScores.forEach((d,i)=>{
2350
+ const y=i*28;
2351
+ g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
2352
+ .attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
2353
+ g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.max_grad))
2354
+ .attr('height',16).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
2355
+ .on('mousemove',(ev)=>showTooltip(`L${d.layer} max_grad: ${d.max_grad}`,ev.pageX,ev.pageY))
2356
+ .on('mouseleave',hideTooltip);
2357
+ g.append('text').attr('x',x(d.max_grad)+4).attr('y',y+14)
2358
+ .attr('fill',C.green).attr('font-size',9).text(d.max_grad.toExponential(2));
2359
+ });
2360
+ }
2361
+
2362
+ function drawCausalOnlyChart(results) {
2363
+ const el = clearChart('sl-chart');
2364
+ const W = el.clientWidth || 700, H = 40 + results.length * 28;
2365
+ el.style.height = H+'px';
2366
+ const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
2367
+ const m = {left:50,right:80,top:20,bottom:10};
2368
+ const w = W-m.left-m.right;
2369
+ const maxIE = Math.max(d3.max(results, d=>d.indirect_effect), 0.001);
2370
+ const x = d3.scaleLinear().domain([0,maxIE]).range([0,w]);
2371
+ const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
2372
+ results.forEach((d,i)=>{
2373
+ const y=i*28; const ie=Math.max(0,d.indirect_effect);
2374
+ g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
2375
+ .attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
2376
+ g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(ie))
2377
+ .attr('height',16).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
2378
+ .on('mousemove',(ev)=>showTooltip(`L${d.layer} IE: ${d.indirect_effect} patch_p: ${d.patch_prob}`,ev.pageX,ev.pageY))
2379
+ .on('mouseleave',hideTooltip);
2380
+ g.append('text').attr('x',x(ie)+4).attr('y',y+14)
2381
+ .attr('fill',C.yellow).attr('font-size',9).text(d.indirect_effect.toFixed(5));
2382
+ });
2383
+ }
2384
+
2385
+ function showSmartRec(data) {
2386
+ const rec = data.recommendation;
2387
+ if(!rec){ document.getElementById('sl-rec').textContent='No recommendation.'; return; }
2388
+ let txt = `β˜… Best layer: L${rec.layer} combined=${rec.combined}\n\n`;
2389
+ txt += ` gate_sim: ${rec.gate_sim} (norm ${rec.gate_sim_n})\n`;
2390
+ txt += ` grad_norm: ${rec.grad_norm} (norm ${rec.grad_norm_n})\n`;
2391
+ txt += ` causal_effect: ${rec.causal_effect} (norm ${rec.causal_n})\n`;
2392
+ if(rec.best_slots.length){
2393
+ txt += `\nTop gradient slots in L${rec.layer}:\n`;
2394
+ rec.best_slots.forEach(s=>{ txt+=` slot ${s.slot} grad_norm=${s.grad_norm}\n`; });
2395
+ }
2396
+ txt += `\nPhase layer (trace): L${data.phase_layer}\n`;
2397
+ txt += `Subject pos: ${data.subject_pos}\n`;
2398
+ txt += `clean_prob: ${data.clean_prob} corrupt_prob: ${data.corrupt_prob}`;
2399
+ document.getElementById('sl-rec').textContent = txt;
2400
+ }
2401
+
2402
+ function buildSlTable(ranked) {
2403
+ const el = document.getElementById('sl-table');
2404
+ const maxC = Math.max(...ranked.map(r=>r.combined));
2405
+ let html = `<table style="width:100%;border-collapse:collapse;font-size:11px">
2406
+ <thead><tr style="color:var(--muted);border-bottom:1px solid var(--border)">
2407
+ <th style="padding:4px 8px;text-align:left">Layer</th>
2408
+ <th style="padding:4px 8px;text-align:right;color:${C.blue}">gate_sim</th>
2409
+ <th style="padding:4px 8px;text-align:right;color:${C.green}">grad_norm</th>
2410
+ <th style="padding:4px 8px;text-align:right;color:${C.yellow}">causal IE</th>
2411
+ <th style="padding:4px 8px;text-align:right">combined β˜…</th>
2412
+ <th style="padding:4px 8px;text-align:left;color:var(--muted)">top grad slots</th>
2413
+ </tr></thead><tbody>`;
2414
+ ranked.forEach(r=>{
2415
+ const hi = r.combined===maxC ? `background:rgba(210,153,34,0.08)` : '';
2416
+ const slots = r.best_slots.slice(0,3).map(s=>s.slot).join(', ');
2417
+ html+=`<tr style="${hi};border-bottom:1px solid var(--border)">
2418
+ <td style="padding:4px 8px;color:${r.combined===maxC?C.yellow:C.text}">L${r.layer}</td>
2419
+ <td style="padding:4px 8px;text-align:right;color:${C.blue}">${r.gate_sim}</td>
2420
+ <td style="padding:4px 8px;text-align:right;color:${C.green}">${r.grad_norm.toExponential(2)}</td>
2421
+ <td style="padding:4px 8px;text-align:right;color:${C.yellow}">${r.causal_effect.toFixed(5)}</td>
2422
+ <td style="padding:4px 8px;text-align:right;font-weight:700">${r.combined}</td>
2423
+ <td style="padding:4px 8px;color:var(--muted)">${slots}</td>
2424
+ </tr>`;
2425
+ });
2426
+ html += '</tbody></table>';
2427
+ el.innerHTML = html;
2428
+ }
2429
+
2430
  // ═══════════════════════════════════════════════
2431
  // INIT
2432
  // ═══════════════════════════════════════════════
 
2485
  use_activation: bool = False
2486
  prompt: Optional[str] = None
2487
 
2488
+ class GradientReq(BaseModel):
2489
+ prompt: str
2490
+ target: str
2491
+
2492
+ class CausalTraceReq(BaseModel):
2493
+ prompt: str
2494
+ subject: str
2495
+ target: str
2496
+ noise_std: float = 0.1
2497
+
2498
+ class SmartLocateReq(BaseModel):
2499
+ prompt: str
2500
+ subject: str
2501
+ target: str
2502
+ alpha: float = 0.4
2503
+ beta: float = 0.3
2504
+ gamma: float = 0.3
2505
+ noise_std: float = 0.1
2506
+
2507
+ class SmartEditReq(BaseModel):
2508
+ prompt: str
2509
+ subject: str
2510
+ relation: str = ""
2511
+ old_target: str
2512
+ new_target: str
2513
+ top_layers: int = 3
2514
+ slots_per_layer: int = 2
2515
+ scale: float = 1.5
2516
+ noise_std: float = 0.1
2517
+ alpha: float = 0.4
2518
+ beta: float = 0.4
2519
+ gamma: float = 0.2
2520
+
2521
  class DryRunReq(BaseModel):
2522
  entity: str
2523
  new_target: str
 
2624
  return vi.locate(req.prompt, req.subject, req.target)
2625
 
2626
 
2627
+ @app.post("/api/gradient_scores")
2628
+ async def api_gradient_scores(req: GradientReq):
2629
+ vi = _require()
2630
+ return vi.gradient_slot_scores(req.prompt, req.target)
2631
+
2632
+
2633
+ @app.post("/api/causal_trace")
2634
+ async def api_causal_trace(req: CausalTraceReq):
2635
+ vi = _require()
2636
+ return vi.causal_patch_trace(req.prompt, req.subject, req.target,
2637
+ noise_std=req.noise_std)
2638
+
2639
+
2640
+ @app.post("/api/smart_locate")
2641
+ async def api_smart_locate(req: SmartLocateReq):
2642
+ vi = _require()
2643
+ return vi.smart_locate(req.prompt, req.subject, req.target,
2644
+ alpha=req.alpha, beta=req.beta, gamma=req.gamma,
2645
+ noise_std=req.noise_std)
2646
+
2647
+
2648
+ @app.post("/api/smart_edit")
2649
+ async def api_smart_edit(req: SmartEditReq):
2650
+ vi = _require()
2651
+ prompt_str = req.prompt or f"The {req.relation} of {req.subject} is"
2652
+ before = vi.infer(prompt_str, top_k=5)
2653
+ log: List[str] = []
2654
+ try:
2655
+ result = vi.smart_edit(
2656
+ prompt_str, req.subject, req.relation, req.old_target, req.new_target,
2657
+ top_layers=req.top_layers, slots_per_layer=req.slots_per_layer,
2658
+ scale=req.scale, noise_std=req.noise_std,
2659
+ alpha=req.alpha, beta=req.beta, gamma=req.gamma, log=log
2660
+ )
2661
+ except Exception as e:
2662
+ raise HTTPException(status_code=500, detail=str(e))
2663
+ after = vi.infer(prompt_str, top_k=5)
2664
+ b_map = {d["token"]: d["prob"] for d in before}
2665
+ a_map = {d["token"]: d["prob"] for d in after}
2666
+ all_toks = set(b_map) | set(a_map)
2667
+ delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0),
2668
+ "delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks],
2669
+ key=lambda x: -abs(x["delta"]))
2670
+ return {"before": before, "after": after, "delta": delta,
2671
+ "debug_log": log, "used_layers": result["used_layers"],
2672
+ "smart_locate": result["smart_locate"]}
2673
+
2674
+
2675
  @app.post("/api/gate_heatmap")
2676
  async def api_gate_heatmap(req: HeatmapReq):
2677
  vi = _require()