Spaces:
Running on L4
Running on L4
Attention Explorer: redesign controls, examples, neighbors; add config support
Browse files- Controls: Layer ▲▼ + Head ◄► on separate rows, matching keyboard shortcuts
- Examples: labeled pairs (bass/spring/light) with background grouping, above input
- Removed current/pitch pairs (kept in search tool, documented in CLAUDE.md)
- Neighbors: pill chain with → bridge arrow, probability-scaled borders
- Movement colors: only show when no word selected (exploration mode)
- Selected word: full-row outline, word-group-only background highlight
- Attention model now configurable via config.json (attention_model key)
- Updated model description text to GPT-2 Medium
- Input label: "Text" with hint "Try an example above, or enter your own text"
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- app.py +72 -45
- attention.py +187 -88
- config.json +8 -3
app.py
CHANGED
|
@@ -207,6 +207,49 @@ body.dark, .dark {
|
|
| 207 |
font-size: 15px;
|
| 208 |
line-height: 1.6;
|
| 209 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
"""
|
| 211 |
|
| 212 |
# Dark mode toggle JS — toggles class and swaps button text + heading colors
|
|
@@ -1193,40 +1236,42 @@ def create_app():
|
|
| 1193 |
gr.Markdown("### Attention Explorer")
|
| 1194 |
gr.Markdown(
|
| 1195 |
"See which words the model pays attention to when processing a sentence. "
|
| 1196 |
-
"Uses GPT-2
|
| 1197 |
"Click a word to see curved lines connecting it to the words it attended to — "
|
| 1198 |
"thicker lines mean stronger attention."
|
| 1199 |
)
|
| 1200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
attn_input = gr.Textbox(
|
| 1202 |
-
label="
|
| 1203 |
value="",
|
| 1204 |
lines=1,
|
| 1205 |
-
placeholder="Enter
|
| 1206 |
)
|
| 1207 |
attn_btn = gr.Button("Explore", variant="primary")
|
| 1208 |
-
gr.Examples(
|
| 1209 |
-
examples=[
|
| 1210 |
-
["He tuned his electric bass and plugged into the"],
|
| 1211 |
-
["Out on the lake she caught a large bass and pulled it onto the"],
|
| 1212 |
-
["She wound the metal spring and the clock began to"],
|
| 1213 |
-
["After the long winter the warm spring rain made the flowers"],
|
| 1214 |
-
["The river current pulled the boat toward the"],
|
| 1215 |
-
["In her speech the current president talked about the"],
|
| 1216 |
-
["She tuned her guitar until the pitch was"],
|
| 1217 |
-
["The ballplayer stood on the mound and wound up to pitch to the"],
|
| 1218 |
-
["She flipped the switch and the light began to"],
|
| 1219 |
-
["The bag was so light she carried it with"],
|
| 1220 |
-
],
|
| 1221 |
-
inputs=[attn_input],
|
| 1222 |
-
)
|
| 1223 |
attn_output = gr.HTML(label="Visualization")
|
| 1224 |
|
| 1225 |
def run_attention_explorer(text):
|
| 1226 |
"""Run attention extraction and build visualization HTML."""
|
| 1227 |
if not text or not text.strip():
|
| 1228 |
-
return "<p style='color:#94a3b8;font-style:italic;'>Enter
|
| 1229 |
-
|
|
|
|
| 1230 |
return build_attention_html(data)
|
| 1231 |
|
| 1232 |
attn_btn.click(
|
|
@@ -1235,31 +1280,13 @@ def create_app():
|
|
| 1235 |
outputs=[attn_output],
|
| 1236 |
)
|
| 1237 |
|
| 1238 |
-
#
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
-
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
|
| 1245 |
-
"In her speech the current president talked about the",
|
| 1246 |
-
"She tuned her guitar until the pitch was",
|
| 1247 |
-
"The ballplayer stood on the mound and wound up to pitch to the",
|
| 1248 |
-
"She flipped the switch and the light began to",
|
| 1249 |
-
"The bag was so light she carried it with",
|
| 1250 |
-
]
|
| 1251 |
-
|
| 1252 |
-
def on_attn_input_change(text):
|
| 1253 |
-
"""Auto-explore if text matches an example (example click)."""
|
| 1254 |
-
if text and text.strip() in attn_examples_list:
|
| 1255 |
-
return run_attention_explorer(text)
|
| 1256 |
-
return gr.update()
|
| 1257 |
-
|
| 1258 |
-
attn_input.change(
|
| 1259 |
-
fn=on_attn_input_change,
|
| 1260 |
-
inputs=[attn_input],
|
| 1261 |
-
outputs=[attn_output],
|
| 1262 |
-
)
|
| 1263 |
|
| 1264 |
# ==================================================================
|
| 1265 |
# Admin Panel
|
|
|
|
| 207 |
font-size: 15px;
|
| 208 |
line-height: 1.6;
|
| 209 |
}
|
| 210 |
+
|
| 211 |
+
/* Attention Explorer example pairs */
|
| 212 |
+
.attn-example-row {
|
| 213 |
+
background: #f8f6fb !important;
|
| 214 |
+
border-radius: 6px !important;
|
| 215 |
+
padding: 6px 10px !important;
|
| 216 |
+
margin-bottom: 2px !important;
|
| 217 |
+
align-items: flex-start !important;
|
| 218 |
+
gap: 6px !important;
|
| 219 |
+
flex-wrap: wrap !important;
|
| 220 |
+
}
|
| 221 |
+
.attn-example-label {
|
| 222 |
+
min-width: 70px !important;
|
| 223 |
+
max-width: 70px !important;
|
| 224 |
+
flex-shrink: 0 !important;
|
| 225 |
+
padding-top: 4px !important;
|
| 226 |
+
}
|
| 227 |
+
.attn-example-label p {
|
| 228 |
+
font-family: 'Merriweather', serif !important;
|
| 229 |
+
color: #63348d !important;
|
| 230 |
+
font-size: 13px !important;
|
| 231 |
+
margin: 0 !important;
|
| 232 |
+
}
|
| 233 |
+
.attn-example-btn {
|
| 234 |
+
flex: 0 0 auto !important;
|
| 235 |
+
min-width: 0 !important;
|
| 236 |
+
max-width: fit-content !important;
|
| 237 |
+
}
|
| 238 |
+
.attn-example-btn button {
|
| 239 |
+
font-size: 12px !important;
|
| 240 |
+
padding: 4px 12px !important;
|
| 241 |
+
border: 1.5px solid #d4c8e8 !important;
|
| 242 |
+
border-radius: 14px !important;
|
| 243 |
+
background: #fff !important;
|
| 244 |
+
color: #4a3070 !important;
|
| 245 |
+
white-space: nowrap !important;
|
| 246 |
+
width: auto !important;
|
| 247 |
+
text-align: left !important;
|
| 248 |
+
}
|
| 249 |
+
.attn-example-btn button:hover {
|
| 250 |
+
background: #f3f0f7 !important;
|
| 251 |
+
border-color: #63348d !important;
|
| 252 |
+
}
|
| 253 |
"""
|
| 254 |
|
| 255 |
# Dark mode toggle JS — toggles class and swaps button text + heading colors
|
|
|
|
| 1236 |
gr.Markdown("### Attention Explorer")
|
| 1237 |
gr.Markdown(
|
| 1238 |
"See which words the model pays attention to when processing a sentence. "
|
| 1239 |
+
"Uses GPT-2 Medium (345M parameters, 24 layers, 16 attention heads). "
|
| 1240 |
"Click a word to see curved lines connecting it to the words it attended to — "
|
| 1241 |
"thicker lines mean stronger attention."
|
| 1242 |
)
|
| 1243 |
|
| 1244 |
+
# Example sentence pairs — labeled by polysemy word
|
| 1245 |
+
attn_example_pairs = [
|
| 1246 |
+
("bass", "He tuned his bass and plugged into the", "On the lake she caught a bass and pulled it onto the"),
|
| 1247 |
+
("spring", "She wound the metal spring and the clock began to", "After the long winter the warm spring rain made the flowers"),
|
| 1248 |
+
("light", "She flipped the switch and the light began to", "The bag was so light she carried it with"),
|
| 1249 |
+
]
|
| 1250 |
+
|
| 1251 |
+
attn_example_btns = []
|
| 1252 |
+
for word, sent_a, sent_b in attn_example_pairs:
|
| 1253 |
+
with gr.Row(elem_classes=["attn-example-row"]):
|
| 1254 |
+
gr.Markdown(f"**{word}:**", elem_classes=["attn-example-label"])
|
| 1255 |
+
btn_a = gr.Button(sent_a, size="sm", variant="secondary", elem_classes=["attn-example-btn"])
|
| 1256 |
+
btn_b = gr.Button(sent_b, size="sm", variant="secondary", elem_classes=["attn-example-btn"])
|
| 1257 |
+
attn_example_btns.extend([btn_a, btn_b])
|
| 1258 |
+
|
| 1259 |
+
gr.Markdown("*Try an example above, or enter your own text:*")
|
| 1260 |
attn_input = gr.Textbox(
|
| 1261 |
+
label="Text",
|
| 1262 |
value="",
|
| 1263 |
lines=1,
|
| 1264 |
+
placeholder="Enter text to explore...",
|
| 1265 |
)
|
| 1266 |
attn_btn = gr.Button("Explore", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
attn_output = gr.HTML(label="Visualization")
|
| 1268 |
|
| 1269 |
def run_attention_explorer(text):
|
| 1270 |
"""Run attention extraction and build visualization HTML."""
|
| 1271 |
if not text or not text.strip():
|
| 1272 |
+
return "<p style='color:#94a3b8;font-style:italic;'>Enter text above and click Explore.</p>"
|
| 1273 |
+
attn_model = manager.config.get("attention_model", "gpt2-medium")
|
| 1274 |
+
data = get_attention_data(text.strip(), model_name=attn_model)
|
| 1275 |
return build_attention_html(data)
|
| 1276 |
|
| 1277 |
attn_btn.click(
|
|
|
|
| 1280 |
outputs=[attn_output],
|
| 1281 |
)
|
| 1282 |
|
| 1283 |
+
# Wire up example buttons — each sets input and auto-explores
|
| 1284 |
+
for btn in attn_example_btns:
|
| 1285 |
+
btn.click(
|
| 1286 |
+
fn=lambda text: (text, run_attention_explorer(text)),
|
| 1287 |
+
inputs=[btn],
|
| 1288 |
+
outputs=[attn_input, attn_output],
|
| 1289 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1290 |
|
| 1291 |
# ==================================================================
|
| 1292 |
# Admin Panel
|
attention.py
CHANGED
|
@@ -33,14 +33,17 @@ def _detect_device() -> str:
|
|
| 33 |
return "cpu"
|
| 34 |
|
| 35 |
|
| 36 |
-
def load_attention_model():
|
| 37 |
-
"""Load
|
| 38 |
global _model, _tokenizer, _device
|
| 39 |
if _model is not None:
|
| 40 |
return
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
-
|
| 44 |
)
|
| 45 |
_model.eval()
|
| 46 |
_device = _detect_device()
|
|
@@ -275,7 +278,7 @@ def _compute_movement(hidden_states, word_groups):
|
|
| 275 |
return movement
|
| 276 |
|
| 277 |
|
| 278 |
-
def get_attention_data(text):
|
| 279 |
"""Run forward pass and return attention + neighbor + movement data.
|
| 280 |
|
| 281 |
Returns dict with:
|
|
@@ -288,7 +291,7 @@ def get_attention_data(text):
|
|
| 288 |
"""
|
| 289 |
global _cached_hidden_states, _cached_words
|
| 290 |
|
| 291 |
-
load_attention_model()
|
| 292 |
|
| 293 |
inputs = _tokenizer(text, return_tensors="pt")
|
| 294 |
inputs = {k: v.to(_device) for k, v in inputs.items()}
|
|
@@ -381,37 +384,40 @@ def _build_viz_html(data_json):
|
|
| 381 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&family=JetBrains+Mono&family=Merriweather:wght@700&display=swap" rel="stylesheet">
|
| 382 |
</head><body style="margin:0;padding:0;">
|
| 383 |
<div class="attn-explorer" id="attn-viz" tabindex="0">
|
| 384 |
-
<!--
|
| 385 |
<div class="attn-controls">
|
| 386 |
-
<div class="attn-
|
| 387 |
<span class="attn-selector-title">Layer</span>
|
| 388 |
-
<button class="attn-nav-btn" onclick="attnPrevLayer()">&#
|
| 389 |
<span class="attn-label" id="attn-layer-label">12 / 12</span>
|
| 390 |
-
<button class="attn-nav-btn" onclick="attnNextLayer()">&#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
</div>
|
| 392 |
-
<div class="attn-
|
| 393 |
<span class="attn-selector-title">Head</span>
|
| 394 |
<button class="attn-nav-btn" onclick="attnPrevHead()">◀</button>
|
| 395 |
<span class="attn-label" id="attn-head-label">Avg</span>
|
| 396 |
<button class="attn-nav-btn" onclick="attnNextHead()">▶</button>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
</div>
|
| 398 |
-
<label class="attn-show-all">
|
| 399 |
-
<input type="checkbox" id="attn-show-all-cb" onchange="attnToggleShowAll(this.checked)">
|
| 400 |
-
Show all connections
|
| 401 |
-
</label>
|
| 402 |
-
<div class="attn-nbr-count">
|
| 403 |
-
<label>Neighbors:</label>
|
| 404 |
-
<select id="attn-nbr-select" onchange="attnSetNeighborCount(this.value)">
|
| 405 |
-
<option value="0">off</option>
|
| 406 |
-
<option value="1">1</option>
|
| 407 |
-
<option value="3" selected>3</option>
|
| 408 |
-
<option value="5">5</option>
|
| 409 |
-
</select>
|
| 410 |
-
</div>
|
| 411 |
-
<label class="attn-show-all">
|
| 412 |
-
<input type="checkbox" id="attn-show-punct-cb" onchange="attnToggleShowPunct(this.checked)">
|
| 413 |
-
Show punctuation
|
| 414 |
-
</label>
|
| 415 |
</div>
|
| 416 |
|
| 417 |
<div class="attn-hint" id="attn-hint">Click a word to see what it attended to</div>
|
|
@@ -432,31 +438,30 @@ def _build_viz_html(data_json):
|
|
| 432 |
<style>
|
| 433 |
.attn-explorer {{
|
| 434 |
font-family: 'Inter', system-ui, sans-serif;
|
| 435 |
-
padding: 16px;
|
| 436 |
outline: none;
|
| 437 |
}}
|
| 438 |
.attn-controls {{
|
| 439 |
display: flex;
|
| 440 |
-
|
| 441 |
-
gap:
|
| 442 |
margin-bottom: 12px;
|
| 443 |
-
flex-wrap: wrap;
|
| 444 |
}}
|
| 445 |
-
.attn-
|
| 446 |
display: flex;
|
| 447 |
align-items: center;
|
| 448 |
-
gap:
|
| 449 |
}}
|
| 450 |
.attn-selector-title {{
|
| 451 |
font-family: 'Merriweather', serif;
|
| 452 |
-
font-weight:
|
| 453 |
color: #63348d;
|
| 454 |
-
font-size:
|
| 455 |
-
min-width:
|
| 456 |
}}
|
| 457 |
.attn-label {{
|
| 458 |
font-family: 'JetBrains Mono', monospace;
|
| 459 |
-
font-size:
|
| 460 |
min-width: 72px;
|
| 461 |
text-align: center;
|
| 462 |
color: #1e293b;
|
|
@@ -465,52 +470,66 @@ def _build_viz_html(data_json):
|
|
| 465 |
background: #63348d;
|
| 466 |
color: white;
|
| 467 |
border: none;
|
| 468 |
-
border-radius:
|
| 469 |
-
|
|
|
|
| 470 |
cursor: pointer;
|
| 471 |
-
font-size:
|
| 472 |
line-height: 1;
|
|
|
|
|
|
|
|
|
|
| 473 |
transition: background 0.15s;
|
| 474 |
}}
|
| 475 |
.attn-nav-btn:hover {{
|
| 476 |
background: #4e2870;
|
| 477 |
}}
|
| 478 |
-
.attn-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
color: #475569;
|
| 484 |
-
cursor: pointer;
|
| 485 |
-
user-select: none;
|
| 486 |
}}
|
| 487 |
-
.attn-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
| 490 |
}}
|
| 491 |
-
.attn-nbr-
|
| 492 |
-
display: flex;
|
| 493 |
-
align-items: center;
|
| 494 |
-
gap: 6px;
|
| 495 |
-
font-size: 14px;
|
| 496 |
-
color: #475569;
|
| 497 |
-
}}
|
| 498 |
-
.attn-nbr-count label {{
|
| 499 |
font-family: 'Merriweather', serif;
|
| 500 |
-
font-weight:
|
| 501 |
color: #63348d;
|
| 502 |
font-size: 14px;
|
| 503 |
}}
|
| 504 |
-
.attn-nbr-
|
| 505 |
font-family: 'JetBrains Mono', monospace;
|
| 506 |
font-size: 13px;
|
| 507 |
-
padding:
|
| 508 |
border: 1px solid #d1d5db;
|
| 509 |
border-radius: 4px;
|
| 510 |
background: white;
|
| 511 |
color: #1e293b;
|
| 512 |
cursor: pointer;
|
| 513 |
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
.attn-hint {{
|
| 515 |
font-size: 13px;
|
| 516 |
color: #94a3b8;
|
|
@@ -522,10 +541,11 @@ def _build_viz_html(data_json):
|
|
| 522 |
min-height: 100px;
|
| 523 |
display: flex;
|
| 524 |
flex-direction: row;
|
|
|
|
| 525 |
}}
|
| 526 |
.attn-lines-container {{
|
| 527 |
-
width:
|
| 528 |
-
min-width:
|
| 529 |
position: relative;
|
| 530 |
flex-shrink: 0;
|
| 531 |
}}
|
|
@@ -538,44 +558,82 @@ def _build_viz_html(data_json):
|
|
| 538 |
flex-grow: 1;
|
| 539 |
}}
|
| 540 |
.attn-word-row {{
|
| 541 |
-
display:
|
| 542 |
align-items: center;
|
| 543 |
-
padding:
|
| 544 |
-
border-radius: 4px;
|
| 545 |
cursor: pointer;
|
| 546 |
user-select: none;
|
| 547 |
font-size: 16px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
transition: background 0.15s;
|
| 549 |
-
width:
|
|
|
|
| 550 |
}}
|
| 551 |
-
.attn-word-row:hover {{
|
| 552 |
background: #f3f0f7;
|
| 553 |
}}
|
| 554 |
.attn-word-row.attn-selected {{
|
| 555 |
-
background: #ded9f4;
|
| 556 |
-
font-weight: 600;
|
| 557 |
outline: 2px solid #63348d;
|
| 558 |
outline-offset: -2px;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
}}
|
| 560 |
-
.attn-word-row.attn-target {{
|
| 561 |
background: #f3f0f7;
|
| 562 |
}}
|
| 563 |
.attn-word-label {{
|
| 564 |
font-family: 'Inter', system-ui, sans-serif;
|
|
|
|
| 565 |
}}
|
| 566 |
.attn-weight-badge {{
|
| 567 |
font-family: 'JetBrains Mono', monospace;
|
| 568 |
font-size: 11px;
|
| 569 |
color: #63348d;
|
| 570 |
-
margin-left:
|
| 571 |
opacity: 0.8;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
}}
|
| 573 |
.attn-word-neighbors {{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
font-family: 'JetBrains Mono', monospace;
|
| 575 |
font-size: 12px;
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
}}
|
| 580 |
.attn-lines-svg {{
|
| 581 |
position: absolute;
|
|
@@ -761,16 +819,18 @@ def _build_viz_html(data_json):
|
|
| 761 |
}}
|
| 762 |
}}
|
| 763 |
|
| 764 |
-
// Movement background color —
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
}} else if (movement && movement[i] !== undefined) {{
|
| 770 |
var m = movement[i];
|
| 771 |
-
|
|
|
|
|
|
|
| 772 |
}} else {{
|
| 773 |
row.style.background = '';
|
|
|
|
| 774 |
}}
|
| 775 |
}});
|
| 776 |
|
|
@@ -806,9 +866,11 @@ def _build_viz_html(data_json):
|
|
| 806 |
var neighbors = ATTN_DATA.neighbors[layerKey];
|
| 807 |
WORDS.forEach(function(word, i) {{
|
| 808 |
var el = document.getElementById('attn-nbr-' + i);
|
|
|
|
| 809 |
if (el) {{
|
|
|
|
| 810 |
if (attnNeighborCount === 0 || !neighbors || !neighbors[i]) {{
|
| 811 |
-
|
| 812 |
return;
|
| 813 |
}}
|
| 814 |
// Filter: skip self, always skip profanity, optionally skip punctuation
|
|
@@ -821,7 +883,29 @@ def _build_viz_html(data_json):
|
|
| 821 |
if (!attnShowPunct && n.punct) continue;
|
| 822 |
filtered.push(n);
|
| 823 |
}}
|
| 824 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
}}
|
| 826 |
}});
|
| 827 |
}}
|
|
@@ -901,16 +985,31 @@ def _build_viz_html(data_json):
|
|
| 901 |
row.dataset.idx = i;
|
| 902 |
row.onclick = function() {{ attnSelectWord(i); }};
|
| 903 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
var label = document.createElement('span');
|
| 905 |
label.className = 'attn-word-label';
|
| 906 |
label.textContent = word;
|
| 907 |
-
|
| 908 |
|
| 909 |
var badge = document.createElement('span');
|
| 910 |
badge.className = 'attn-weight-badge';
|
| 911 |
badge.style.display = 'none';
|
| 912 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
|
|
|
|
| 914 |
var neighbors = document.createElement('span');
|
| 915 |
neighbors.className = 'attn-word-neighbors';
|
| 916 |
neighbors.id = 'attn-nbr-' + i;
|
|
|
|
| 33 |
return "cpu"
|
| 34 |
|
| 35 |
|
| 36 |
+
def load_attention_model(model_name=None):
|
| 37 |
+
"""Load attention model with eager attention. Idempotent."""
|
| 38 |
global _model, _tokenizer, _device
|
| 39 |
if _model is not None:
|
| 40 |
return
|
| 41 |
+
if model_name is None:
|
| 42 |
+
model_name = "gpt2-medium"
|
| 43 |
+
print(f"Loading attention model: {model_name}")
|
| 44 |
+
_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 45 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
model_name, attn_implementation="eager"
|
| 47 |
)
|
| 48 |
_model.eval()
|
| 49 |
_device = _detect_device()
|
|
|
|
| 278 |
return movement
|
| 279 |
|
| 280 |
|
| 281 |
+
def get_attention_data(text, model_name=None):
|
| 282 |
"""Run forward pass and return attention + neighbor + movement data.
|
| 283 |
|
| 284 |
Returns dict with:
|
|
|
|
| 291 |
"""
|
| 292 |
global _cached_hidden_states, _cached_words
|
| 293 |
|
| 294 |
+
load_attention_model(model_name)
|
| 295 |
|
| 296 |
inputs = _tokenizer(text, return_tensors="pt")
|
| 297 |
inputs = {k: v.to(_device) for k, v in inputs.items()}
|
|
|
|
| 384 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&family=JetBrains+Mono&family=Merriweather:wght@700&display=swap" rel="stylesheet">
|
| 385 |
</head><body style="margin:0;padding:0;">
|
| 386 |
<div class="attn-explorer" id="attn-viz" tabindex="0">
|
| 387 |
+
<!-- Controls: two rows -->
|
| 388 |
<div class="attn-controls">
|
| 389 |
+
<div class="attn-controls-row">
|
| 390 |
<span class="attn-selector-title">Layer</span>
|
| 391 |
+
<button class="attn-nav-btn" onclick="attnPrevLayer()">▲</button>
|
| 392 |
<span class="attn-label" id="attn-layer-label">12 / 12</span>
|
| 393 |
+
<button class="attn-nav-btn" onclick="attnNextLayer()">▼</button>
|
| 394 |
+
<span class="attn-key-hint">↑ ↓</span>
|
| 395 |
+
<span class="attn-sep"></span>
|
| 396 |
+
<span class="attn-nbr-label">Neighbors:</span>
|
| 397 |
+
<select class="attn-nbr-select" id="attn-nbr-select" onchange="attnSetNeighborCount(this.value)">
|
| 398 |
+
<option value="0">Off</option>
|
| 399 |
+
<option value="1">1</option>
|
| 400 |
+
<option value="3" selected>3</option>
|
| 401 |
+
<option value="5">5</option>
|
| 402 |
+
</select>
|
| 403 |
</div>
|
| 404 |
+
<div class="attn-controls-row">
|
| 405 |
<span class="attn-selector-title">Head</span>
|
| 406 |
<button class="attn-nav-btn" onclick="attnPrevHead()">◀</button>
|
| 407 |
<span class="attn-label" id="attn-head-label">Avg</span>
|
| 408 |
<button class="attn-nav-btn" onclick="attnNextHead()">▶</button>
|
| 409 |
+
<span class="attn-key-hint">← →</span>
|
| 410 |
+
<span class="attn-sep"></span>
|
| 411 |
+
<span class="attn-show-label">Show:</span>
|
| 412 |
+
<label class="attn-checkbox">
|
| 413 |
+
<input type="checkbox" id="attn-show-all-cb" onchange="attnToggleShowAll(this.checked)">
|
| 414 |
+
All connections
|
| 415 |
+
</label>
|
| 416 |
+
<label class="attn-checkbox">
|
| 417 |
+
<input type="checkbox" id="attn-show-punct-cb" onchange="attnToggleShowPunct(this.checked)">
|
| 418 |
+
Punctuation
|
| 419 |
+
</label>
|
| 420 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
</div>
|
| 422 |
|
| 423 |
<div class="attn-hint" id="attn-hint">Click a word to see what it attended to</div>
|
|
|
|
| 438 |
<style>
|
| 439 |
.attn-explorer {{
|
| 440 |
font-family: 'Inter', system-ui, sans-serif;
|
| 441 |
+
padding: 16px 16px 16px 0;
|
| 442 |
outline: none;
|
| 443 |
}}
|
| 444 |
.attn-controls {{
|
| 445 |
display: flex;
|
| 446 |
+
flex-direction: column;
|
| 447 |
+
gap: 10px;
|
| 448 |
margin-bottom: 12px;
|
|
|
|
| 449 |
}}
|
| 450 |
+
.attn-controls-row {{
|
| 451 |
display: flex;
|
| 452 |
align-items: center;
|
| 453 |
+
gap: 10px;
|
| 454 |
}}
|
| 455 |
.attn-selector-title {{
|
| 456 |
font-family: 'Merriweather', serif;
|
| 457 |
+
font-weight: 700;
|
| 458 |
color: #63348d;
|
| 459 |
+
font-size: 15px;
|
| 460 |
+
min-width: 50px;
|
| 461 |
}}
|
| 462 |
.attn-label {{
|
| 463 |
font-family: 'JetBrains Mono', monospace;
|
| 464 |
+
font-size: 15px;
|
| 465 |
min-width: 72px;
|
| 466 |
text-align: center;
|
| 467 |
color: #1e293b;
|
|
|
|
| 470 |
background: #63348d;
|
| 471 |
color: white;
|
| 472 |
border: none;
|
| 473 |
+
border-radius: 6px;
|
| 474 |
+
width: 32px;
|
| 475 |
+
height: 32px;
|
| 476 |
cursor: pointer;
|
| 477 |
+
font-size: 16px;
|
| 478 |
line-height: 1;
|
| 479 |
+
display: inline-flex;
|
| 480 |
+
align-items: center;
|
| 481 |
+
justify-content: center;
|
| 482 |
transition: background 0.15s;
|
| 483 |
}}
|
| 484 |
.attn-nav-btn:hover {{
|
| 485 |
background: #4e2870;
|
| 486 |
}}
|
| 487 |
+
.attn-key-hint {{
|
| 488 |
+
font-family: 'JetBrains Mono', monospace;
|
| 489 |
+
font-size: 10px;
|
| 490 |
+
color: #999;
|
| 491 |
+
letter-spacing: 3px;
|
|
|
|
|
|
|
|
|
|
| 492 |
}}
|
| 493 |
+
.attn-sep {{
|
| 494 |
+
width: 1px;
|
| 495 |
+
height: 32px;
|
| 496 |
+
background: #e0d8eb;
|
| 497 |
+
margin: 0 6px;
|
| 498 |
}}
|
| 499 |
+
.attn-nbr-label {{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
font-family: 'Merriweather', serif;
|
| 501 |
+
font-weight: 700;
|
| 502 |
color: #63348d;
|
| 503 |
font-size: 14px;
|
| 504 |
}}
|
| 505 |
+
.attn-nbr-select {{
|
| 506 |
font-family: 'JetBrains Mono', monospace;
|
| 507 |
font-size: 13px;
|
| 508 |
+
padding: 4px 8px;
|
| 509 |
border: 1px solid #d1d5db;
|
| 510 |
border-radius: 4px;
|
| 511 |
background: white;
|
| 512 |
color: #1e293b;
|
| 513 |
cursor: pointer;
|
| 514 |
}}
|
| 515 |
+
.attn-show-label {{
|
| 516 |
+
font-weight: 600;
|
| 517 |
+
color: #63348d;
|
| 518 |
+
font-size: 14px;
|
| 519 |
+
}}
|
| 520 |
+
.attn-checkbox {{
|
| 521 |
+
display: flex;
|
| 522 |
+
align-items: center;
|
| 523 |
+
gap: 6px;
|
| 524 |
+
font-size: 14px;
|
| 525 |
+
color: #475569;
|
| 526 |
+
cursor: pointer;
|
| 527 |
+
user-select: none;
|
| 528 |
+
}}
|
| 529 |
+
.attn-checkbox input {{
|
| 530 |
+
accent-color: #63348d;
|
| 531 |
+
cursor: pointer;
|
| 532 |
+
}}
|
| 533 |
.attn-hint {{
|
| 534 |
font-size: 13px;
|
| 535 |
color: #94a3b8;
|
|
|
|
| 541 |
min-height: 100px;
|
| 542 |
display: flex;
|
| 543 |
flex-direction: row;
|
| 544 |
+
margin-left: -16px;
|
| 545 |
}}
|
| 546 |
.attn-lines-container {{
|
| 547 |
+
width: 140px;
|
| 548 |
+
min-width: 140px;
|
| 549 |
position: relative;
|
| 550 |
flex-shrink: 0;
|
| 551 |
}}
|
|
|
|
| 558 |
flex-grow: 1;
|
| 559 |
}}
|
| 560 |
.attn-word-row {{
|
| 561 |
+
display: flex;
|
| 562 |
align-items: center;
|
| 563 |
+
padding: 3px 4px;
|
|
|
|
| 564 |
cursor: pointer;
|
| 565 |
user-select: none;
|
| 566 |
font-size: 16px;
|
| 567 |
+
min-height: 36px;
|
| 568 |
+
}}
|
| 569 |
+
.attn-word-group {{
|
| 570 |
+
display: inline-flex;
|
| 571 |
+
align-items: center;
|
| 572 |
+
padding: 4px 10px;
|
| 573 |
+
border-radius: 5px;
|
| 574 |
transition: background 0.15s;
|
| 575 |
+
width: 130px;
|
| 576 |
+
flex-shrink: 0;
|
| 577 |
}}
|
| 578 |
+
.attn-word-row:hover .attn-word-group {{
|
| 579 |
background: #f3f0f7;
|
| 580 |
}}
|
| 581 |
.attn-word-row.attn-selected {{
|
|
|
|
|
|
|
| 582 |
outline: 2px solid #63348d;
|
| 583 |
outline-offset: -2px;
|
| 584 |
+
border-radius: 5px;
|
| 585 |
+
}}
|
| 586 |
+
.attn-word-row.attn-selected .attn-word-group {{
|
| 587 |
+
background: #ded9f4;
|
| 588 |
+
font-weight: 600;
|
| 589 |
}}
|
| 590 |
+
.attn-word-row.attn-target .attn-word-group {{
|
| 591 |
background: #f3f0f7;
|
| 592 |
}}
|
| 593 |
.attn-word-label {{
|
| 594 |
font-family: 'Inter', system-ui, sans-serif;
|
| 595 |
+
min-width: 70px;
|
| 596 |
}}
|
| 597 |
.attn-weight-badge {{
|
| 598 |
font-family: 'JetBrains Mono', monospace;
|
| 599 |
font-size: 11px;
|
| 600 |
color: #63348d;
|
| 601 |
+
margin-left: 6px;
|
| 602 |
opacity: 0.8;
|
| 603 |
+
width: 30px;
|
| 604 |
+
text-align: right;
|
| 605 |
+
}}
|
| 606 |
+
.attn-bridge {{
|
| 607 |
+
color: #b8a8d0;
|
| 608 |
+
font-size: 14px;
|
| 609 |
+
padding: 0 8px 0 10px;
|
| 610 |
+
flex-shrink: 0;
|
| 611 |
}}
|
| 612 |
.attn-word-neighbors {{
|
| 613 |
+
display: flex;
|
| 614 |
+
align-items: center;
|
| 615 |
+
gap: 0;
|
| 616 |
+
}}
|
| 617 |
+
.attn-nbr-pill {{
|
| 618 |
font-family: 'JetBrains Mono', monospace;
|
| 619 |
font-size: 12px;
|
| 620 |
+
padding: 2px 8px;
|
| 621 |
+
white-space: nowrap;
|
| 622 |
+
background: #fff;
|
| 623 |
+
color: #4a3070;
|
| 624 |
+
border-radius: 4px;
|
| 625 |
+
display: inline-flex;
|
| 626 |
+
align-items: center;
|
| 627 |
+
justify-content: center;
|
| 628 |
+
width: 90px;
|
| 629 |
+
text-align: center;
|
| 630 |
+
}}
|
| 631 |
+
.attn-nbr-sep {{
|
| 632 |
+
width: 16px;
|
| 633 |
+
flex-shrink: 0;
|
| 634 |
+
text-align: center;
|
| 635 |
+
font-size: 13px;
|
| 636 |
+
color: #b8a8d0;
|
| 637 |
}}
|
| 638 |
.attn-lines-svg {{
|
| 639 |
position: absolute;
|
|
|
|
| 819 |
}}
|
| 820 |
}}
|
| 821 |
|
| 822 |
+
// Movement background color — only when no word is selected (exploration mode)
|
| 823 |
+
// Full row highlight since movement relates to the word AND its neighbors
|
| 824 |
+
var wordGroup = row.querySelector('.attn-word-group');
|
| 825 |
+
var noSelection = (attnSelectedWord === null);
|
| 826 |
+
if (noSelection && movement && movement[i] !== undefined) {{
|
|
|
|
| 827 |
var m = movement[i];
|
| 828 |
+
var moveBg = 'rgba(99, 52, 141, ' + (m * 0.2).toFixed(3) + ')';
|
| 829 |
+
row.style.background = moveBg;
|
| 830 |
+
if (wordGroup) wordGroup.style.background = '';
|
| 831 |
}} else {{
|
| 832 |
row.style.background = '';
|
| 833 |
+
if (wordGroup) wordGroup.style.background = '';
|
| 834 |
}}
|
| 835 |
}});
|
| 836 |
|
|
|
|
| 866 |
var neighbors = ATTN_DATA.neighbors[layerKey];
|
| 867 |
WORDS.forEach(function(word, i) {{
|
| 868 |
var el = document.getElementById('attn-nbr-' + i);
|
| 869 |
+
var bridgeEl = document.getElementById('attn-bridge-' + i);
|
| 870 |
if (el) {{
|
| 871 |
+
el.innerHTML = '';
|
| 872 |
if (attnNeighborCount === 0 || !neighbors || !neighbors[i]) {{
|
| 873 |
+
if (bridgeEl) bridgeEl.style.visibility = 'hidden';
|
| 874 |
return;
|
| 875 |
}}
|
| 876 |
// Filter: skip self, always skip profanity, optionally skip punctuation
|
|
|
|
| 883 |
if (!attnShowPunct && n.punct) continue;
|
| 884 |
filtered.push(n);
|
| 885 |
}}
|
| 886 |
+
if (filtered.length === 0) {{
|
| 887 |
+
if (bridgeEl) bridgeEl.style.visibility = 'hidden';
|
| 888 |
+
return;
|
| 889 |
+
}}
|
| 890 |
+
// Show bridge arrow
|
| 891 |
+
if (bridgeEl) bridgeEl.style.visibility = 'visible';
|
| 892 |
+
// Render pills
|
| 893 |
+
filtered.forEach(function(n, idx) {{
|
| 894 |
+
if (idx > 0) {{
|
| 895 |
+
var sep = document.createElement('span');
|
| 896 |
+
sep.className = 'attn-nbr-sep';
|
| 897 |
+
sep.textContent = '\u203A';
|
| 898 |
+
el.appendChild(sep);
|
| 899 |
+
}}
|
| 900 |
+
var pill = document.createElement('span');
|
| 901 |
+
pill.className = 'attn-nbr-pill';
|
| 902 |
+
// Border color/width scaled by probability
|
| 903 |
+
var alpha = 0.20 + 0.75 * Math.min(1, n.prob / 0.7);
|
| 904 |
+
var bw = Math.max(1, Math.min(2.5, 1 + n.prob * 3));
|
| 905 |
+
pill.style.border = bw.toFixed(1) + 'px solid rgba(99,52,141,' + alpha.toFixed(2) + ')';
|
| 906 |
+
pill.textContent = n.word;
|
| 907 |
+
el.appendChild(pill);
|
| 908 |
+
}});
|
| 909 |
}}
|
| 910 |
}});
|
| 911 |
}}
|
|
|
|
| 985 |
row.dataset.idx = i;
|
| 986 |
row.onclick = function() {{ attnSelectWord(i); }};
|
| 987 |
|
| 988 |
+
// Word group (highlight target)
|
| 989 |
+
var group = document.createElement('span');
|
| 990 |
+
group.className = 'attn-word-group';
|
| 991 |
+
|
| 992 |
var label = document.createElement('span');
|
| 993 |
label.className = 'attn-word-label';
|
| 994 |
label.textContent = word;
|
| 995 |
+
group.appendChild(label);
|
| 996 |
|
| 997 |
var badge = document.createElement('span');
|
| 998 |
badge.className = 'attn-weight-badge';
|
| 999 |
badge.style.display = 'none';
|
| 1000 |
+
group.appendChild(badge);
|
| 1001 |
+
|
| 1002 |
+
row.appendChild(group);
|
| 1003 |
+
|
| 1004 |
+
// Bridge arrow
|
| 1005 |
+
var bridge = document.createElement('span');
|
| 1006 |
+
bridge.className = 'attn-bridge';
|
| 1007 |
+
bridge.textContent = '\u2192';
|
| 1008 |
+
bridge.style.visibility = 'hidden';
|
| 1009 |
+
bridge.id = 'attn-bridge-' + i;
|
| 1010 |
+
row.appendChild(bridge);
|
| 1011 |
|
| 1012 |
+
// Neighbor pills container
|
| 1013 |
var neighbors = document.createElement('span');
|
| 1014 |
neighbors.className = 'attn-word-neighbors';
|
| 1015 |
neighbors.id = 'attn-nbr-' + i;
|
config.json
CHANGED
|
@@ -20,7 +20,12 @@
|
|
| 20 |
"Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others."
|
| 21 |
},
|
| 22 |
"chat_model": "Llama-3.2-3B-Instruct",
|
| 23 |
-
"attention_model": "gpt2",
|
| 24 |
-
"default_attention_sentence": "
|
| 25 |
-
"default_neighbor_count": 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
}
|
|
|
|
| 20 |
"Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others."
|
| 21 |
},
|
| 22 |
"chat_model": "Llama-3.2-3B-Instruct",
|
| 23 |
+
"attention_model": "gpt2-medium",
|
| 24 |
+
"default_attention_sentence": "",
|
| 25 |
+
"default_neighbor_count": 3,
|
| 26 |
+
"attention_examples": [
|
| 27 |
+
["bass", "He tuned his bass and plugged into the", "On the lake she caught a bass and pulled it onto the"],
|
| 28 |
+
["spring", "She wound the metal spring and the clock began to", "After the long winter the warm spring rain made the flowers"],
|
| 29 |
+
["light", "She flipped the switch and the light began to", "The bag was so light she carried it with"]
|
| 30 |
+
]
|
| 31 |
}
|