Update app.py
Browse files
app.py
CHANGED
|
@@ -466,7 +466,272 @@ class VIndex:
|
|
| 466 |
"mode": "activation-guided" if use_act else "embed-based"
|
| 467 |
}
|
| 468 |
|
| 469 |
-
# ββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
def infer(self, prompt: str, top_k: int = 5):
|
| 472 |
probs = torch.softmax(self._forward(prompt), dim=-1)
|
|
@@ -771,10 +1036,12 @@ svg text { font-family: var(--font); fill: var(--text); }
|
|
| 771 |
<button class="tab-btn" onclick="showTab('describe')">β‘ Describe</button>
|
| 772 |
<button class="tab-btn" onclick="showTab('trace')">β’ Trace</button>
|
| 773 |
<button class="tab-btn" onclick="showTab('locate')">β£ Locate</button>
|
| 774 |
-
<button class="tab-btn" onclick="showTab('
|
| 775 |
-
<button class="tab-btn" onclick="showTab('
|
| 776 |
-
<button class="tab-btn" onclick="showTab('
|
| 777 |
-
<button class="tab-btn" onclick="showTab('
|
|
|
|
|
|
|
| 778 |
</div>
|
| 779 |
|
| 780 |
<!-- TOOLTIP -->
|
|
@@ -925,6 +1192,86 @@ svg text { font-family: var(--font); fill: var(--text); }
|
|
| 925 |
</div>
|
| 926 |
</div>
|
| 927 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
<!-- ββββββββββ HEATMAP PANEL ββββββββββ -->
|
| 929 |
<div id="panel-heatmap" class="panel">
|
| 930 |
<div class="card">
|
|
@@ -980,6 +1327,7 @@ svg text { font-family: var(--font); fill: var(--text); }
|
|
| 980 |
<div class="radio-group" id="edit-mode-group">
|
| 981 |
<label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label>
|
| 982 |
<label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label>
|
|
|
|
| 983 |
<label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label>
|
| 984 |
<label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label>
|
| 985 |
<label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label>
|
|
@@ -990,6 +1338,28 @@ svg text { font-family: var(--font); fill: var(--text); }
|
|
| 990 |
<label>Prompt (PRECISE mode)</label>
|
| 991 |
<input type="text" id="edit-prompt" value="The capital of France is">
|
| 992 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
<div id="style-shift-row" style="display:none;margin-top:8px">
|
| 994 |
<label>From concept</label>
|
| 995 |
<input type="text" id="ss-from" value="formal">
|
|
@@ -1069,6 +1439,209 @@ svg text { font-family: var(--font); fill: var(--text); }
|
|
| 1069 |
</div>
|
| 1070 |
</div>
|
| 1071 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
</div><!-- /app -->
|
| 1073 |
|
| 1074 |
<script>
|
|
@@ -1459,6 +2032,7 @@ function showHmSlotDetail(layer, slot) {
|
|
| 1459 |
document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{
|
| 1460 |
r.addEventListener('change', ()=>{
|
| 1461 |
document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none';
|
|
|
|
| 1462 |
document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none';
|
| 1463 |
document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none';
|
| 1464 |
});
|
|
@@ -1500,6 +2074,30 @@ async function runEdit() {
|
|
| 1500 |
to_concept: document.getElementById('ss-to').value,
|
| 1501 |
strength: +document.getElementById('ss-strength').value,
|
| 1502 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1503 |
if(mode==='MULTI-EDIT'){
|
| 1504 |
try {
|
| 1505 |
body.facts = JSON.parse(document.getElementById('multi-json').value);
|
|
@@ -1614,6 +2212,221 @@ async function updatePatchCount() {
|
|
| 1614 |
} catch(e){}
|
| 1615 |
}
|
| 1616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1617 |
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1618 |
// INIT
|
| 1619 |
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -1672,6 +2485,39 @@ class HeatmapReq(BaseModel):
|
|
| 1672 |
use_activation: bool = False
|
| 1673 |
prompt: Optional[str] = None
|
| 1674 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1675 |
class DryRunReq(BaseModel):
|
| 1676 |
entity: str
|
| 1677 |
new_target: str
|
|
@@ -1778,6 +2624,54 @@ async def api_locate(req: LocateReq):
|
|
| 1778 |
return vi.locate(req.prompt, req.subject, req.target)
|
| 1779 |
|
| 1780 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1781 |
@app.post("/api/gate_heatmap")
|
| 1782 |
async def api_gate_heatmap(req: HeatmapReq):
|
| 1783 |
vi = _require()
|
|
|
|
| 466 |
"mode": "activation-guided" if use_act else "embed-based"
|
| 467 |
}
|
| 468 |
|
| 469 |
+
# ββ Phase 1+: mechanistic attribution βββββββββββββββββββββ
|
| 470 |
+
|
| 471 |
+
def gradient_slot_scores(self, prompt: str, target: str) -> Dict:
|
| 472 |
+
"""One backward pass: grad norm of β(-log p(target))/βW_down[:,slot] per KB layer.
|
| 473 |
+
Identifies which slots causally contributed to this prediction via gradient signal."""
|
| 474 |
+
target_id = self.token_id(target)
|
| 475 |
+
|
| 476 |
+
# Temporarily enable grad on down-proj weights
|
| 477 |
+
down_params: List[Tuple[int, torch.nn.Parameter]] = []
|
| 478 |
+
for li in range(self.arch.n_layers):
|
| 479 |
+
layer = self.arch._layer(li)
|
| 480 |
+
p = layer.mlp.c_proj.weight if self.arch.style == "gpt2" \
|
| 481 |
+
else layer.mlp.down_proj.weight
|
| 482 |
+
p.requires_grad_(True)
|
| 483 |
+
down_params.append((li, p))
|
| 484 |
+
|
| 485 |
+
self.model.zero_grad()
|
| 486 |
+
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
|
| 487 |
+
out = self.model(**inputs)
|
| 488 |
+
loss = -F.log_softmax(out.logits[0, -1], dim=-1)[target_id]
|
| 489 |
+
loss.backward()
|
| 490 |
+
|
| 491 |
+
layer_scores = []
|
| 492 |
+
for li, p in down_params:
|
| 493 |
+
grad = p.grad
|
| 494 |
+
p.requires_grad_(False)
|
| 495 |
+
if grad is None:
|
| 496 |
+
layer_scores.append({"layer": li, "max_grad": 0.0, "top_slots": []})
|
| 497 |
+
continue
|
| 498 |
+
# gpt2: c_proj.weight [ffn_dim, hidden] β rows = slots
|
| 499 |
+
# gated: down_proj.weight [hidden, ffn_dim] β cols = slots
|
| 500 |
+
slot_norms = grad.norm(dim=1) if self.arch.style == "gpt2" \
|
| 501 |
+
else grad.norm(dim=0) # [ffn_dim]
|
| 502 |
+
k = min(20, slot_norms.shape[0])
|
| 503 |
+
vals, idxs = slot_norms.topk(k)
|
| 504 |
+
layer_scores.append({
|
| 505 |
+
"layer": li,
|
| 506 |
+
"max_grad": round(float(vals[0].item()), 6),
|
| 507 |
+
"top_slots": [{"slot": int(idx.item()),
|
| 508 |
+
"grad_norm": round(float(v.item()), 6)}
|
| 509 |
+
for idx, v in zip(idxs, vals)]
|
| 510 |
+
})
|
| 511 |
+
|
| 512 |
+
self.model.zero_grad()
|
| 513 |
+
return {"layer_scores": layer_scores}
|
| 514 |
+
|
| 515 |
+
def causal_patch_trace(self, prompt: str, subject: str, target: str,
|
| 516 |
+
noise_std: float = 0.1) -> Dict:
|
| 517 |
+
"""ROME-style causal tracing.
|
| 518 |
+
Corrupts subject embeddings, then for each KB layer measures how much
|
| 519 |
+
patching that layer's hidden state (at subject position) restores p(target).
|
| 520 |
+
Expensive: O(n_layers) forward passes."""
|
| 521 |
+
target_id = self.token_id(target)
|
| 522 |
+
W_u = self.arch.get_unembedding().to(self.device)
|
| 523 |
+
inputs = self.tok(prompt, return_tensors="pt").to(self.device)
|
| 524 |
+
ids = inputs["input_ids"][0].tolist()
|
| 525 |
+
|
| 526 |
+
# Find subject token positions via subsequence match
|
| 527 |
+
subj_ids = self.tok.encode(subject, add_special_tokens=False)
|
| 528 |
+
subj_pos: List[int] = []
|
| 529 |
+
for start in range(len(ids) - len(subj_ids) + 1):
|
| 530 |
+
if ids[start:start+len(subj_ids)] == subj_ids:
|
| 531 |
+
subj_pos = list(range(start, start+len(subj_ids)))
|
| 532 |
+
break
|
| 533 |
+
if not subj_pos:
|
| 534 |
+
for si in subj_ids:
|
| 535 |
+
if si in ids:
|
| 536 |
+
subj_pos = [ids.index(si)]
|
| 537 |
+
break
|
| 538 |
+
if not subj_pos:
|
| 539 |
+
subj_pos = [0]
|
| 540 |
+
|
| 541 |
+
# ββ Clean forward β capture every layer's hidden states ββ
|
| 542 |
+
clean_hs: Dict[int, torch.Tensor] = {}
|
| 543 |
+
clean_handles = []
|
| 544 |
+
def _mk_clean(li):
|
| 545 |
+
def _h(m, inp, out):
|
| 546 |
+
h = out[0] if isinstance(out, tuple) else out
|
| 547 |
+
clean_hs[li] = h[0].detach().clone() # [seq, hidden]
|
| 548 |
+
return _h
|
| 549 |
+
for li in range(self.arch.n_layers):
|
| 550 |
+
clean_handles.append(self.arch._layer(li).register_forward_hook(_mk_clean(li)))
|
| 551 |
+
with torch.no_grad():
|
| 552 |
+
clean_out = self.model(**inputs)
|
| 553 |
+
for h in clean_handles: h.remove()
|
| 554 |
+
clean_prob = float(torch.softmax(clean_out.logits[0,-1], dim=-1)[target_id].item())
|
| 555 |
+
|
| 556 |
+
# ββ Corrupted embeddings ββ
|
| 557 |
+
E = self.arch.get_embedding().to(self.device)
|
| 558 |
+
emb = E[inputs["input_ids"][0]].unsqueeze(0).clone() # [1, seq, hidden]
|
| 559 |
+
noise_scale = emb.std().item() * noise_std
|
| 560 |
+
for pos in subj_pos:
|
| 561 |
+
emb[0, pos] += torch.randn_like(emb[0, pos]) * noise_scale
|
| 562 |
+
|
| 563 |
+
with torch.no_grad():
|
| 564 |
+
corr_out = self.model(inputs_embeds=emb)
|
| 565 |
+
corr_prob = float(torch.softmax(corr_out.logits[0,-1], dim=-1)[target_id].item())
|
| 566 |
+
|
| 567 |
+
# ββ Causal patch sweep ββ
|
| 568 |
+
results = []
|
| 569 |
+
for li in range(self.kb_start, self.kb_end):
|
| 570 |
+
def _mk_patch(target_li):
|
| 571 |
+
def _h(m, inp, out):
|
| 572 |
+
if target_li not in clean_hs:
|
| 573 |
+
return out
|
| 574 |
+
is_tuple = isinstance(out, tuple)
|
| 575 |
+
h = list(out) if is_tuple else [out]
|
| 576 |
+
clean = clean_hs[target_li]
|
| 577 |
+
for pos in subj_pos:
|
| 578 |
+
if pos < clean.shape[0]:
|
| 579 |
+
h[0][0, pos] = clean[pos].to(h[0].device)
|
| 580 |
+
return tuple(h) if is_tuple else h[0]
|
| 581 |
+
return _h
|
| 582 |
+
ph = self.arch._layer(li).register_forward_hook(_mk_patch(li))
|
| 583 |
+
with torch.no_grad():
|
| 584 |
+
patch_out = self.model(inputs_embeds=emb.clone())
|
| 585 |
+
ph.remove()
|
| 586 |
+
patch_prob = float(torch.softmax(patch_out.logits[0,-1], dim=-1)[target_id].item())
|
| 587 |
+
ie = patch_prob - corr_prob
|
| 588 |
+
results.append({
|
| 589 |
+
"layer": li,
|
| 590 |
+
"patch_prob": round(patch_prob, 6),
|
| 591 |
+
"indirect_effect": round(ie, 6),
|
| 592 |
+
})
|
| 593 |
+
|
| 594 |
+
return {
|
| 595 |
+
"clean_prob": round(clean_prob, 6),
|
| 596 |
+
"corrupt_prob": round(corr_prob, 6),
|
| 597 |
+
"subject_pos": subj_pos,
|
| 598 |
+
"results": results,
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
def smart_locate(self, prompt: str, subject: str, target: str,
|
| 602 |
+
alpha: float = 0.4, beta: float = 0.3, gamma: float = 0.3,
|
| 603 |
+
noise_std: float = 0.1) -> Dict:
|
| 604 |
+
"""Combined gate_sim + grad_norm + causal_effect β precise layer/slot ranking.
|
| 605 |
+
alpha = weight for gate cosine sim
|
| 606 |
+
beta = weight for gradient norm
|
| 607 |
+
gamma = weight for causal indirect effect"""
|
| 608 |
+
gate_data = self.locate(prompt, subject, target)
|
| 609 |
+
grad_data = self.gradient_slot_scores(prompt, target)
|
| 610 |
+
causal_data = self.causal_patch_trace(prompt, subject, target, noise_std=noise_std)
|
| 611 |
+
|
| 612 |
+
gate_map = {ls["layer"]: ls["max_sim"] for ls in gate_data["layer_scores"]}
|
| 613 |
+
grad_map = {ls["layer"]: ls["max_grad"] for ls in grad_data["layer_scores"]}
|
| 614 |
+
causal_map = {r["layer"]: max(0.0, r["indirect_effect"])
|
| 615 |
+
for r in causal_data["results"]}
|
| 616 |
+
grad_slots = {ls["layer"]: ls["top_slots"] for ls in grad_data["layer_scores"]}
|
| 617 |
+
|
| 618 |
+
layers = sorted(set(gate_map) | set(grad_map) | set(causal_map))
|
| 619 |
+
|
| 620 |
+
def _norm(vals: List[float]) -> List[float]:
|
| 621 |
+
m = max(vals) if vals else 1.0
|
| 622 |
+
return [v/m if m > 0 else 0.0 for v in vals]
|
| 623 |
+
|
| 624 |
+
gv = [gate_map.get(l, 0.0) for l in layers]
|
| 625 |
+
dv = [grad_map.get(l, 0.0) for l in layers]
|
| 626 |
+
cv = [causal_map.get(l, 0.0) for l in layers]
|
| 627 |
+
gn, dn, cn = _norm(gv), _norm(dv), _norm(cv)
|
| 628 |
+
|
| 629 |
+
ranked = []
|
| 630 |
+
for i, l in enumerate(layers):
|
| 631 |
+
score = alpha*gn[i] + beta*dn[i] + gamma*cn[i]
|
| 632 |
+
ranked.append({
|
| 633 |
+
"layer": l,
|
| 634 |
+
"gate_sim": round(gv[i], 4),
|
| 635 |
+
"grad_norm": round(dv[i], 6),
|
| 636 |
+
"causal_effect": round(cv[i], 6),
|
| 637 |
+
"gate_sim_n": round(gn[i], 4),
|
| 638 |
+
"grad_norm_n": round(dn[i], 4),
|
| 639 |
+
"causal_n": round(cn[i], 4),
|
| 640 |
+
"combined": round(score, 4),
|
| 641 |
+
"best_slots": (grad_slots.get(l) or [])[:5],
|
| 642 |
+
})
|
| 643 |
+
|
| 644 |
+
ranked.sort(key=lambda x: -x["combined"])
|
| 645 |
+
|
| 646 |
+
return {
|
| 647 |
+
"ranked_layers": ranked,
|
| 648 |
+
"phase_layer": gate_data["phase_layer"],
|
| 649 |
+
"subject_pos": gate_data["subject_pos"],
|
| 650 |
+
"clean_prob": causal_data["clean_prob"],
|
| 651 |
+
"corrupt_prob": causal_data["corrupt_prob"],
|
| 652 |
+
"recommendation": ranked[0] if ranked else None,
|
| 653 |
+
"weights": {"alpha": alpha, "beta": beta, "gamma": gamma},
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
def smart_edit(self, prompt: str, subject: str, relation: str,
|
| 657 |
+
old_target: str, new_target: str,
|
| 658 |
+
top_layers: int = 3, slots_per_layer: int = 2,
|
| 659 |
+
scale: float = 1.5, noise_std: float = 0.1,
|
| 660 |
+
alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2,
|
| 661 |
+
log: Optional[List[str]] = None) -> Dict:
|
| 662 |
+
"""Auto edit: runs smart_locate on (prompt, subject, old_target) to find
|
| 663 |
+
the exact layer+slot targets via gradient+causal+gate consensus, then
|
| 664 |
+
patches those W_down columns toward embed(new_target).
|
| 665 |
+
|
| 666 |
+
old_target = what the model currently predicts (used to locate)
|
| 667 |
+
new_target = what you want to inject
|
| 668 |
+
top_layers = how many top-ranked layers to patch
|
| 669 |
+
slots_per_layer = gradient-identified slots to patch per layer
|
| 670 |
+
scale = col_norm multiplier (1.5-3.0 recommended)
|
| 671 |
+
beta > alpha because grad_norm is more reliable than gate_sim for small models."""
|
| 672 |
+
if log is None: log = []
|
| 673 |
+
self._snapshot()
|
| 674 |
+
|
| 675 |
+
log.append(f"SMART_EDIT: '{subject}' [{relation}] {old_target!r} β {new_target!r}")
|
| 676 |
+
log.append(f" Running smart_locate on prompt: {prompt!r}")
|
| 677 |
+
log.append(f" Weights: οΏ½οΏ½={alpha} Ξ²={beta} Ξ³={gamma} noise_std={noise_std}")
|
| 678 |
+
|
| 679 |
+
sl = self.smart_locate(prompt, subject, old_target,
|
| 680 |
+
alpha=alpha, beta=beta, gamma=gamma,
|
| 681 |
+
noise_std=noise_std)
|
| 682 |
+
|
| 683 |
+
log.append(f" clean_prob={sl['clean_prob']:.6f} corrupt_prob={sl['corrupt_prob']:.6f}")
|
| 684 |
+
log.append(f" Phase layer: L{sl['phase_layer']} Subject pos: {sl['subject_pos']}")
|
| 685 |
+
|
| 686 |
+
if sl["clean_prob"] < 1e-5:
|
| 687 |
+
log.append(" β clean_prob near zero β model barely knows this fact.")
|
| 688 |
+
log.append(" Grad-norm signal still valid. Causal IE=0 is expected.")
|
| 689 |
+
log.append(" Recommend: gpt2-medium or Qwen2.5-1.5B for stronger facts.")
|
| 690 |
+
|
| 691 |
+
tv = self.embed(new_target)
|
| 692 |
+
tv_n = F.normalize(tv, dim=0)
|
| 693 |
+
ops = []
|
| 694 |
+
used = []
|
| 695 |
+
|
| 696 |
+
top_ranked = sl["ranked_layers"][:top_layers]
|
| 697 |
+
for lr in top_ranked:
|
| 698 |
+
li = lr["layer"]
|
| 699 |
+
# Use gradient-identified slots β far more precise than gate cosine
|
| 700 |
+
grad_slots = [s["slot"] for s in lr["best_slots"][:slots_per_layer]]
|
| 701 |
+
if not grad_slots:
|
| 702 |
+
log.append(f" L{li}: no grad slots, skipping")
|
| 703 |
+
continue
|
| 704 |
+
_, Wd = self.arch.get_ffn_weights(li)
|
| 705 |
+
Wd = Wd.to(self.device)
|
| 706 |
+
for slot in grad_slots:
|
| 707 |
+
col_norm = Wd[:, slot].norm().item()
|
| 708 |
+
new_col = (tv_n * col_norm * scale).cpu().tolist()
|
| 709 |
+
ops.append({"op":"update_down","layer":li,"slot":slot,"down_col":new_col})
|
| 710 |
+
log.append(f" β L{li} slot {slot}: combined={lr['combined']} "
|
| 711 |
+
f"grad_norm={lr['grad_norm']:.4f} col_norm={col_norm:.4f} "
|
| 712 |
+
f"inject={col_norm*scale:.4f}")
|
| 713 |
+
used.append({"layer":li,"slots":grad_slots,"combined":lr["combined"]})
|
| 714 |
+
|
| 715 |
+
self.patches.append({
|
| 716 |
+
"type": "SMART_UPDATE",
|
| 717 |
+
"entity": subject,
|
| 718 |
+
"relation": relation,
|
| 719 |
+
"new_target": new_target,
|
| 720 |
+
"old_target": old_target,
|
| 721 |
+
"smart_top": top_ranked,
|
| 722 |
+
"ops": ops,
|
| 723 |
+
})
|
| 724 |
+
self._apply_all_patches()
|
| 725 |
+
log.append(f"\n β {len(ops)} op(s) across {len(used)} layer(s), patch #{len(self.patches)}")
|
| 726 |
+
|
| 727 |
+
return {
|
| 728 |
+
"ops": ops,
|
| 729 |
+
"used_layers": used,
|
| 730 |
+
"smart_locate": sl,
|
| 731 |
+
"log": log,
|
| 732 |
+
}
|
| 733 |
+
|
| 734 |
+
|
| 735 |
|
| 736 |
def infer(self, prompt: str, top_k: int = 5):
|
| 737 |
probs = torch.softmax(self._forward(prompt), dim=-1)
|
|
|
|
| 1036 |
<button class="tab-btn" onclick="showTab('describe')">β‘ Describe</button>
|
| 1037 |
<button class="tab-btn" onclick="showTab('trace')">β’ Trace</button>
|
| 1038 |
<button class="tab-btn" onclick="showTab('locate')">β£ Locate</button>
|
| 1039 |
+
<button class="tab-btn" onclick="showTab('smartlocate')">β€ Smart Locate</button>
|
| 1040 |
+
<button class="tab-btn" onclick="showTab('heatmap')">β₯ Heatmap</button>
|
| 1041 |
+
<button class="tab-btn" onclick="showTab('edit')">β¦ Edit</button>
|
| 1042 |
+
<button class="tab-btn" onclick="showTab('patches')">β§ Patches</button>
|
| 1043 |
+
<button class="tab-btn" onclick="showTab('guide')" style="margin-left:auto;color:var(--green)">π Guide</button>
|
| 1044 |
+
<button class="tab-btn" onclick="showTab('load')">β Load</button>
|
| 1045 |
</div>
|
| 1046 |
|
| 1047 |
<!-- TOOLTIP -->
|
|
|
|
| 1192 |
</div>
|
| 1193 |
</div>
|
| 1194 |
|
| 1195 |
+
<!-- ββββββββββ SMART LOCATE PANEL ββββββββββ -->
|
| 1196 |
+
<div id="panel-smartlocate" class="panel">
|
| 1197 |
+
<div class="card">
|
| 1198 |
+
<h3>Smart Locate β gradient + causal + gate_sim combined</h3>
|
| 1199 |
+
<div style="color:var(--muted);font-size:11px;margin-bottom:12px;line-height:1.7">
|
| 1200 |
+
Three independent signals combined into one ranked layer list.<br>
|
| 1201 |
+
<span style="color:var(--blue)">β gate_sim</span> β static embedding cosine (fast, weak proxy)
|
| 1202 |
+
<span style="color:var(--green)">β grad_norm</span> β βloss/βW_down per slot (one backward pass)
|
| 1203 |
+
<span style="color:var(--yellow)">β causal IE</span> β indirect effect via subject-corruption patching (N_layers passes, slow)
|
| 1204 |
+
</div>
|
| 1205 |
+
<div class="row">
|
| 1206 |
+
<div class="col2">
|
| 1207 |
+
<label>Prompt</label>
|
| 1208 |
+
<input type="text" id="sl-prompt" value="The capital of France is">
|
| 1209 |
+
</div>
|
| 1210 |
+
<div class="col">
|
| 1211 |
+
<label>Subject</label>
|
| 1212 |
+
<input type="text" id="sl-subject" value="France">
|
| 1213 |
+
</div>
|
| 1214 |
+
<div class="col">
|
| 1215 |
+
<label>Target</label>
|
| 1216 |
+
<input type="text" id="sl-target" value="Paris">
|
| 1217 |
+
</div>
|
| 1218 |
+
</div>
|
| 1219 |
+
<div class="row" style="margin-top:10px">
|
| 1220 |
+
<div class="col">
|
| 1221 |
+
<label>Ξ± gate_sim: <span id="sl-a-val">0.4</span></label>
|
| 1222 |
+
<input type="range" id="sl-alpha" min="0" max="1" step="0.05" value="0.4"
|
| 1223 |
+
oninput="document.getElementById('sl-a-val').textContent=this.value">
|
| 1224 |
+
</div>
|
| 1225 |
+
<div class="col">
|
| 1226 |
+
<label>Ξ² grad_norm: <span id="sl-b-val">0.3</span></label>
|
| 1227 |
+
<input type="range" id="sl-beta" min="0" max="1" step="0.05" value="0.3"
|
| 1228 |
+
oninput="document.getElementById('sl-b-val').textContent=this.value">
|
| 1229 |
+
</div>
|
| 1230 |
+
<div class="col">
|
| 1231 |
+
<label>Ξ³ causal: <span id="sl-g-val">0.3</span></label>
|
| 1232 |
+
<input type="range" id="sl-gamma" min="0" max="1" step="0.05" value="0.3"
|
| 1233 |
+
oninput="document.getElementById('sl-g-val').textContent=this.value">
|
| 1234 |
+
</div>
|
| 1235 |
+
<div class="col">
|
| 1236 |
+
<label>Noise Ο: <span id="sl-noise-val">0.1</span></label>
|
| 1237 |
+
<input type="range" id="sl-noise" min="0.02" max="0.5" step="0.02" value="0.1"
|
| 1238 |
+
oninput="document.getElementById('sl-noise-val').textContent=this.value">
|
| 1239 |
+
</div>
|
| 1240 |
+
</div>
|
| 1241 |
+
<div style="display:flex;gap:8px;margin-top:12px;flex-wrap:wrap">
|
| 1242 |
+
<button onclick="runSmartLocate()">β‘ Smart Locate (full)</button>
|
| 1243 |
+
<button class="secondary" onclick="runGradientOnly()">βΆ Gradient only (fast)</button>
|
| 1244 |
+
<button class="secondary" onclick="runCausalOnly()">βΆ Causal trace only</button>
|
| 1245 |
+
</div>
|
| 1246 |
+
<div id="sl-status" style="color:var(--muted);font-size:11px;margin-top:8px"></div>
|
| 1247 |
+
</div>
|
| 1248 |
+
|
| 1249 |
+
<div class="row">
|
| 1250 |
+
<div class="col2 card">
|
| 1251 |
+
<h3>Layer Rankings β 3-signal stacked bars</h3>
|
| 1252 |
+
<div class="chart-wrap" id="sl-chart" style="min-height:320px"></div>
|
| 1253 |
+
</div>
|
| 1254 |
+
<div class="col card">
|
| 1255 |
+
<h3>Recommendation</h3>
|
| 1256 |
+
<div id="sl-rec" class="log">Run Smart Locate to see the best edit target.</div>
|
| 1257 |
+
<h3 style="margin-top:14px">Collateral Probe</h3>
|
| 1258 |
+
<div class="row" style="margin-top:8px">
|
| 1259 |
+
<input type="text" id="sl-coll-prompt" value="Biggest cities in France"
|
| 1260 |
+
style="flex:2" placeholder="Collateral promptβ¦">
|
| 1261 |
+
<button class="secondary" onclick="runCollateral()" style="flex:0">βΆ</button>
|
| 1262 |
+
</div>
|
| 1263 |
+
<div id="sl-coll-out" class="log" style="margin-top:8px">Probe a prompt to check collateral damage.</div>
|
| 1264 |
+
</div>
|
| 1265 |
+
</div>
|
| 1266 |
+
|
| 1267 |
+
<div class="card">
|
| 1268 |
+
<h3>Per-Layer Detail</h3>
|
| 1269 |
+
<div id="sl-table" style="overflow-x:auto">
|
| 1270 |
+
<div style="color:var(--muted);font-size:11px">Run Smart Locate first.</div>
|
| 1271 |
+
</div>
|
| 1272 |
+
</div>
|
| 1273 |
+
</div>
|
| 1274 |
+
|
| 1275 |
<!-- ββββββββββ HEATMAP PANEL ββββββββββ -->
|
| 1276 |
<div id="panel-heatmap" class="panel">
|
| 1277 |
<div class="card">
|
|
|
|
| 1327 |
<div class="radio-group" id="edit-mode-group">
|
| 1328 |
<label><input type="radio" name="edit-mode" value="UPDATE" checked> UPDATE</label>
|
| 1329 |
<label><input type="radio" name="edit-mode" value="PRECISE"> PRECISE</label>
|
| 1330 |
+
<label><input type="radio" name="edit-mode" value="SMART"> β
SMART</label>
|
| 1331 |
<label><input type="radio" name="edit-mode" value="INSERT"> INSERT</label>
|
| 1332 |
<label><input type="radio" name="edit-mode" value="SUPPRESS"> SUPPRESS</label>
|
| 1333 |
<label><input type="radio" name="edit-mode" value="AMPLIFY"> AMPLIFY</label>
|
|
|
|
| 1338 |
<label>Prompt (PRECISE mode)</label>
|
| 1339 |
<input type="text" id="edit-prompt" value="The capital of France is">
|
| 1340 |
</div>
|
| 1341 |
+
<div id="smart-row" style="display:none;margin-top:8px;background:var(--bg);border:1px solid var(--border);border-radius:6px;padding:10px">
|
| 1342 |
+
<div style="color:var(--blue);font-size:11px;font-weight:700;margin-bottom:6px">β
SMART AUTO MODE</div>
|
| 1343 |
+
<label>Prompt (used for locate + after-check)</label>
|
| 1344 |
+
<input type="text" id="smart-prompt" value="The capital of France is">
|
| 1345 |
+
<label style="margin-top:6px">Old value (what model currently says)</label>
|
| 1346 |
+
<input type="text" id="smart-old" value="Paris">
|
| 1347 |
+
<div class="row" style="margin-top:6px">
|
| 1348 |
+
<div class="col">
|
| 1349 |
+
<label>Top layers: <span id="smart-layers-val">3</span></label>
|
| 1350 |
+
<input type="range" id="smart-layers" min="1" max="8" value="3"
|
| 1351 |
+
oninput="document.getElementById('smart-layers-val').textContent=this.value">
|
| 1352 |
+
</div>
|
| 1353 |
+
<div class="col">
|
| 1354 |
+
<label>Slots/layer: <span id="smart-slots-val">2</span></label>
|
| 1355 |
+
<input type="range" id="smart-slots" min="1" max="5" value="2"
|
| 1356 |
+
oninput="document.getElementById('smart-slots-val').textContent=this.value">
|
| 1357 |
+
</div>
|
| 1358 |
+
</div>
|
| 1359 |
+
<div style="color:var(--muted);font-size:10px;margin-top:6px">
|
| 1360 |
+
Runs smart_locate internally β patches gradient-identified slots. No manual tuning needed.
|
| 1361 |
+
</div>
|
| 1362 |
+
</div>
|
| 1363 |
<div id="style-shift-row" style="display:none;margin-top:8px">
|
| 1364 |
<label>From concept</label>
|
| 1365 |
<input type="text" id="ss-from" value="formal">
|
|
|
|
| 1439 |
</div>
|
| 1440 |
</div>
|
| 1441 |
|
| 1442 |
+
<!-- ββββββββββ GUIDE PANEL ββββββββββ -->
|
| 1443 |
+
<div id="panel-guide" class="panel">
|
| 1444 |
+
<div class="card">
|
| 1445 |
+
<h3>What is VINDEX doing?</h3>
|
| 1446 |
+
<div style="line-height:1.9;color:var(--muted)">
|
| 1447 |
+
In a transformer, factual associations like <span style="color:var(--text)">"France β capital β Paris"</span>
|
| 1448 |
+
are stored as direction vectors in the <span style="color:var(--blue)">W_down columns</span> of FFN layers.
|
| 1449 |
+
The <span style="color:var(--blue)">W_gate rows</span> act as keys: when the residual stream resembles "France",
|
| 1450 |
+
the matching gate fires, the down column adds "Paris" direction to the stream, and the unembedding reads out "Paris".
|
| 1451 |
+
VINDEX surgically replaces those down columns without retraining.
|
| 1452 |
+
</div>
|
| 1453 |
+
</div>
|
| 1454 |
+
|
| 1455 |
+
<div class="card">
|
| 1456 |
+
<h3>Quickstart β 5-step experiment</h3>
|
| 1457 |
+
<div style="line-height:2;font-size:12px">
|
| 1458 |
+
<div style="color:var(--yellow);margin-bottom:4px">Step 1 β Load a model that actually knows facts</div>
|
| 1459 |
+
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
|
| 1460 |
+
β Load tab β <span style="color:var(--blue)">gpt2-medium</span> (1.5 GB, knows capitals) or
|
| 1461 |
+
<span style="color:var(--blue)">Qwen/Qwen2.5-1.5B-Instruct</span> (3 GB, strong).<br>
|
| 1462 |
+
distilgpt2 has clean_probβ0 for most facts β causal IE=0 everywhere β misleading results.
|
| 1463 |
+
</div>
|
| 1464 |
+
|
| 1465 |
+
<div style="color:var(--yellow);margin-bottom:4px">Step 2 β Verify the model knows the fact</div>
|
| 1466 |
+
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
|
| 1467 |
+
β Infer: prompt = <code>"The capital of France is"</code><br>
|
| 1468 |
+
β Good: "Paris" appears in top-3 with prob > 0.05<br>
|
| 1469 |
+
β Bad: top tokens are "a", "the", "known" β model doesn't know it β skip to INSERT mode
|
| 1470 |
+
</div>
|
| 1471 |
+
|
| 1472 |
+
<div style="color:var(--yellow);margin-bottom:4px">Step 3 β Find where the fact lives</div>
|
| 1473 |
+
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
|
| 1474 |
+
β’ Trace: prompt = <code>"The capital of France is"</code>, target = <code>"Paris"</code><br>
|
| 1475 |
+
β Look for phase layer: where rank drops from ~30000 to <100. That's where the fact materializes.<br>
|
| 1476 |
+
β€ Smart Locate β Gradient only (fast, 1 backward pass):<br>
|
| 1477 |
+
<span style="margin-left:16px">subject = <code>France</code>, target = <code>Paris</code></span><br>
|
| 1478 |
+
β The layer with highest grad_norm bar = best edit target. Note the slot numbers.
|
| 1479 |
+
</div>
|
| 1480 |
+
|
| 1481 |
+
<div style="color:var(--yellow);margin-bottom:4px">Step 4 β Edit with SMART mode</div>
|
| 1482 |
+
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
|
| 1483 |
+
β¦ Edit tab β mode = <span style="color:var(--blue)">β
SMART</span><br>
|
| 1484 |
+
Entity = <code>France</code> | Relation = <code>capital</code><br>
|
| 1485 |
+
Old value = <code>Paris</code> (what model says now β used for locate)<br>
|
| 1486 |
+
New value = <code>Lyon</code> (what you want)<br>
|
| 1487 |
+
Prompt = <code>"The capital of France is"</code><br>
|
| 1488 |
+
Scale = <code>2.0</code> (start here; increase to 3.0 if effect is weak)<br>
|
| 1489 |
+
β Click <b>Apply Edit</b>. Smart locate runs internally, patches grad-identified slots.
|
| 1490 |
+
</div>
|
| 1491 |
+
|
| 1492 |
+
<div style="color:var(--yellow);margin-bottom:4px">Step 5 β Check collateral damage</div>
|
| 1493 |
+
<div style="color:var(--muted);margin-left:16px;margin-bottom:10px">
|
| 1494 |
+
β Infer: <code>"The capital of France is"</code> β should now say Lyon<br>
|
| 1495 |
+
β Infer: <code>"Biggest cities in France"</code> β should be unchanged (different slots)<br>
|
| 1496 |
+
β Infer: <code>"Paris is a city in"</code> β should still say France<br>
|
| 1497 |
+
β Infer: <code>"Lyon is a city in"</code> β might now also say France (collateral)<br>
|
| 1498 |
+
β€ Smart Locate collateral probe β run these prompts, compare slot lists in β§ Patches
|
| 1499 |
+
</div>
|
| 1500 |
+
</div>
|
| 1501 |
+
</div>
|
| 1502 |
+
|
| 1503 |
+
<div class="card">
|
| 1504 |
+
<h3>Interpreting Smart Locate results</h3>
|
| 1505 |
+
<div style="font-size:11px;line-height:1.9">
|
| 1506 |
+
<div class="row" style="gap:20px">
|
| 1507 |
+
<div class="col">
|
| 1508 |
+
<div style="color:var(--blue);font-weight:700;margin-bottom:6px">β gate_sim (blue)</div>
|
| 1509 |
+
<div style="color:var(--muted)">
|
| 1510 |
+
Cosine between W_gate[slot] and embed(subject).<br>
|
| 1511 |
+
Fast, cheap, but <b>weak proxy</b> β measures embedding-space similarity,<br>
|
| 1512 |
+
not causal contribution. Useful for finding <i>related</i> slots.<br>
|
| 1513 |
+
<b>High gate_sim + low grad_norm</b> = slot activates for this entity<br>
|
| 1514 |
+
but doesn't contribute much to this specific prediction.
|
| 1515 |
+
</div>
|
| 1516 |
+
</div>
|
| 1517 |
+
<div class="col">
|
| 1518 |
+
<div style="color:var(--green);font-weight:700;margin-bottom:6px">β grad_norm (green)</div>
|
| 1519 |
+
<div style="color:var(--muted)">
|
| 1520 |
+
ββ(-log p(target))/βW_down[:,slot]β β how much changing this slot<br>
|
| 1521 |
+
would affect the loss for this (prompt, target) pair.<br>
|
| 1522 |
+
<b>Most reliable signal</b>, works even when clean_prob is tiny.<br>
|
| 1523 |
+
One backward pass. Use Ξ² > Ξ± to weight this higher.<br>
|
| 1524 |
+
<b>High grad_norm</b> = this slot is causally upstream of the prediction.
|
| 1525 |
+
</div>
|
| 1526 |
+
</div>
|
| 1527 |
+
<div class="col">
|
| 1528 |
+
<div style="color:var(--yellow);font-weight:700;margin-bottom:6px">β causal IE (yellow)</div>
|
| 1529 |
+
<div style="color:var(--muted)">
|
| 1530 |
+
Indirect effect via noise-corruption patching (ROME-style).<br>
|
| 1531 |
+
Measures: if I corrupt subject embeddings, how much does patching<br>
|
| 1532 |
+
layer L's hidden state at subject pos <i>restore</i> the prediction?<br>
|
| 1533 |
+
<b>Most interpretable</b> β true causal measurement. But:<br>
|
| 1534 |
+
If clean_prob β 0, IE = 0 everywhere (nothing to restore).<br>
|
| 1535 |
+
Needs a model that actually knows the fact.
|
| 1536 |
+
</div>
|
| 1537 |
+
</div>
|
| 1538 |
+
</div>
|
| 1539 |
+
<div style="margin-top:12px;padding:10px;background:var(--bg);border-radius:6px;border:1px solid var(--border)">
|
| 1540 |
+
<span style="color:var(--yellow)">β Your distilgpt2 result:</span>
|
| 1541 |
+
<span style="color:var(--muted)"> clean_prob=0.000001 β causal IE=0 everywhere (expected, not a bug).
|
| 1542 |
+
grad_norm on L9/slot515 IS real signal β that slot responds to France+capital context in the gradient sense.
|
| 1543 |
+
But the probability mass is too diffuse to show causal separation.
|
| 1544 |
+
Switch to gpt2-medium for textbook causal results.</span>
|
| 1545 |
+
</div>
|
| 1546 |
+
</div>
|
| 1547 |
+
</div>
|
| 1548 |
+
|
| 1549 |
+
<div class="card">
|
| 1550 |
+
<h3>Edit modes β when to use which</h3>
|
| 1551 |
+
<div style="font-size:11px">
|
| 1552 |
+
<table style="width:100%;border-collapse:collapse">
|
| 1553 |
+
<thead><tr style="border-bottom:1px solid var(--border);color:var(--muted)">
|
| 1554 |
+
<th style="padding:6px 8px;text-align:left">Mode</th>
|
| 1555 |
+
<th style="padding:6px 8px;text-align:left">Slot selection</th>
|
| 1556 |
+
<th style="padding:6px 8px;text-align:left">Best for</th>
|
| 1557 |
+
<th style="padding:6px 8px;text-align:left">Knobs</th>
|
| 1558 |
+
</tr></thead>
|
| 1559 |
+
<tbody style="color:var(--muted)">
|
| 1560 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1561 |
+
<td style="padding:6px 8px;color:var(--blue)">UPDATE</td>
|
| 1562 |
+
<td style="padding:6px 8px">gate cosine sim to embed(entity)</td>
|
| 1563 |
+
<td style="padding:6px 8px">Quick experiment, model knows the fact well</td>
|
| 1564 |
+
<td style="padding:6px 8px">Top-K=3-5, Scale=1.5-3</td>
|
| 1565 |
+
</tr>
|
| 1566 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1567 |
+
<td style="padding:6px 8px;color:var(--purple)">PRECISE</td>
|
| 1568 |
+
<td style="padding:6px 8px">gate cosine sim to h_L[subject_pos]</td>
|
| 1569 |
+
<td style="padding:6px 8px">In-context subject representation (3-5Γ better than UPDATE)</td>
|
| 1570 |
+
<td style="padding:6px 8px">+ Prompt field</td>
|
| 1571 |
+
</tr>
|
| 1572 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1573 |
+
<td style="padding:6px 8px;color:var(--yellow)">β
SMART</td>
|
| 1574 |
+
<td style="padding:6px 8px">gradient norm β exact slots, then patch</td>
|
| 1575 |
+
<td style="padding:6px 8px"><b>Best overall.</b> Auto-locates, no manual tuning</td>
|
| 1576 |
+
<td style="padding:6px 8px">Top layers=3, Slots/layer=2, Scale=1.5-2.5</td>
|
| 1577 |
+
</tr>
|
| 1578 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1579 |
+
<td style="padding:6px 8px;color:var(--green)">INSERT</td>
|
| 1580 |
+
<td style="padding:6px 8px">weakest slot (norm-based)</td>
|
| 1581 |
+
<td style="padding:6px 8px">Model has no knowledge of fact, build from scratch</td>
|
| 1582 |
+
<td style="padding:6px 8px">Alpha=0.4-0.7, Spread=4-6</td>
|
| 1583 |
+
</tr>
|
| 1584 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1585 |
+
<td style="padding:6px 8px;color:var(--red)">SUPPRESS</td>
|
| 1586 |
+
<td style="padding:6px 8px">gate cosine β scale W_down to 0</td>
|
| 1587 |
+
<td style="padding:6px 8px">Make model forget an entity (factor=0) or weaken (0.5)</td>
|
| 1588 |
+
<td style="padding:6px 8px">Factor: 0=forget, 0.5=weaken</td>
|
| 1589 |
+
</tr>
|
| 1590 |
+
<tr style="border-bottom:1px solid var(--border)">
|
| 1591 |
+
<td style="padding:6px 8px;color:var(--cyan)">STYLE-SHIFT</td>
|
| 1592 |
+
<td style="padding:6px 8px">gate cosine β add direction vector</td>
|
| 1593 |
+
<td style="padding:6px 8px">Bias/tone shifts: CEOβless male-coded, Parisβdarker</td>
|
| 1594 |
+
<td style="padding:6px 8px">from/to concepts, strength=0.3-0.8</td>
|
| 1595 |
+
</tr>
|
| 1596 |
+
</tbody>
|
| 1597 |
+
</table>
|
| 1598 |
+
</div>
|
| 1599 |
+
</div>
|
| 1600 |
+
|
| 1601 |
+
<div class="card">
|
| 1602 |
+
<h3>Experiments to run</h3>
|
| 1603 |
+
<div style="font-size:11px;line-height:1.9;color:var(--muted)">
|
| 1604 |
+
<div style="color:var(--text);margin-bottom:4px">Experiment A β Capital swap (classic ROME benchmark)</div>
|
| 1605 |
+
Model: gpt2-medium | Prompt: "The capital of France is" | Old: Paris | New: Lyon<br>
|
| 1606 |
+
Check: "France's capital city" | "Lyon is now" | "Paris is in" | "Eiffel Tower is in"<br>
|
| 1607 |
+
Insight: does it generalize (paraphrase) or is it prompt-specific?<br><br>
|
| 1608 |
+
|
| 1609 |
+
<div style="color:var(--text);margin-bottom:4px">Experiment B β Slot overlap analysis (your collateral question)</div>
|
| 1610 |
+
1. SMART locate "The capital of France is" β note slot numbers in recommendation<br>
|
| 1611 |
+
2. SMART locate "The biggest city in France is" β compare slot lists<br>
|
| 1612 |
+
3. Overlap = slots that will be collaterally damaged<br>
|
| 1613 |
+
4. No overlap = clean surgery β<br><br>
|
| 1614 |
+
|
| 1615 |
+
<div style="color:var(--text);margin-bottom:4px">Experiment C β Suppression then INSERT</div>
|
| 1616 |
+
SUPPRESS France β then INSERT France capital Lyon β Infer<br>
|
| 1617 |
+
vs just UPDATE. Which gives cleaner, more confident result?<br><br>
|
| 1618 |
+
|
| 1619 |
+
<div style="color:var(--text);margin-bottom:4px">Experiment D β Style shift (no factual change)</div>
|
| 1620 |
+
STYLE-SHIFT: anchor=CEO, from="male", to="female", strength=0.3<br>
|
| 1621 |
+
Then Infer: "The CEO of the company is a" β does pronoun distribution shift?<br>
|
| 1622 |
+
Insight: this is mechanical debiasing without retraining.<br><br>
|
| 1623 |
+
|
| 1624 |
+
<div style="color:var(--text);margin-bottom:4px">Experiment E β Compile and compare</div>
|
| 1625 |
+
Edit 5 facts. Compile β save as new model directory.<br>
|
| 1626 |
+
Load compiled model fresh β Infer same prompts β edits should persist in weights.<br>
|
| 1627 |
+
Then Trace on compiled model β phase layers should shift or sharpen.
|
| 1628 |
+
</div>
|
| 1629 |
+
</div>
|
| 1630 |
+
|
| 1631 |
+
<div class="card">
|
| 1632 |
+
<h3>Ξ± Ξ² Ξ³ tuning guide</h3>
|
| 1633 |
+
<div style="font-size:11px;line-height:1.9;color:var(--muted)">
|
| 1634 |
+
<b style="color:var(--text)">Default (0.4 / 0.3 / 0.3)</b> β balanced, works for unknown model quality<br>
|
| 1635 |
+
<b style="color:var(--text)">Grad-heavy (0.1 / 0.7 / 0.2)</b> β clean_prob > 0.01. Grad signal is sharp, trust it.<br>
|
| 1636 |
+
<b style="color:var(--text)">Gate+Grad (0.4 / 0.4 / 0.2)</b> β recommended for smart_edit when causal IE is weak<br>
|
| 1637 |
+
<b style="color:var(--text)">Causal-heavy (0.2 / 0.2 / 0.6)</b> β only when clean_prob > 0.1. IE is the gold signal then.<br>
|
| 1638 |
+
<b style="color:var(--text)">Gate-only (1.0 / 0.0 / 0.0)</b> β equivalent to basic locate(), sanity check<br>
|
| 1639 |
+
<br>
|
| 1640 |
+
<b style="color:var(--yellow)">Your distilgpt2 setting:</b> use (0.3 / 0.7 / 0.0) β gate+grad, skip causal (it's 0 anyway).
|
| 1641 |
+
</div>
|
| 1642 |
+
</div>
|
| 1643 |
+
</div>
|
| 1644 |
+
|
| 1645 |
</div><!-- /app -->
|
| 1646 |
|
| 1647 |
<script>
|
|
|
|
| 2032 |
document.querySelectorAll('input[name="edit-mode"]').forEach(r=>{
|
| 2033 |
r.addEventListener('change', ()=>{
|
| 2034 |
document.getElementById('precise-prompt-row').style.display = r.value==='PRECISE'?'block':'none';
|
| 2035 |
+
document.getElementById('smart-row').style.display = r.value==='SMART'?'block':'none';
|
| 2036 |
document.getElementById('style-shift-row').style.display = r.value==='STYLE-SHIFT'?'block':'none';
|
| 2037 |
document.getElementById('multiedit-row').style.display = r.value==='MULTI-EDIT'?'block':'none';
|
| 2038 |
});
|
|
|
|
| 2074 |
to_concept: document.getElementById('ss-to').value,
|
| 2075 |
strength: +document.getElementById('ss-strength').value,
|
| 2076 |
};
|
| 2077 |
+
if(mode==='SMART'){
|
| 2078 |
+
try {
|
| 2079 |
+
const r = await api('/api/smart_edit', {
|
| 2080 |
+
prompt: document.getElementById('smart-prompt').value,
|
| 2081 |
+
subject: document.getElementById('edit-entity').value,
|
| 2082 |
+
relation: document.getElementById('edit-relation').value,
|
| 2083 |
+
old_target: document.getElementById('smart-old').value,
|
| 2084 |
+
new_target: document.getElementById('edit-new').value,
|
| 2085 |
+
top_layers: +document.getElementById('smart-layers').value,
|
| 2086 |
+
slots_per_layer: +document.getElementById('smart-slots').value,
|
| 2087 |
+
scale: +document.getElementById('edit-scale').value,
|
| 2088 |
+
noise_std: 0.1, alpha: 0.4, beta: 0.4, gamma: 0.2,
|
| 2089 |
+
});
|
| 2090 |
+
drawBeforeAfterChart(r.before, r.after);
|
| 2091 |
+
let log = r.debug_log.join('\n');
|
| 2092 |
+
log += '\n\nUsed layers:\n';
|
| 2093 |
+
r.used_layers.forEach(l=>{ log+=` L${l.layer} slots=[${l.slots.join(',')}] combined=${l.combined}\n`; });
|
| 2094 |
+
log += '\nDelta:\n';
|
| 2095 |
+
r.delta.slice(0,8).forEach(d=>{ log+=` ${d.token}: ${d.before.toFixed(4)} β ${d.after.toFixed(4)} ${d.delta>0?'+':''}${d.delta.toFixed(4)}\n`; });
|
| 2096 |
+
document.getElementById('edit-log').textContent = log;
|
| 2097 |
+
updatePatchCount();
|
| 2098 |
+
} catch(e) { alert(e.message); }
|
| 2099 |
+
return;
|
| 2100 |
+
}
|
| 2101 |
if(mode==='MULTI-EDIT'){
|
| 2102 |
try {
|
| 2103 |
body.facts = JSON.parse(document.getElementById('multi-json').value);
|
|
|
|
| 2212 |
} catch(e){}
|
| 2213 |
}
|
| 2214 |
|
| 2215 |
+
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2216 |
+
// SMART LOCATE
|
| 2217 |
+
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2218 |
+
let _slData = null;
|
| 2219 |
+
|
| 2220 |
+
async function runSmartLocate() {
|
| 2221 |
+
const st = document.getElementById('sl-status');
|
| 2222 |
+
st.textContent = 'β³ Running gradient pass + causal sweep (may take ~20s for large models)β¦';
|
| 2223 |
+
try {
|
| 2224 |
+
const data = await api('/api/smart_locate', {
|
| 2225 |
+
prompt: document.getElementById('sl-prompt').value,
|
| 2226 |
+
subject: document.getElementById('sl-subject').value,
|
| 2227 |
+
target: document.getElementById('sl-target').value,
|
| 2228 |
+
alpha: +document.getElementById('sl-alpha').value,
|
| 2229 |
+
beta: +document.getElementById('sl-beta').value,
|
| 2230 |
+
gamma: +document.getElementById('sl-gamma').value,
|
| 2231 |
+
noise_std: +document.getElementById('sl-noise').value,
|
| 2232 |
+
});
|
| 2233 |
+
_slData = data;
|
| 2234 |
+
st.textContent = `β Done. clean_prob=${data.clean_prob.toFixed(4)} corrupt_prob=${data.corrupt_prob.toFixed(4)}`;
|
| 2235 |
+
drawSmartLocateChart(data.ranked_layers);
|
| 2236 |
+
showSmartRec(data);
|
| 2237 |
+
buildSlTable(data.ranked_layers);
|
| 2238 |
+
} catch(e) { st.textContent = 'β '+e.message; }
|
| 2239 |
+
}
|
| 2240 |
+
|
| 2241 |
+
async function runGradientOnly() {
|
| 2242 |
+
const st = document.getElementById('sl-status');
|
| 2243 |
+
st.textContent = 'β³ Running gradient passβ¦';
|
| 2244 |
+
try {
|
| 2245 |
+
const data = await api('/api/gradient_scores', {
|
| 2246 |
+
prompt: document.getElementById('sl-prompt').value,
|
| 2247 |
+
target: document.getElementById('sl-target').value,
|
| 2248 |
+
});
|
| 2249 |
+
st.textContent = `β Gradient done. ${data.layer_scores.length} KB layers.`;
|
| 2250 |
+
// Draw gradient-only bars
|
| 2251 |
+
drawGradOnlyChart(data.layer_scores);
|
| 2252 |
+
} catch(e) { st.textContent = 'β '+e.message; }
|
| 2253 |
+
}
|
| 2254 |
+
|
| 2255 |
+
async function runCausalOnly() {
|
| 2256 |
+
const st = document.getElementById('sl-status');
|
| 2257 |
+
st.textContent = 'β³ Running causal patch traceβ¦';
|
| 2258 |
+
try {
|
| 2259 |
+
const data = await api('/api/causal_trace', {
|
| 2260 |
+
prompt: document.getElementById('sl-prompt').value,
|
| 2261 |
+
subject: document.getElementById('sl-subject').value,
|
| 2262 |
+
target: document.getElementById('sl-target').value,
|
| 2263 |
+
noise_std: +document.getElementById('sl-noise').value,
|
| 2264 |
+
});
|
| 2265 |
+
st.textContent = `β Causal done. clean=${data.clean_prob.toFixed(4)} corrupt=${data.corrupt_prob.toFixed(4)}`;
|
| 2266 |
+
drawCausalOnlyChart(data.results);
|
| 2267 |
+
} catch(e) { st.textContent = 'β '+e.message; }
|
| 2268 |
+
}
|
| 2269 |
+
|
| 2270 |
+
async function runCollateral() {
|
| 2271 |
+
const prompt = document.getElementById('sl-coll-prompt').value;
|
| 2272 |
+
try {
|
| 2273 |
+
const data = await api('/api/infer', { prompt, top_k: 5 });
|
| 2274 |
+
const el = document.getElementById('sl-coll-out');
|
| 2275 |
+
el.textContent = `"${prompt}"\n` +
|
| 2276 |
+
data.results.map(r=>` ${r.token.padEnd(18)} ${r.prob.toFixed(4)}`).join('\n');
|
| 2277 |
+
} catch(e) { document.getElementById('sl-coll-out').textContent = 'β '+e.message; }
|
| 2278 |
+
}
|
| 2279 |
+
|
| 2280 |
+
function drawSmartLocateChart(ranked) {
|
| 2281 |
+
// Sort by layer for chart display
|
| 2282 |
+
const byLayer = [...ranked].sort((a,b)=>a.layer-b.layer);
|
| 2283 |
+
const el = clearChart('sl-chart');
|
| 2284 |
+
const W = el.clientWidth || 700, H = 40 + byLayer.length * 34;
|
| 2285 |
+
el.style.height = H+'px';
|
| 2286 |
+
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
|
| 2287 |
+
const m = {left:50,right:110,top:20,bottom:20};
|
| 2288 |
+
const w = W-m.left-m.right;
|
| 2289 |
+
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
|
| 2290 |
+
|
| 2291 |
+
// Each bar = 3 stacked segments (normalized: gate_sim_n, grad_norm_n, causal_n)
|
| 2292 |
+
// Each segment width = signal_n * (w/3) so max of each is w/3
|
| 2293 |
+
const segW = w / 3;
|
| 2294 |
+
|
| 2295 |
+
byLayer.forEach((d,i)=>{
|
| 2296 |
+
const y = i*34;
|
| 2297 |
+
// Label
|
| 2298 |
+
g.append('text').attr('x',-6).attr('y',y+17).attr('text-anchor','end')
|
| 2299 |
+
.attr('fill', d.layer===(_slData?.recommendation?.layer) ? C.yellow : C.muted)
|
| 2300 |
+
.attr('font-size',10).text('L'+d.layer);
|
| 2301 |
+
|
| 2302 |
+
// gate_sim segment
|
| 2303 |
+
g.append('rect').attr('x',0).attr('y',y+4).attr('width',d.gate_sim_n*segW)
|
| 2304 |
+
.attr('height',12).attr('rx',2).attr('fill',C.blue).attr('opacity',.8)
|
| 2305 |
+
.on('mousemove',(ev)=>showTooltip(`L${d.layer} gate_sim: ${d.gate_sim}`,ev.pageX,ev.pageY))
|
| 2306 |
+
.on('mouseleave',hideTooltip);
|
| 2307 |
+
// grad_norm segment
|
| 2308 |
+
g.append('rect').attr('x',segW).attr('y',y+4).attr('width',d.grad_norm_n*segW)
|
| 2309 |
+
.attr('height',12).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
|
| 2310 |
+
.on('mousemove',(ev)=>showTooltip(`L${d.layer} grad_norm: ${d.grad_norm}`,ev.pageX,ev.pageY))
|
| 2311 |
+
.on('mouseleave',hideTooltip);
|
| 2312 |
+
// causal segment
|
| 2313 |
+
g.append('rect').attr('x',segW*2).attr('y',y+4).attr('width',d.causal_n*segW)
|
| 2314 |
+
.attr('height',12).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
|
| 2315 |
+
.on('mousemove',(ev)=>showTooltip(`L${d.layer} causal_IE: ${d.causal_effect}`,ev.pageX,ev.pageY))
|
| 2316 |
+
.on('mouseleave',hideTooltip);
|
| 2317 |
+
// combined score label
|
| 2318 |
+
g.append('text').attr('x',w+6).attr('y',y+14)
|
| 2319 |
+
.attr('fill', d.combined===Math.max(...ranked.map(r=>r.combined)) ? C.yellow : C.muted)
|
| 2320 |
+
.attr('font-size',10).text(d.combined.toFixed(3));
|
| 2321 |
+
});
|
| 2322 |
+
|
| 2323 |
+
// Axis labels
|
| 2324 |
+
const ax = g.append('g').attr('transform',`translate(0,${byLayer.length*34})`);
|
| 2325 |
+
ax.append('text').attr('x',segW/2).attr('y',14).attr('text-anchor','middle')
|
| 2326 |
+
.attr('fill',C.blue).attr('font-size',9).text('gate_sim');
|
| 2327 |
+
ax.append('text').attr('x',segW*1.5).attr('y',14).attr('text-anchor','middle')
|
| 2328 |
+
.attr('fill',C.green).attr('font-size',9).text('grad_norm');
|
| 2329 |
+
ax.append('text').attr('x',segW*2.5).attr('y',14).attr('text-anchor','middle')
|
| 2330 |
+
.attr('fill',C.yellow).attr('font-size',9).text('causal IE');
|
| 2331 |
+
|
| 2332 |
+
// Section dividers
|
| 2333 |
+
[segW,segW*2].forEach(x=>{
|
| 2334 |
+
g.append('line').attr('x1',x).attr('x2',x).attr('y1',0).attr('y2',byLayer.length*34)
|
| 2335 |
+
.attr('stroke',C.border).attr('stroke-width',1).attr('stroke-dasharray','3,2');
|
| 2336 |
+
});
|
| 2337 |
+
}
|
| 2338 |
+
|
| 2339 |
+
function drawGradOnlyChart(layerScores) {
|
| 2340 |
+
const el = clearChart('sl-chart');
|
| 2341 |
+
const W = el.clientWidth || 700, H = 40 + layerScores.length * 28;
|
| 2342 |
+
el.style.height = H+'px';
|
| 2343 |
+
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
|
| 2344 |
+
const m = {left:50,right:80,top:20,bottom:10};
|
| 2345 |
+
const w = W-m.left-m.right;
|
| 2346 |
+
const maxG = d3.max(layerScores, d=>d.max_grad) || 1;
|
| 2347 |
+
const x = d3.scaleLinear().domain([0,maxG]).range([0,w]);
|
| 2348 |
+
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
|
| 2349 |
+
layerScores.forEach((d,i)=>{
|
| 2350 |
+
const y=i*28;
|
| 2351 |
+
g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
|
| 2352 |
+
.attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
|
| 2353 |
+
g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(d.max_grad))
|
| 2354 |
+
.attr('height',16).attr('rx',2).attr('fill',C.green).attr('opacity',.8)
|
| 2355 |
+
.on('mousemove',(ev)=>showTooltip(`L${d.layer} max_grad: ${d.max_grad}`,ev.pageX,ev.pageY))
|
| 2356 |
+
.on('mouseleave',hideTooltip);
|
| 2357 |
+
g.append('text').attr('x',x(d.max_grad)+4).attr('y',y+14)
|
| 2358 |
+
.attr('fill',C.green).attr('font-size',9).text(d.max_grad.toExponential(2));
|
| 2359 |
+
});
|
| 2360 |
+
}
|
| 2361 |
+
|
| 2362 |
+
function drawCausalOnlyChart(results) {
|
| 2363 |
+
const el = clearChart('sl-chart');
|
| 2364 |
+
const W = el.clientWidth || 700, H = 40 + results.length * 28;
|
| 2365 |
+
el.style.height = H+'px';
|
| 2366 |
+
const svg = d3.select(el).append('svg').attr('width','100%').attr('height',H);
|
| 2367 |
+
const m = {left:50,right:80,top:20,bottom:10};
|
| 2368 |
+
const w = W-m.left-m.right;
|
| 2369 |
+
const maxIE = Math.max(d3.max(results, d=>d.indirect_effect), 0.001);
|
| 2370 |
+
const x = d3.scaleLinear().domain([0,maxIE]).range([0,w]);
|
| 2371 |
+
const g = svg.append('g').attr('transform',`translate(${m.left},${m.top})`);
|
| 2372 |
+
results.forEach((d,i)=>{
|
| 2373 |
+
const y=i*28; const ie=Math.max(0,d.indirect_effect);
|
| 2374 |
+
g.append('text').attr('x',-6).attr('y',y+14).attr('text-anchor','end')
|
| 2375 |
+
.attr('fill',C.muted).attr('font-size',10).text('L'+d.layer);
|
| 2376 |
+
g.append('rect').attr('x',0).attr('y',y+2).attr('width',x(ie))
|
| 2377 |
+
.attr('height',16).attr('rx',2).attr('fill',C.yellow).attr('opacity',.8)
|
| 2378 |
+
.on('mousemove',(ev)=>showTooltip(`L${d.layer} IE: ${d.indirect_effect} patch_p: ${d.patch_prob}`,ev.pageX,ev.pageY))
|
| 2379 |
+
.on('mouseleave',hideTooltip);
|
| 2380 |
+
g.append('text').attr('x',x(ie)+4).attr('y',y+14)
|
| 2381 |
+
.attr('fill',C.yellow).attr('font-size',9).text(d.indirect_effect.toFixed(5));
|
| 2382 |
+
});
|
| 2383 |
+
}
|
| 2384 |
+
|
| 2385 |
+
function showSmartRec(data) {
|
| 2386 |
+
const rec = data.recommendation;
|
| 2387 |
+
if(!rec){ document.getElementById('sl-rec').textContent='No recommendation.'; return; }
|
| 2388 |
+
let txt = `β
Best layer: L${rec.layer} combined=${rec.combined}\n\n`;
|
| 2389 |
+
txt += ` gate_sim: ${rec.gate_sim} (norm ${rec.gate_sim_n})\n`;
|
| 2390 |
+
txt += ` grad_norm: ${rec.grad_norm} (norm ${rec.grad_norm_n})\n`;
|
| 2391 |
+
txt += ` causal_effect: ${rec.causal_effect} (norm ${rec.causal_n})\n`;
|
| 2392 |
+
if(rec.best_slots.length){
|
| 2393 |
+
txt += `\nTop gradient slots in L${rec.layer}:\n`;
|
| 2394 |
+
rec.best_slots.forEach(s=>{ txt+=` slot ${s.slot} grad_norm=${s.grad_norm}\n`; });
|
| 2395 |
+
}
|
| 2396 |
+
txt += `\nPhase layer (trace): L${data.phase_layer}\n`;
|
| 2397 |
+
txt += `Subject pos: ${data.subject_pos}\n`;
|
| 2398 |
+
txt += `clean_prob: ${data.clean_prob} corrupt_prob: ${data.corrupt_prob}`;
|
| 2399 |
+
document.getElementById('sl-rec').textContent = txt;
|
| 2400 |
+
}
|
| 2401 |
+
|
| 2402 |
+
function buildSlTable(ranked) {
|
| 2403 |
+
const el = document.getElementById('sl-table');
|
| 2404 |
+
const maxC = Math.max(...ranked.map(r=>r.combined));
|
| 2405 |
+
let html = `<table style="width:100%;border-collapse:collapse;font-size:11px">
|
| 2406 |
+
<thead><tr style="color:var(--muted);border-bottom:1px solid var(--border)">
|
| 2407 |
+
<th style="padding:4px 8px;text-align:left">Layer</th>
|
| 2408 |
+
<th style="padding:4px 8px;text-align:right;color:${C.blue}">gate_sim</th>
|
| 2409 |
+
<th style="padding:4px 8px;text-align:right;color:${C.green}">grad_norm</th>
|
| 2410 |
+
<th style="padding:4px 8px;text-align:right;color:${C.yellow}">causal IE</th>
|
| 2411 |
+
<th style="padding:4px 8px;text-align:right">combined β
</th>
|
| 2412 |
+
<th style="padding:4px 8px;text-align:left;color:var(--muted)">top grad slots</th>
|
| 2413 |
+
</tr></thead><tbody>`;
|
| 2414 |
+
ranked.forEach(r=>{
|
| 2415 |
+
const hi = r.combined===maxC ? `background:rgba(210,153,34,0.08)` : '';
|
| 2416 |
+
const slots = r.best_slots.slice(0,3).map(s=>s.slot).join(', ');
|
| 2417 |
+
html+=`<tr style="${hi};border-bottom:1px solid var(--border)">
|
| 2418 |
+
<td style="padding:4px 8px;color:${r.combined===maxC?C.yellow:C.text}">L${r.layer}</td>
|
| 2419 |
+
<td style="padding:4px 8px;text-align:right;color:${C.blue}">${r.gate_sim}</td>
|
| 2420 |
+
<td style="padding:4px 8px;text-align:right;color:${C.green}">${r.grad_norm.toExponential(2)}</td>
|
| 2421 |
+
<td style="padding:4px 8px;text-align:right;color:${C.yellow}">${r.causal_effect.toFixed(5)}</td>
|
| 2422 |
+
<td style="padding:4px 8px;text-align:right;font-weight:700">${r.combined}</td>
|
| 2423 |
+
<td style="padding:4px 8px;color:var(--muted)">${slots}</td>
|
| 2424 |
+
</tr>`;
|
| 2425 |
+
});
|
| 2426 |
+
html += '</tbody></table>';
|
| 2427 |
+
el.innerHTML = html;
|
| 2428 |
+
}
|
| 2429 |
+
|
| 2430 |
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2431 |
// INIT
|
| 2432 |
// βββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 2485 |
use_activation: bool = False
|
| 2486 |
prompt: Optional[str] = None
|
| 2487 |
|
| 2488 |
+
class GradientReq(BaseModel):
|
| 2489 |
+
prompt: str
|
| 2490 |
+
target: str
|
| 2491 |
+
|
| 2492 |
+
class CausalTraceReq(BaseModel):
|
| 2493 |
+
prompt: str
|
| 2494 |
+
subject: str
|
| 2495 |
+
target: str
|
| 2496 |
+
noise_std: float = 0.1
|
| 2497 |
+
|
| 2498 |
+
class SmartLocateReq(BaseModel):
|
| 2499 |
+
prompt: str
|
| 2500 |
+
subject: str
|
| 2501 |
+
target: str
|
| 2502 |
+
alpha: float = 0.4
|
| 2503 |
+
beta: float = 0.3
|
| 2504 |
+
gamma: float = 0.3
|
| 2505 |
+
noise_std: float = 0.1
|
| 2506 |
+
|
| 2507 |
+
class SmartEditReq(BaseModel):
|
| 2508 |
+
prompt: str
|
| 2509 |
+
subject: str
|
| 2510 |
+
relation: str = ""
|
| 2511 |
+
old_target: str
|
| 2512 |
+
new_target: str
|
| 2513 |
+
top_layers: int = 3
|
| 2514 |
+
slots_per_layer: int = 2
|
| 2515 |
+
scale: float = 1.5
|
| 2516 |
+
noise_std: float = 0.1
|
| 2517 |
+
alpha: float = 0.4
|
| 2518 |
+
beta: float = 0.4
|
| 2519 |
+
gamma: float = 0.2
|
| 2520 |
+
|
| 2521 |
class DryRunReq(BaseModel):
|
| 2522 |
entity: str
|
| 2523 |
new_target: str
|
|
|
|
| 2624 |
return vi.locate(req.prompt, req.subject, req.target)
|
| 2625 |
|
| 2626 |
|
| 2627 |
+
@app.post("/api/gradient_scores")
|
| 2628 |
+
async def api_gradient_scores(req: GradientReq):
|
| 2629 |
+
vi = _require()
|
| 2630 |
+
return vi.gradient_slot_scores(req.prompt, req.target)
|
| 2631 |
+
|
| 2632 |
+
|
| 2633 |
+
@app.post("/api/causal_trace")
|
| 2634 |
+
async def api_causal_trace(req: CausalTraceReq):
|
| 2635 |
+
vi = _require()
|
| 2636 |
+
return vi.causal_patch_trace(req.prompt, req.subject, req.target,
|
| 2637 |
+
noise_std=req.noise_std)
|
| 2638 |
+
|
| 2639 |
+
|
| 2640 |
+
@app.post("/api/smart_locate")
|
| 2641 |
+
async def api_smart_locate(req: SmartLocateReq):
|
| 2642 |
+
vi = _require()
|
| 2643 |
+
return vi.smart_locate(req.prompt, req.subject, req.target,
|
| 2644 |
+
alpha=req.alpha, beta=req.beta, gamma=req.gamma,
|
| 2645 |
+
noise_std=req.noise_std)
|
| 2646 |
+
|
| 2647 |
+
|
| 2648 |
+
@app.post("/api/smart_edit")
|
| 2649 |
+
async def api_smart_edit(req: SmartEditReq):
|
| 2650 |
+
vi = _require()
|
| 2651 |
+
prompt_str = req.prompt or f"The {req.relation} of {req.subject} is"
|
| 2652 |
+
before = vi.infer(prompt_str, top_k=5)
|
| 2653 |
+
log: List[str] = []
|
| 2654 |
+
try:
|
| 2655 |
+
result = vi.smart_edit(
|
| 2656 |
+
prompt_str, req.subject, req.relation, req.old_target, req.new_target,
|
| 2657 |
+
top_layers=req.top_layers, slots_per_layer=req.slots_per_layer,
|
| 2658 |
+
scale=req.scale, noise_std=req.noise_std,
|
| 2659 |
+
alpha=req.alpha, beta=req.beta, gamma=req.gamma, log=log
|
| 2660 |
+
)
|
| 2661 |
+
except Exception as e:
|
| 2662 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 2663 |
+
after = vi.infer(prompt_str, top_k=5)
|
| 2664 |
+
b_map = {d["token"]: d["prob"] for d in before}
|
| 2665 |
+
a_map = {d["token"]: d["prob"] for d in after}
|
| 2666 |
+
all_toks = set(b_map) | set(a_map)
|
| 2667 |
+
delta = sorted([{"token":t,"before":b_map.get(t,0),"after":a_map.get(t,0),
|
| 2668 |
+
"delta":a_map.get(t,0)-b_map.get(t,0)} for t in all_toks],
|
| 2669 |
+
key=lambda x: -abs(x["delta"]))
|
| 2670 |
+
return {"before": before, "after": after, "delta": delta,
|
| 2671 |
+
"debug_log": log, "used_layers": result["used_layers"],
|
| 2672 |
+
"smart_locate": result["smart_locate"]}
|
| 2673 |
+
|
| 2674 |
+
|
| 2675 |
@app.post("/api/gate_heatmap")
|
| 2676 |
async def api_gate_heatmap(req: HeatmapReq):
|
| 2677 |
vi = _require()
|