Spaces:
Running
Running
Commit ·
7fa8fb4
1
Parent(s): d86e476
Attention refactor, better categorization and explanation
Browse files- app.py +4 -3
- components/pipeline.py +218 -89
- jarvis_llmvis_ux_review.md +240 -0
- rag_docs/head_categories_explained.md +33 -31
- scripts/analyze_heads.py +564 -0
- tests/conftest.py +0 -8
- tests/test_head_detection.py +391 -273
- utils/__init__.py +4 -5
- utils/head_categories.json +1099 -0
- utils/head_detection.py +245 -373
- utils/model_patterns.py +1 -20
app.py
CHANGED
|
@@ -18,7 +18,7 @@ import json
|
|
| 18 |
import torch
|
| 19 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 20 |
perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
|
| 21 |
-
from utils.head_detection import
|
| 22 |
from utils.model_config import get_auto_selections
|
| 23 |
from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
|
| 24 |
|
|
@@ -576,10 +576,11 @@ def update_pipeline_content(activation_data, model_name):
|
|
| 576 |
except:
|
| 577 |
pass
|
| 578 |
|
| 579 |
-
#
|
| 580 |
head_categories = None
|
| 581 |
try:
|
| 582 |
-
|
|
|
|
| 583 |
except:
|
| 584 |
pass
|
| 585 |
|
|
|
|
| 18 |
import torch
|
| 19 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 20 |
perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
|
| 21 |
+
from utils.head_detection import get_active_head_summary
|
| 22 |
from utils.model_config import get_auto_selections
|
| 23 |
from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
|
| 24 |
|
|
|
|
| 576 |
except:
|
| 577 |
pass
|
| 578 |
|
| 579 |
+
# Get head categorization from pre-computed JSON + runtime verification
|
| 580 |
head_categories = None
|
| 581 |
try:
|
| 582 |
+
from utils.head_detection import get_active_head_summary
|
| 583 |
+
head_categories = get_active_head_summary(activation_data, model_name)
|
| 584 |
except:
|
| 585 |
pass
|
| 586 |
|
components/pipeline.py
CHANGED
|
@@ -400,133 +400,263 @@ def create_attention_content(attention_html=None, top_attended=None, layer_info=
|
|
| 400 |
"""
|
| 401 |
Create content for the attention stage.
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
|
| 406 |
Args:
|
| 407 |
attention_html: BertViz HTML string for attention visualization
|
| 408 |
-
top_attended: DEPRECATED - no longer used
|
| 409 |
layer_info: Optional layer information for context
|
| 410 |
-
head_categories:
|
| 411 |
-
|
| 412 |
-
|
| 413 |
"""
|
| 414 |
content_items = [
|
| 415 |
html.Div([
|
| 416 |
html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
|
| 417 |
html.P([
|
| 418 |
"The model looks at ", html.Strong("all tokens at once"),
|
| 419 |
-
" and figures out which ones are related to each other. This is called 'attention'
|
| 420 |
"each token 'attends to' other tokens to gather context for its prediction."
|
| 421 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '12px'}),
|
| 422 |
html.P([
|
| 423 |
-
"Attention has multiple ", html.Strong("heads"), "
|
| 424 |
-
"
|
| 425 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
|
| 426 |
])
|
| 427 |
]
|
| 428 |
|
| 429 |
-
#
|
| 430 |
-
if head_categories:
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
'
|
| 436 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
}
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
category_sections = []
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
| 450 |
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
category_sections.append(
|
| 474 |
-
html.Details([
|
| 475 |
-
html.Summary([
|
| 476 |
-
html.Span(label, style={'fontWeight': '500', 'color': '#495057'}),
|
| 477 |
-
html.Span(f" ({count})", style={'marginLeft': '4px', 'color': '#6c757d'})
|
| 478 |
-
], style={
|
| 479 |
-
'padding': '8px 12px',
|
| 480 |
-
'backgroundColor': f'{color}15',
|
| 481 |
-
'border': f'1px solid {color}30',
|
| 482 |
-
'borderRadius': '8px',
|
| 483 |
-
'cursor': 'pointer',
|
| 484 |
-
'userSelect': 'none',
|
| 485 |
-
'listStyle': 'none',
|
| 486 |
-
'display': 'flex',
|
| 487 |
-
'alignItems': 'center'
|
| 488 |
-
}, title=tooltip),
|
| 489 |
-
# Expanded content - list of heads
|
| 490 |
html.Div([
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
'fontSize': '12px',
|
| 494 |
-
'
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
html.Div(
|
| 498 |
-
html.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
], style={
|
| 500 |
-
'
|
| 501 |
-
'
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
})
|
| 504 |
], style={
|
| 505 |
-
'
|
| 506 |
-
'
|
| 507 |
-
'borderRadius': '0 0 8px 8px',
|
| 508 |
-
'marginTop': '-1px',
|
| 509 |
-
'border': f'1px solid {color}30',
|
| 510 |
-
'borderTop': 'none'
|
| 511 |
})
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
if category_sections:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
content_items.append(
|
| 517 |
html.Div([
|
| 518 |
-
html.H5("Attention Head
|
| 519 |
html.P([
|
| 520 |
-
|
| 521 |
-
"Click
|
| 522 |
], style={'color': '#6c757d', 'fontSize': '12px', 'marginBottom': '12px'}),
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
], style={'marginBottom': '16px'})
|
| 525 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
# BertViz visualization with navigation instructions
|
| 528 |
if attention_html:
|
| 529 |
-
# Agent G: Enhanced navigation instructions for head view
|
| 530 |
content_items.append(
|
| 531 |
html.Div([
|
| 532 |
html.H5("How to Navigate the Attention Visualization:", style={'color': '#495057', 'marginBottom': '12px'}),
|
|
@@ -537,7 +667,6 @@ def create_attention_content(attention_html=None, top_attended=None, layer_info=
|
|
| 537 |
html.Span("Click on layer/head numbers at the top to view specific attention heads.",
|
| 538 |
style={'color': '#6c757d'})
|
| 539 |
], style={'marginBottom': '4px'}),
|
| 540 |
-
# Sub-points for click behaviors
|
| 541 |
html.Div([
|
| 542 |
html.Span("• ", style={'color': '#f093fb', 'fontWeight': 'bold'}),
|
| 543 |
html.Strong("Single click ", style={'color': '#495057'}),
|
|
|
|
| 400 |
"""
|
| 401 |
Create content for the attention stage.
|
| 402 |
|
| 403 |
+
Displays head categorization with active/inactive states, activation bars,
|
| 404 |
+
suggested prompts, and guided interpretation.
|
| 405 |
|
| 406 |
Args:
|
| 407 |
attention_html: BertViz HTML string for attention visualization
|
| 408 |
+
top_attended: DEPRECATED - no longer used
|
| 409 |
layer_info: Optional layer information for context
|
| 410 |
+
head_categories: Output from get_active_head_summary() — dict with 'categories' key
|
| 411 |
+
containing per-category data with activation scores.
|
| 412 |
+
Falls back gracefully if None or old format.
|
| 413 |
"""
|
| 414 |
content_items = [
|
| 415 |
html.Div([
|
| 416 |
html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
|
| 417 |
html.P([
|
| 418 |
"The model looks at ", html.Strong("all tokens at once"),
|
| 419 |
+
" and figures out which ones are related to each other. This is called 'attention' — ",
|
| 420 |
"each token 'attends to' other tokens to gather context for its prediction."
|
| 421 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '12px'}),
|
| 422 |
html.P([
|
| 423 |
+
"Attention has multiple ", html.Strong("heads"), " — each head learns to look for different types of relationships. ",
|
| 424 |
+
"Below you can see what role each head plays and whether it's active on your current input."
|
| 425 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
|
| 426 |
])
|
| 427 |
]
|
| 428 |
|
| 429 |
+
# New: Head Roles Panel using get_active_head_summary() output
|
| 430 |
+
if head_categories and isinstance(head_categories, dict) and 'categories' in head_categories:
|
| 431 |
+
categories = head_categories['categories']
|
| 432 |
+
|
| 433 |
+
# Color scheme per category
|
| 434 |
+
category_colors = {
|
| 435 |
+
'previous_token': '#667eea',
|
| 436 |
+
'induction': '#e67e22',
|
| 437 |
+
'duplicate_token': '#9b59b6',
|
| 438 |
+
'positional': '#2ecc71',
|
| 439 |
+
'diffuse': '#3498db',
|
| 440 |
+
'other': '#95a5a6'
|
| 441 |
}
|
| 442 |
|
| 443 |
+
# Find the top recommended head for guided interpretation
|
| 444 |
+
guided_head = None
|
| 445 |
+
guided_cat = None
|
| 446 |
+
for cat_key in ['previous_token', 'induction', 'positional']:
|
| 447 |
+
cat_data = categories.get(cat_key, {})
|
| 448 |
+
heads = cat_data.get('heads', [])
|
| 449 |
+
active_heads = [h for h in heads if h.get('is_active')]
|
| 450 |
+
if active_heads:
|
| 451 |
+
best = max(active_heads, key=lambda h: h['activation_score'])
|
| 452 |
+
if guided_head is None or best['activation_score'] > guided_head['activation_score']:
|
| 453 |
+
guided_head = best
|
| 454 |
+
guided_cat = cat_data.get('display_name', cat_key)
|
| 455 |
+
|
| 456 |
+
# Guided interpretation recommendation
|
| 457 |
+
if guided_head:
|
| 458 |
+
content_items.append(
|
| 459 |
+
html.Div([
|
| 460 |
+
html.I(className='fas fa-lightbulb', style={'color': '#f39c12', 'marginRight': '8px', 'fontSize': '16px'}),
|
| 461 |
+
html.Span([
|
| 462 |
+
html.Strong("Try this: "),
|
| 463 |
+
f"Select Layer {guided_head['layer']}, Head {guided_head['head']} in the visualization below — ",
|
| 464 |
+
f"this is a {guided_cat} head ",
|
| 465 |
+
f"(activation: {guided_head['activation_score']:.0%} on your input)."
|
| 466 |
+
], style={'color': '#495057', 'fontSize': '13px'})
|
| 467 |
+
], style={
|
| 468 |
+
'padding': '12px 16px', 'backgroundColor': '#fef9e7', 'borderRadius': '8px',
|
| 469 |
+
'border': '1px solid #f9e79f', 'marginBottom': '16px', 'display': 'flex', 'alignItems': 'center'
|
| 470 |
+
})
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
# Build category sections
|
| 474 |
category_sections = []
|
| 475 |
+
category_order = ['previous_token', 'induction', 'duplicate_token', 'positional', 'diffuse', 'other']
|
| 476 |
+
|
| 477 |
+
for cat_key in category_order:
|
| 478 |
+
cat_data = categories.get(cat_key, {})
|
| 479 |
+
if not cat_data:
|
| 480 |
+
continue
|
| 481 |
|
| 482 |
+
display_name = cat_data.get('display_name', cat_key)
|
| 483 |
+
description = cat_data.get('description', '')
|
| 484 |
+
educational_text = cat_data.get('educational_text', '')
|
| 485 |
+
icon_name = cat_data.get('icon', 'circle')
|
| 486 |
+
is_applicable = cat_data.get('is_applicable', True)
|
| 487 |
+
suggested_prompt = cat_data.get('suggested_prompt')
|
| 488 |
+
heads = cat_data.get('heads', [])
|
| 489 |
+
color = category_colors.get(cat_key, '#95a5a6')
|
| 490 |
|
| 491 |
+
# Active vs inactive indicator
|
| 492 |
+
has_active_heads = any(h.get('is_active') for h in heads)
|
| 493 |
+
status_icon = '●' if (is_applicable and has_active_heads) else '○'
|
| 494 |
+
status_color = color if (is_applicable and has_active_heads) else '#ccc'
|
| 495 |
+
|
| 496 |
+
# Skip "other" if no heads (which is the normal case)
|
| 497 |
+
if cat_key == 'other' and not heads:
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
# Build head items with activation bars
|
| 501 |
+
head_items = []
|
| 502 |
+
if heads:
|
| 503 |
+
for head_info in heads:
|
| 504 |
+
activation = head_info.get('activation_score', 0.0)
|
| 505 |
+
is_active = head_info.get('is_active', False)
|
| 506 |
+
label = head_info.get('label', f"L{head_info['layer']}-H{head_info['head']}")
|
| 507 |
+
|
| 508 |
+
# Activation bar
|
| 509 |
+
bar_width = max(activation * 100, 2) # Min 2% for visibility
|
| 510 |
+
bar_color = color if is_active else '#ddd'
|
| 511 |
+
|
| 512 |
+
head_items.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
html.Div([
|
| 514 |
+
# Head label
|
| 515 |
+
html.Span(label, style={
|
| 516 |
+
'fontFamily': 'monospace', 'fontSize': '12px', 'fontWeight': '500',
|
| 517 |
+
'minWidth': '60px', 'color': '#495057' if is_active else '#aaa',
|
| 518 |
+
}, title=f"See Layer {head_info['layer']}, Head {head_info['head']} in the visualization below"),
|
| 519 |
+
# Activation bar
|
| 520 |
+
html.Div([
|
| 521 |
+
html.Div(style={
|
| 522 |
+
'width': f'{bar_width}%', 'height': '100%',
|
| 523 |
+
'backgroundColor': bar_color, 'borderRadius': '3px',
|
| 524 |
+
'transition': 'width 0.3s ease'
|
| 525 |
+
})
|
| 526 |
], style={
|
| 527 |
+
'flex': '1', 'height': '12px', 'backgroundColor': '#f0f0f0',
|
| 528 |
+
'borderRadius': '3px', 'margin': '0 8px', 'overflow': 'hidden'
|
| 529 |
+
}),
|
| 530 |
+
# Score label
|
| 531 |
+
html.Span(f"{activation:.2f}", style={
|
| 532 |
+
'fontSize': '11px', 'fontFamily': 'monospace',
|
| 533 |
+
'color': '#495057' if is_active else '#bbb', 'minWidth': '32px'
|
| 534 |
})
|
| 535 |
], style={
|
| 536 |
+
'display': 'flex', 'alignItems': 'center', 'marginBottom': '4px',
|
| 537 |
+
'opacity': '1' if is_active else '0.5'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
})
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Build the category section
|
| 542 |
+
# Header content
|
| 543 |
+
summary_children = [
|
| 544 |
+
html.Span(status_icon, style={
|
| 545 |
+
'color': status_color, 'fontSize': '16px', 'marginRight': '8px'
|
| 546 |
+
}),
|
| 547 |
+
html.Span(display_name, style={'fontWeight': '500', 'color': '#495057'}),
|
| 548 |
+
]
|
| 549 |
+
|
| 550 |
+
if heads:
|
| 551 |
+
active_count = sum(1 for h in heads if h.get('is_active'))
|
| 552 |
+
summary_children.append(
|
| 553 |
+
html.Span(f" ({active_count}/{len(heads)} active)", style={
|
| 554 |
+
'marginLeft': '6px', 'color': '#6c757d', 'fontSize': '12px'
|
| 555 |
+
})
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if not is_applicable:
|
| 559 |
+
summary_children.append(
|
| 560 |
+
html.Span(" — not triggered on this input", style={
|
| 561 |
+
'marginLeft': '6px', 'color': '#aaa', 'fontSize': '12px', 'fontStyle': 'italic'
|
| 562 |
+
})
|
| 563 |
)
|
| 564 |
+
|
| 565 |
+
# Expanded content
|
| 566 |
+
expanded_children = []
|
| 567 |
+
|
| 568 |
+
# Educational explanation
|
| 569 |
+
if educational_text:
|
| 570 |
+
expanded_children.append(
|
| 571 |
+
html.P(educational_text, style={
|
| 572 |
+
'color': '#6c757d', 'fontSize': '13px', 'marginBottom': '10px',
|
| 573 |
+
'fontStyle': 'italic', 'lineHeight': '1.5'
|
| 574 |
+
})
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Suggested prompt (for grayed-out categories)
|
| 578 |
+
if not is_applicable and suggested_prompt:
|
| 579 |
+
expanded_children.append(
|
| 580 |
+
html.Div([
|
| 581 |
+
html.I(className='fas fa-flask', style={'color': '#e67e22', 'marginRight': '6px'}),
|
| 582 |
+
html.Span(suggested_prompt, style={'color': '#e67e22', 'fontSize': '12px'})
|
| 583 |
+
], style={
|
| 584 |
+
'padding': '8px 12px', 'backgroundColor': '#fef5e7',
|
| 585 |
+
'borderRadius': '6px', 'marginBottom': '10px', 'border': '1px solid #fde8c8'
|
| 586 |
+
})
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Head activation bars
|
| 590 |
+
if head_items:
|
| 591 |
+
expanded_children.append(html.Div(head_items))
|
| 592 |
+
|
| 593 |
+
category_sections.append(
|
| 594 |
+
html.Details([
|
| 595 |
+
html.Summary(summary_children, style={
|
| 596 |
+
'padding': '10px 14px',
|
| 597 |
+
'backgroundColor': f'{color}08' if is_applicable else '#fafafa',
|
| 598 |
+
'border': f'1px solid {color}25' if is_applicable else '1px solid #eee',
|
| 599 |
+
'borderRadius': '8px', 'cursor': 'pointer', 'userSelect': 'none',
|
| 600 |
+
'listStyle': 'none', 'display': 'flex', 'alignItems': 'center'
|
| 601 |
+
}),
|
| 602 |
+
html.Div(expanded_children, style={
|
| 603 |
+
'padding': '12px 14px', 'backgroundColor': '#fafbfc',
|
| 604 |
+
'borderRadius': '0 0 8px 8px', 'marginTop': '-1px',
|
| 605 |
+
'border': f'1px solid {color}25' if is_applicable else '1px solid #eee',
|
| 606 |
+
'borderTop': 'none'
|
| 607 |
+
})
|
| 608 |
+
], style={'marginBottom': '8px'}, open=(cat_key == 'previous_token')) # Default-open first category
|
| 609 |
+
)
|
| 610 |
|
| 611 |
if category_sections:
|
| 612 |
+
# Legend
|
| 613 |
+
legend = html.Div([
|
| 614 |
+
html.Span("● = active on your input", style={
|
| 615 |
+
'color': '#495057', 'fontSize': '11px', 'marginRight': '16px'
|
| 616 |
+
}),
|
| 617 |
+
html.Span("○ = role exists but not triggered", style={
|
| 618 |
+
'color': '#aaa', 'fontSize': '11px'
|
| 619 |
+
})
|
| 620 |
+
], style={'marginBottom': '10px'})
|
| 621 |
+
|
| 622 |
content_items.append(
|
| 623 |
html.Div([
|
| 624 |
+
html.H5("Attention Head Roles:", style={'color': '#495057', 'marginBottom': '8px'}),
|
| 625 |
html.P([
|
| 626 |
+
"Each category represents a type of behavior we detected in this model's attention heads. ",
|
| 627 |
+
"Click a category to see individual heads and how strongly they're activated on your input."
|
| 628 |
], style={'color': '#6c757d', 'fontSize': '12px', 'marginBottom': '12px'}),
|
| 629 |
+
legend,
|
| 630 |
+
html.Div(category_sections),
|
| 631 |
+
# Accuracy caveat
|
| 632 |
+
html.Div([
|
| 633 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '6px', 'fontSize': '11px'}),
|
| 634 |
+
html.Span(
|
| 635 |
+
"These categories are simplified labels based on each head's dominant behavior. "
|
| 636 |
+
"In reality, heads can serve multiple roles and may behave differently on different inputs.",
|
| 637 |
+
style={'color': '#999', 'fontSize': '11px'}
|
| 638 |
+
)
|
| 639 |
+
], style={'marginTop': '12px', 'padding': '8px 12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '6px'})
|
| 640 |
], style={'marginBottom': '16px'})
|
| 641 |
)
|
| 642 |
+
elif head_categories is None:
|
| 643 |
+
# Model not analyzed — show fallback message
|
| 644 |
+
content_items.append(
|
| 645 |
+
html.Div([
|
| 646 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
| 647 |
+
html.Span(
|
| 648 |
+
"Head categorization is not available for this model. "
|
| 649 |
+
"The attention visualization below still shows the full attention patterns.",
|
| 650 |
+
style={'color': '#6c757d', 'fontSize': '13px'}
|
| 651 |
+
)
|
| 652 |
+
], style={
|
| 653 |
+
'padding': '12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
|
| 654 |
+
'border': '1px solid #dee2e6', 'marginBottom': '16px'
|
| 655 |
+
})
|
| 656 |
+
)
|
| 657 |
|
| 658 |
# BertViz visualization with navigation instructions
|
| 659 |
if attention_html:
|
|
|
|
| 660 |
content_items.append(
|
| 661 |
html.Div([
|
| 662 |
html.H5("How to Navigate the Attention Visualization:", style={'color': '#495057', 'marginBottom': '12px'}),
|
|
|
|
| 667 |
html.Span("Click on layer/head numbers at the top to view specific attention heads.",
|
| 668 |
style={'color': '#6c757d'})
|
| 669 |
], style={'marginBottom': '4px'}),
|
|
|
|
| 670 |
html.Div([
|
| 671 |
html.Span("• ", style={'color': '#f093fb', 'fontWeight': 'bold'}),
|
| 672 |
html.Strong("Single click ", style={'color': '#495057'}),
|
jarvis_llmvis_ux_review.md
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLMVis UX & Explanation Review
|
| 2 |
+
**Date:** 2026-02-26
|
| 3 |
+
**Reviewer:** JARVIS
|
| 4 |
+
**Method:** Playwright automated walkthrough of https://cdpearlman-llmvis.hf.space (GPT-2 124M, prompt: "The cat sat on the mat. The cat")
|
| 5 |
+
**Reference:** `attention_handoff.md` (attention head categorization spec)
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Executive Summary
|
| 10 |
+
|
| 11 |
+
The app is in solid working shape. The pipeline storytelling is clean, the BertViz integration works, and attribution renders well. The two biggest gaps against the handoff spec are: (1) the attention head categorization is broken — 132/144 heads are mislabeled as "First/Positional," swamping all meaningful signal; and (2) the induction, duplicate, and diffuse head categories from the spec are entirely absent. Beyond that, the attention visualization is the weakest explanation panel — it shows the heatmap but doesn't teach the student what to look for. Ablation UX also has friction and never surfaced results in testing.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## 1. Overall Layout & First Impression
|
| 16 |
+
|
| 17 |
+
**What's good:**
|
| 18 |
+
- Clean gradient header, uncluttered layout
|
| 19 |
+
- The pipeline section ("How the Model Processes Your Input") is a strong pedagogical frame — the numbered steps with the flow chip bar (Input → Tokens → Embed → Attention → MLP → Output) is excellent
|
| 20 |
+
- Glossary modal auto-opens on first visit, which is a good onboarding move
|
| 21 |
+
- The sidebar module selection (showing `transformer.h.{N}.attn` etc.) is a nice power-user layer
|
| 22 |
+
|
| 23 |
+
**Issues:**
|
| 24 |
+
- **Glossary modal close button is off-screen** at default viewport widths. The `×` renders at x≈1858 on a 1400px window. Students on laptops will be stuck staring at a modal they can't close without scrolling right. Fix: position the close button inside the modal boundary, not at the document edge.
|
| 25 |
+
- **45-second cold start with no feedback.** After clicking Analyze, the pipeline stages show "Awaiting analysis..." with no progress indicator, spinner, or ETA. For a student, this looks broken. Fix: add a loading spinner or "Model is warming up (~30s)..." message on first run.
|
| 26 |
+
- **Generation Settings sliders are confusing.** "Number of Generation Choices" with values 1/3/5 is jargon. Students don't know what beam search is. The label should be "Explore How Many Different Continuations?" or similar, with a tooltip. The current glossary entry on Beam Search is good but isn't linked from the slider.
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## 2. Tokenization Stage
|
| 31 |
+
|
| 32 |
+
**What's good:**
|
| 33 |
+
- Clean token→ID table. Exactly the right content.
|
| 34 |
+
- "Your text is split into 10 tokens" summary in the header is great.
|
| 35 |
+
|
| 36 |
+
**Issues:**
|
| 37 |
+
- **No visual "aha" moment.** The table shows Token→ID correctly, but doesn't show *why* "The" becomes 464 vs "the" becoming 262. The capitalization distinction (same word, different token) is sitting right there in this example and the app doesn't call it out. This is a perfect teachable moment — highlight it.
|
| 38 |
+
- **No subword tokenization example.** The prompt was simple English so all tokens were whole words. When a student types something with subwords (e.g., "transformers"), they won't know that's unusual. Consider adding a note: "Notice: some words may split into multiple pieces — try typing 'unhappiness' to see subword tokenization."
|
| 39 |
+
- **The token ID numbers mean nothing to students.** Worth a one-liner: "These IDs are just addresses in a vocabulary table of 50,257 words and word-pieces."
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 3. Embedding Stage
|
| 44 |
+
|
| 45 |
+
**What's good:**
|
| 46 |
+
- The `Token ID → Lookup Table → [768-dimensional vector]` flow diagram is clean and conceptually correct.
|
| 47 |
+
- The callout box ("How the lookup table was created: During training on billions of text examples...") is excellent — this is exactly the kind of "where did this come from?" context students need.
|
| 48 |
+
|
| 49 |
+
**Issues:**
|
| 50 |
+
- **No actual data shown.** The stage says "768-dimensional vector" but never shows a student what even 5 dimensions of that vector look like. Even a truncated display like `[0.23, -1.41, 0.07, ...]` would make it real.
|
| 51 |
+
- **No similarity demo.** The explanation says "words with similar meanings (like 'happy' and 'joyful') have similar vectors" — but doesn't show it. A small cosine similarity callout using tokens actually in the input ("'cat' and 'mat' are somewhat similar; 'cat' and 'The' are not") would land this point.
|
| 52 |
+
- **Missing: positional embeddings.** This is a significant omission. The embedding stage in a transformer is `token_embedding + positional_embedding`. The current explanation only covers token embeddings. Students who read further literature will be confused. Add: "Each token also gets a positional embedding added — a second vector encoding *where* in the sequence it appears."
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 4. Attention Stage
|
| 57 |
+
|
| 58 |
+
This is the most important and most underbuilt section. The handoff doc has a detailed vision that is only partially implemented.
|
| 59 |
+
|
| 60 |
+
### 4a. Head Category Panel
|
| 61 |
+
|
| 62 |
+
**Critical bug: First/Positional is consuming 132/144 heads.**
|
| 63 |
+
|
| 64 |
+
The categorization output:
|
| 65 |
+
- Previous-Token: 6 heads ✓ (reasonable)
|
| 66 |
+
- First/Positional: **132 heads** ✗ (this is ~92% of all heads — clearly wrong)
|
| 67 |
+
- Syntactic: 5 heads (plausible)
|
| 68 |
+
- Other: 1 head
|
| 69 |
+
|
| 70 |
+
This makes the category panel meaningless. A student sees a wall of 132 head IDs under "First/Positional" and learns nothing. The classification threshold for positional heads is almost certainly too loose, OR the `all_scores` from the offline script are being compared against an incorrect threshold. The handoff spec calls for a cap of ~8 heads per category with layer diversity enforcement — that logic is either not implemented or the thresholds need significant tuning.
|
| 71 |
+
|
| 72 |
+
**Missing categories from the spec:**
|
| 73 |
+
The handoff doc specifies 6 categories:
|
| 74 |
+
1. ✅ Previous Token (implemented)
|
| 75 |
+
2. ❌ **Induction** (missing entirely)
|
| 76 |
+
3. ❌ **Duplicate Token** (missing entirely)
|
| 77 |
+
4. ✅ First/Positional (implemented but broken threshold)
|
| 78 |
+
5. ❌ **Diffuse / Bag-of-Words** (missing entirely)
|
| 79 |
+
6. ✅ Other/Unclassified (implemented)
|
| 80 |
+
|
| 81 |
+
"Syntactic" appears as a category but isn't in the handoff spec — unclear where it came from or how it's detected.
|
| 82 |
+
|
| 83 |
+
**Missing: runtime activation scoring.** The spec calls for each head to show an activation score on the *current input* (e.g., whether induction heads are firing given the repeated "The cat" in the prompt). Nothing like this exists yet — heads are just listed as belonging to categories with no indication of whether they're active or dormant on this specific input.
|
| 84 |
+
|
| 85 |
+
**Missing: greyed-out heads with "suggested prompts."** The spec's pedagogically most powerful idea — "Try adding a repeated sentence to see induction heads light up" — doesn't exist at all. This is the thing that turns passive observation into active discovery.
|
| 86 |
+
|
| 87 |
+
### 4b. Attention Visualization (BertViz)
|
| 88 |
+
|
| 89 |
+
**What's good:**
|
| 90 |
+
- BertViz integration works and renders the attention heatmap
|
| 91 |
+
- The navigation instructions (single click, double click, hover) are clear
|
| 92 |
+
|
| 93 |
+
**Issues:**
|
| 94 |
+
- **No guided interpretation.** The visualization shows lines but doesn't tell the student what they're looking at. For a student who just read that "some heads track pronouns," they need a nudge: "Try Layer 4, Head 11 — this head often looks at the previous word." Right now the student opens a heatmap of spaghetti lines and has no idea what to conclude.
|
| 95 |
+
- **The attention viz and head category panel are disconnected.** Clicking a head in the category list should highlight/select it in the BertViz below. The handoff spec mentions this: "Clicking a head navigates to its attention heatmap." That linkage doesn't exist.
|
| 96 |
+
- **No explanation of what "good" attention looks like.** The viz shows all heads at once by default. For a 12×12 model that's 144 attention patterns — overwhelming. The default view should be a single interesting head (e.g., the strongest previous-token head), not all heads.
|
| 97 |
+
- **Layer selector is bare.** The "Layer: [dropdown]" control has no context. Why would a student change the layer? Add: "Earlier layers tend to capture syntax; later layers capture meaning."
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
## 5. MLP (Feed-Forward) Stage
|
| 102 |
+
|
| 103 |
+
**What's good:**
|
| 104 |
+
- The `768d → 3072d → 768d` expand/compress diagram is clean
|
| 105 |
+
- The "Why expand then compress?" callout box is excellent — the neuron activation framing is correct
|
| 106 |
+
- "This happens in each of the model's 12 layers, with attention and MLP working together" is a good summary
|
| 107 |
+
|
| 108 |
+
**Issues:**
|
| 109 |
+
- **No connection to the current input.** The Paris/France example is generic and not connected to the actual prompt being analyzed. Consider: "For your prompt, the MLP layers are likely retrieving knowledge about common English sentence structures."
|
| 110 |
+
- **No visualization.** MLP is the only stage with purely static text and a diagram. Even a simple bar chart of "top activated neurons at layer X" would make this real. The handoff doc doesn't spec this out, but it's a gap.
|
| 111 |
+
- **Missing: the residual stream framing.** The glossary defines "Residual Stream" but the MLP stage doesn't mention that the MLP *adds* to the residual stream rather than replacing it. This is fundamental to why the model can accumulate knowledge across layers.
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## 6. Output Selection Stage
|
| 116 |
+
|
| 117 |
+
**What's good:**
|
| 118 |
+
- Top-5 next-token predictions with probability bars is exactly right
|
| 119 |
+
- The full-sentence context display with highlighted predicted token is excellent UX
|
| 120 |
+
- The "Note on Token Selection" callout about Beam Search and MoE is appropriately nuanced
|
| 121 |
+
|
| 122 |
+
**Issues:**
|
| 123 |
+
- **"13.5% confidence" framing is misleading.** "Confidence" implies certainty; this is a softmax probability, which is better described as "the model assigned a 13.5% probability to 'was' as the next word." Students may misread this as "the model is 13.5% confident it's right."
|
| 124 |
+
- **No contrast with wrong predictions.** The chart shows top-5 but doesn't explain *why* the model predicted "was" over "sat." A connection back to attribution ("The token 'cat' had the highest influence on predicting 'was'") would close the loop.
|
| 125 |
+
- **The token slider is unclear.** "Step through generated tokens" with a slider defaulting to 0 and showing "was" is confusing — it looks like there's nothing to step through. Label it: "Generated token 1 of 1: was" and grey out or hide the slider when only 1 token was generated.
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## 7. Token Attribution Panel
|
| 130 |
+
|
| 131 |
+
**What's good:**
|
| 132 |
+
- The visualization works well — darker tokens = more important is intuitive
|
| 133 |
+
- The bar chart with normalized attribution scores is clean
|
| 134 |
+
- Results matched expectations: "was" (the second "cat" token, position 9) scored 1.0, "The" scored 0.87 — sensible given the prompt structure
|
| 135 |
+
|
| 136 |
+
**Issues:**
|
| 137 |
+
- **"Simple Gradient" is selected by default, not "Integrated Gradients."** The UI labels Simple Gradient as "faster, less accurate" and Integrated Gradients as "more accurate, slower" — but defaults to the less accurate one. For an educational tool where accuracy matters more than speed, this should be reversed. Or at minimum, note: "For learning purposes, Integrated Gradients gives more reliable results."
|
| 138 |
+
- **No explanation of what attribution scores mean in plain English.** The callout says "Tokens with higher attribution scores contributed more to the model's prediction" — but students need: "The second 'cat' scored highest because the model is pattern-matching 'The cat...' to predict what typically follows 'The cat' in English text."
|
| 139 |
+
- **No visual connection to the actual attention visualization.** If "was" had high attribution from "cat," students should be able to click through to see which attention heads facilitated that. Right now attribution and attention are completely siloed.
|
| 140 |
+
- **Target Token dropdown is confusing.** "Use top predicted token (default)" is fine, but the empty text box below it with "Leave empty to compute attribution for the top predicted token" is redundant and confusing — why show a text box that you immediately tell them not to fill?
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## 8. Ablation Panel
|
| 145 |
+
|
| 146 |
+
**Issues (mostly UX):**
|
| 147 |
+
- **Ablation didn't show results in automated testing** — the head selection reset when switching tabs, suggesting state management issues between the Ablation and Attribution tabs.
|
| 148 |
+
- **No presets or suggestions.** The student faces a blank "Layer / Head" picker and has no idea which heads are interesting to ablate. The category panel above already identified previous-token heads (L4-H11, etc.) — there should be a "Try ablating this head" link from the category panel directly into the ablation form.
|
| 149 |
+
- **"Run Ablation Experiment" is permanently greyed out** until a head is added. The disabled state has no tooltip explaining why. Add: "Add at least one head above to run the experiment."
|
| 150 |
+
- **No explanation of what to expect.** Before running, tell students: "If this head is important, the top prediction may change. If it doesn't change, the head wasn't critical for this input."
|
| 151 |
+
- **No result interpretation.** After running (when it works), the diff between original and ablated predictions needs plain-English interpretation: "Removing L4-H11 changed 'was' (13.5%) → 'sat' (18.2%). This suggests that head was suppressing 'sat' as a prediction."
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## 9. Sidebar
|
| 156 |
+
|
| 157 |
+
**What's good:**
|
| 158 |
+
- The "Model loaded successfully! Detected family: GPT-2 architecture" green badge is good UX
|
| 159 |
+
- Module selection dropdowns (Attention Modules, Layer Blocks, Normalization Parameters) make sense for power users
|
| 160 |
+
|
| 161 |
+
**Issues:**
|
| 162 |
+
- **Sidebar purpose is unclear to students.** There's no explanation of what changing "Attention Modules" does or why a student would want to. This entire panel reads like a developer debug tool that was left exposed.
|
| 163 |
+
- **"Clear Selections" does what, exactly?** No tooltip.
|
| 164 |
+
- Consider: either hide the sidebar behind an "Advanced" toggle for student mode, or add inline documentation for each control.
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## 10. Chatbot (Robot Icon)
|
| 169 |
+
|
| 170 |
+
The robot icon is visible at bottom-right but the chat panel contents weren't captured in automated testing (JS error prevented inspection). Recommend manual review of the chatbot's response quality and whether it contextualizes responses to the current model/prompt state.
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## Priority Recommendations for Cursor
|
| 175 |
+
|
| 176 |
+
### 🔴 Critical (do these first)
|
| 177 |
+
|
| 178 |
+
1. **Fix attention head categorization thresholds.** First/Positional capturing 132/144 heads makes the entire category panel meaningless. Tighten the threshold, enforce the ~8-head cap per category from the spec, and add layer diversity. This is the highest-impact fix.
|
| 179 |
+
|
| 180 |
+
2. **Add the missing head categories.** Induction, Duplicate Token, and Diffuse are all specced in `attention_handoff.md` with detection logic. They need to be implemented. Induction is especially important for this exact prompt (repeated "The cat").
|
| 181 |
+
|
| 182 |
+
3. **Fix the modal close button off-screen bug.** Students can't close the glossary modal on standard laptop viewports. Easy CSS fix: `position: absolute; right: 16px` inside the modal container, not the document.
|
| 183 |
+
|
| 184 |
+
4. **Add a loading state after clicking Analyze.** 45 seconds of static "Awaiting analysis..." with no spinner is a UX failure. Add a pulsing animation or "Loading model..." progress message.
|
| 185 |
+
|
| 186 |
+
### 🟡 High Priority
|
| 187 |
+
|
| 188 |
+
5. **Connect head categories to the BertViz visualization.** Clicking a head ID (e.g., L4-H11) in the category panel should auto-select that head in the attention viz below.
|
| 189 |
+
|
| 190 |
+
6. **Add runtime activation scoring to head categories.** Per the spec: show whether each head type is active on the current input. Gray out induction heads if there's no repetition in the input, with a "Try: 'The cat sat. The cat'" suggested prompt.
|
| 191 |
+
|
| 192 |
+
7. **Add positional embeddings to the Embedding stage explanation.** Currently missing an entire half of what embeddings are.
|
| 193 |
+
|
| 194 |
+
8. **Fix ablation state management.** Head selections shouldn't reset when switching between Ablation and Attribution tabs.
|
| 195 |
+
|
| 196 |
+
9. **Change attribution default to Integrated Gradients.** It's the more accurate method; this is an educational tool, not a speed benchmark.
|
| 197 |
+
|
| 198 |
+
10. **Capitalize on the tokenization "aha" moment.** "The" (464) vs "the" (262) is sitting right there in the example. Call it out explicitly.
|
| 199 |
+
|
| 200 |
+
### 🟢 Enhancements
|
| 201 |
+
|
| 202 |
+
11. **Add guided "what to look for" text to the attention visualization.** Pick one interesting head per model (pre-annotated) and surface it as a recommendation: "Try Layer 4, Head 11 to see a previous-token head in action."
|
| 203 |
+
|
| 204 |
+
12. **Add suggested prompts for exploring each head category.** "To see induction heads activate, try: 'The cat sat on the mat. The cat...'"
|
| 205 |
+
|
| 206 |
+
13. **Reframe "confidence" in Output stage.** Replace with "probability" throughout.
|
| 207 |
+
|
| 208 |
+
14. **Link attribution results to attention heads.** "The token 'cat' was most influential — see which heads connected it to the prediction in the Attention stage."
|
| 209 |
+
|
| 210 |
+
15. **Fix the Output stage token slider** — hide or disable it when only 1 token was generated.
|
| 211 |
+
|
| 212 |
+
16. **Add a brief "what would you like to explore?" prompt to the ablation UI** with pre-suggested heads from the category panel.
|
| 213 |
+
|
| 214 |
+
17. **Sidebar: add explanatory text** for what Module Selection controls, or hide it in an "Advanced" section.
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
## What's Already Strong (Don't Break)
|
| 219 |
+
|
| 220 |
+
- The 5-stage pipeline structure and the flow chip bar — keep it exactly as is
|
| 221 |
+
- The BertViz integration — it works and the navigation instructions are clear
|
| 222 |
+
- The callout boxes in Embedding and MLP — these are the best explanation text in the app
|
| 223 |
+
- The token attribution visualization (darker = more important) — intuitive and correct
|
| 224 |
+
- The top-5 output prediction chart — exactly the right content
|
| 225 |
+
- The glossary modal content — all 8 entries are well-written
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
## Comparison to Handoff Spec
|
| 230 |
+
|
| 231 |
+
| Spec Feature | Status |
|
| 232 |
+
|---|---|
|
| 233 |
+
| 6 head categories (Previous Token, Induction, Duplicate, Positional, Diffuse, Other) | ⚠️ Partial — 3/6 missing, Positional broken |
|
| 234 |
+
| Per-head activation scores on current input | ❌ Not implemented |
|
| 235 |
+
| Active/inactive state display (filled vs open circle) | ❌ Not implemented |
|
| 236 |
+
| Greyed-out heads with suggested prompts | ❌ Not implemented |
|
| 237 |
+
| Click head → navigate to attention heatmap | ❌ Not implemented |
|
| 238 |
+
| Runtime verification module | ❌ Not implemented |
|
| 239 |
+
| One-time offline analysis script | ✅ Appears to have run (JSON exists) |
|
| 240 |
+
| Educational tooltips per category | ⚠️ Partial — descriptions exist but brief |
|
rag_docs/head_categories_explained.md
CHANGED
|
@@ -1,56 +1,58 @@
|
|
| 1 |
-
# Attention Head Categories
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
**
|
| 18 |
|
| 19 |
-
**
|
| 20 |
|
| 21 |
-
###
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
|
| 25 |
-
**
|
| 26 |
|
| 27 |
-
**
|
| 28 |
|
| 29 |
-
###
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
**
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
-
**What
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
-
##
|
| 46 |
|
| 47 |
-
**
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
Head categories are especially useful for guiding ablation experiments:
|
| 54 |
-
- Ablate a **Previous-Token** head to see if local context patterns break
|
| 55 |
-
- Ablate a **BoW** head to see if the model loses global context
|
| 56 |
-
- Compare the effect of ablating heads from different categories on the same prompt
|
|
|
|
| 1 |
+
# Attention Head Categories
|
| 2 |
|
| 3 |
+
This document explains the different types of attention heads found in transformer models. These categories are determined through **offline analysis** using TransformerLens and **verified at runtime** against your actual input.
|
| 4 |
|
| 5 |
+
## Categories
|
| 6 |
|
| 7 |
+
### Previous Token
|
| 8 |
+
**Symbol:** ● (active on most inputs)
|
| 9 |
|
| 10 |
+
Attends to the immediately preceding token — like reading left to right. This head helps the model track local word-by-word patterns. It's one of the most common and reliable head types.
|
| 11 |
|
| 12 |
+
**What to look for in the visualization:** Strong diagonal line one position below the main diagonal.
|
| 13 |
|
| 14 |
+
### Induction
|
| 15 |
+
**Symbol:** ● when repeated tokens exist, ○ otherwise
|
| 16 |
|
| 17 |
+
Completes repeated patterns: if the model saw [A][B] before and now sees [A], it predicts [B] will follow. This is one of the most important mechanisms in transformer language models.
|
| 18 |
|
| 19 |
+
**Requires:** Repeated tokens in your input. If no tokens repeat, this category appears grayed out.
|
| 20 |
|
| 21 |
+
**Try this prompt:** "The cat sat on the mat. The cat" — the repeated "The cat" activates induction heads.
|
| 22 |
|
| 23 |
+
### Duplicate Token
|
| 24 |
+
**Symbol:** ● when duplicate tokens exist, ○ otherwise
|
| 25 |
|
| 26 |
+
Notices when the same word appears more than once, acting like a highlighter for repeated words. Helps the model track which words have already been said.
|
| 27 |
|
| 28 |
+
**Requires:** Repeated tokens in your input.
|
| 29 |
|
| 30 |
+
**Try this prompt:** "The cat sat. The cat slept." — the repeated words activate duplicate-token heads.
|
| 31 |
|
| 32 |
+
### Positional / First-Token
|
| 33 |
+
**Symbol:** ● (active on most inputs)
|
| 34 |
|
| 35 |
+
Always pays attention to the very first word, using it as a fixed anchor point. The first token often serves as a "default" position when no specific token is relevant.
|
| 36 |
|
| 37 |
+
**What to look for:** Strong vertical line at column 0 (all tokens attending to position 0).
|
| 38 |
|
| 39 |
+
### Diffuse / Spread
|
| 40 |
+
**Symbol:** ● (active on most inputs)
|
| 41 |
|
| 42 |
+
Spreads attention evenly across many words, gathering general context rather than focusing on one spot. Provides a "big picture" summary of the input.
|
| 43 |
|
| 44 |
+
**What to look for:** No strong patterns — attention is spread roughly evenly across all tokens.
|
| 45 |
|
| 46 |
+
### Other / Unclassified
|
| 47 |
|
| 48 |
+
Heads whose dominant pattern doesn't fit the categories above. These may perform more complex or context-dependent operations.
|
| 49 |
|
| 50 |
+
## How It Works
|
| 51 |
|
| 52 |
+
1. **Offline Analysis:** A TransformerLens script analyzes each head across many test inputs and assigns categories based on dominant behavior patterns.
|
| 53 |
+
2. **Runtime Verification:** When you enter a prompt, the app checks whether each head's known role is actually active on your specific input.
|
| 54 |
+
3. **Active vs Inactive:** A filled circle (●) means the head's role is triggered. An open circle (○) means the role exists but isn't triggered on your current input (e.g., no repeated tokens for induction).
|
| 55 |
|
| 56 |
+
## Important Note
|
| 57 |
|
| 58 |
+
These categories are simplified labels based on each head's dominant behavior pattern. In reality, attention heads can serve multiple roles and may behave differently depending on the input.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/analyze_heads.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Offline Head Analysis Script
|
| 4 |
+
|
| 5 |
+
Uses TransformerLens to analyze attention head behaviors across test inputs
|
| 6 |
+
and generates a JSON file with head categories for each model.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/analyze_heads.py --model gpt2
|
| 10 |
+
python scripts/analyze_heads.py --model gpt2 gpt2-medium EleutherAI/pythia-70m
|
| 11 |
+
python scripts/analyze_heads.py --all
|
| 12 |
+
|
| 13 |
+
Output:
|
| 14 |
+
Writes to utils/head_categories.json
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
os.environ["USE_TF"] = "0" # Prevent TensorFlow noise
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import json
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Dict, List, Any, Tuple
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
# Add project root to path
|
| 31 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 32 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 33 |
+
|
| 34 |
+
JSON_OUTPUT_PATH = PROJECT_ROOT / "utils" / "head_categories.json"
|
| 35 |
+
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# TransformerLens model name mapping
|
| 38 |
+
# ============================================================================
|
| 39 |
+
# TL uses its own naming conventions. Map from HuggingFace names
|
| 40 |
+
# (used in our model_config.py) to TL names.
|
| 41 |
+
|
| 42 |
+
HF_TO_TL_NAME = {
|
| 43 |
+
"gpt2": "gpt2-small",
|
| 44 |
+
"openai-community/gpt2": "gpt2-small",
|
| 45 |
+
"gpt2-medium": "gpt2-medium",
|
| 46 |
+
"openai-community/gpt2-medium": "gpt2-medium",
|
| 47 |
+
"gpt2-large": "gpt2-large",
|
| 48 |
+
"openai-community/gpt2-large": "gpt2-large",
|
| 49 |
+
"gpt2-xl": "gpt2-xl",
|
| 50 |
+
"openai-community/gpt2-xl": "gpt2-xl",
|
| 51 |
+
"EleutherAI/pythia-70m": "pythia-70m",
|
| 52 |
+
"EleutherAI/pythia-160m": "pythia-160m",
|
| 53 |
+
"EleutherAI/pythia-410m": "pythia-410m",
|
| 54 |
+
"EleutherAI/pythia-1b": "pythia-1b",
|
| 55 |
+
"EleutherAI/pythia-1.4b": "pythia-1.4b",
|
| 56 |
+
"facebook/opt-125m": "opt-125m",
|
| 57 |
+
"facebook/opt-350m": "opt-350m",
|
| 58 |
+
"facebook/opt-1.3b": "opt-1.3b",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Default models to analyze
|
| 62 |
+
DEFAULT_MODELS = ["gpt2"]
|
| 63 |
+
|
| 64 |
+
ALL_PRIORITY_MODELS = [
|
| 65 |
+
"gpt2",
|
| 66 |
+
"gpt2-medium",
|
| 67 |
+
"EleutherAI/pythia-70m",
|
| 68 |
+
"EleutherAI/pythia-160m",
|
| 69 |
+
"EleutherAI/pythia-410m",
|
| 70 |
+
"facebook/opt-125m",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# ============================================================================
|
| 74 |
+
# Category metadata (shared across all models)
|
| 75 |
+
# ============================================================================
|
| 76 |
+
|
| 77 |
+
CATEGORY_METADATA = {
|
| 78 |
+
"previous_token": {
|
| 79 |
+
"display_name": "Previous Token",
|
| 80 |
+
"description": "Attends to the immediately preceding token — like reading left to right",
|
| 81 |
+
"icon": "arrow-left",
|
| 82 |
+
"educational_text": "This head looks at the word right before the current one. Like reading left to right, it helps track local word-by-word patterns.",
|
| 83 |
+
"requires_repetition": False,
|
| 84 |
+
},
|
| 85 |
+
"induction": {
|
| 86 |
+
"display_name": "Induction",
|
| 87 |
+
"description": "Completes repeated patterns: if it saw [A][B] before and now sees [A], it predicts [B]",
|
| 88 |
+
"icon": "repeat",
|
| 89 |
+
"educational_text": "This head finds patterns that happened before and predicts they'll happen again. If it saw 'the cat' earlier, it expects the same words to follow.",
|
| 90 |
+
"requires_repetition": True,
|
| 91 |
+
"suggested_prompt": "Try: 'The cat sat on the mat. The cat' — the repeated 'The cat' lets induction heads activate.",
|
| 92 |
+
},
|
| 93 |
+
"duplicate_token": {
|
| 94 |
+
"display_name": "Duplicate Token",
|
| 95 |
+
"description": "Notices when the same word appears more than once",
|
| 96 |
+
"icon": "clone",
|
| 97 |
+
"educational_text": "This head notices when the same word appears more than once, like a highlighter for repeated words. It helps the model track which words have already been said.",
|
| 98 |
+
"requires_repetition": True,
|
| 99 |
+
"suggested_prompt": "Try a prompt with repeated words like 'The cat sat. The cat slept.' to see duplicate-token heads light up.",
|
| 100 |
+
},
|
| 101 |
+
"positional": {
|
| 102 |
+
"display_name": "Positional / First-Token",
|
| 103 |
+
"description": "Always pays attention to the very first word, using it as an anchor point",
|
| 104 |
+
"icon": "map-pin",
|
| 105 |
+
"educational_text": "This head always pays attention to the very first word, using it as an anchor point. The first token serves as a 'default' position when no other token is specifically relevant.",
|
| 106 |
+
"requires_repetition": False,
|
| 107 |
+
},
|
| 108 |
+
"diffuse": {
|
| 109 |
+
"display_name": "Diffuse / Spread",
|
| 110 |
+
"description": "Spreads attention evenly across many words, gathering general context",
|
| 111 |
+
"icon": "expand-arrows-alt",
|
| 112 |
+
"educational_text": "This head spreads its attention evenly across many words, gathering general context rather than focusing on one spot. It provides a 'big picture' summary of the input.",
|
| 113 |
+
"requires_repetition": False,
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ============================================================================
|
| 119 |
+
# Test input generation
|
| 120 |
+
# ============================================================================
|
| 121 |
+
|
| 122 |
+
def generate_test_inputs(tokenizer) -> Dict[str, List[str]]:
|
| 123 |
+
"""Generate categorized test inputs for head analysis."""
|
| 124 |
+
|
| 125 |
+
# Natural language prompts for general analysis
|
| 126 |
+
natural_prompts = [
|
| 127 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 128 |
+
"In the beginning, there was nothing but darkness and silence.",
|
| 129 |
+
"Machine learning models process data to make predictions about the future.",
|
| 130 |
+
"She walked through the park and noticed the flowers blooming everywhere.",
|
| 131 |
+
"The president announced new economic policies at the press conference today.",
|
| 132 |
+
"After years of research, scientists finally discovered the missing link.",
|
| 133 |
+
"The library was quiet except for the occasional turning of pages.",
|
| 134 |
+
"Programming is both an art and a science requiring careful thought.",
|
| 135 |
+
"The restaurant on the corner served the best pizza in the entire city.",
|
| 136 |
+
"Education is the most powerful tool for changing the world around us.",
|
| 137 |
+
"The storm clouds gathered on the horizon as the wind began to howl.",
|
| 138 |
+
"Mathematics provides the foundation for understanding complex physical phenomena.",
|
| 139 |
+
"The children played happily in the garden while their parents watched.",
|
| 140 |
+
"Economic growth depends on innovation, investment, and human capital development.",
|
| 141 |
+
"The old man sat on the bench and watched the pigeons gather crumbs.",
|
| 142 |
+
"Artificial intelligence will transform every industry in the coming decades.",
|
| 143 |
+
"The river flowed gently through the valley between the tall mountains.",
|
| 144 |
+
"Good communication skills are essential for success in any professional career.",
|
| 145 |
+
"The concert hall was packed with enthusiastic fans waiting for the show.",
|
| 146 |
+
"Climate change poses significant challenges for agriculture and food security.",
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
# Repetitive prompts for induction / duplicate detection
|
| 150 |
+
repetitive_prompts = [
|
| 151 |
+
"The cat sat on the mat. The cat sat on the mat.",
|
| 152 |
+
"One two three four five. One two three four five.",
|
| 153 |
+
"Hello world hello world hello world hello world.",
|
| 154 |
+
"Alice went to the store. Bob went to the store. Alice went to the store.",
|
| 155 |
+
"The dog chased the ball. The dog chased the ball. The dog chased.",
|
| 156 |
+
"Red blue green red blue green red blue green red.",
|
| 157 |
+
"I like apples and I like oranges and I like apples.",
|
| 158 |
+
"The sun rises in the east. The sun sets in the west. The sun rises.",
|
| 159 |
+
"Monday Tuesday Wednesday Monday Tuesday Wednesday Monday.",
|
| 160 |
+
"She said hello and he said hello and she said hello again.",
|
| 161 |
+
"The key to success is practice. The key to success is patience.",
|
| 162 |
+
"We went to the park and then we went to the park again.",
|
| 163 |
+
"First second third first second third first second third.",
|
| 164 |
+
"The teacher asked the student. The student asked the teacher. The teacher asked.",
|
| 165 |
+
"North south east west north south east west north south.",
|
| 166 |
+
"Open the door. Close the door. Open the door. Close the door.",
|
| 167 |
+
"The big red ball bounced. The big red ball rolled.",
|
| 168 |
+
"Cat dog cat dog cat dog cat dog cat dog.",
|
| 169 |
+
"Learn practice improve learn practice improve learn practice.",
|
| 170 |
+
"The man walked. The woman walked. The man walked. The woman walked.",
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"natural": natural_prompts,
|
| 175 |
+
"repetitive": repetitive_prompts,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ============================================================================
|
| 180 |
+
# Head scoring functions
|
| 181 |
+
# ============================================================================
|
| 182 |
+
|
| 183 |
+
def score_previous_token(attn_patterns: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
"""
|
| 185 |
+
Score each head for previous-token behavior.
|
| 186 |
+
|
| 187 |
+
For each position i > 0, check attention to position i-1.
|
| 188 |
+
Returns [n_layers, n_heads] scores.
|
| 189 |
+
"""
|
| 190 |
+
n_layers, n_heads, seq_len, _ = attn_patterns.shape
|
| 191 |
+
|
| 192 |
+
if seq_len < 2:
|
| 193 |
+
return torch.zeros(n_layers, n_heads)
|
| 194 |
+
|
| 195 |
+
scores = torch.zeros(n_layers, n_heads)
|
| 196 |
+
for i in range(1, seq_len):
|
| 197 |
+
scores += attn_patterns[:, :, i, i - 1]
|
| 198 |
+
scores /= (seq_len - 1)
|
| 199 |
+
|
| 200 |
+
return scores
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def score_positional(attn_patterns: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
"""
|
| 205 |
+
Score each head for first-token / positional behavior.
|
| 206 |
+
|
| 207 |
+
Measures mean attention to position 0 across all positions.
|
| 208 |
+
Returns [n_layers, n_heads] scores.
|
| 209 |
+
"""
|
| 210 |
+
# Mean of column 0 across all query positions
|
| 211 |
+
return attn_patterns[:, :, :, 0].mean(dim=-1)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def score_diffuse(attn_patterns: torch.Tensor) -> torch.Tensor:
|
| 215 |
+
"""
|
| 216 |
+
Score each head for diffuse / bag-of-words behavior.
|
| 217 |
+
|
| 218 |
+
Measures normalized entropy of attention distribution.
|
| 219 |
+
Returns [n_layers, n_heads] scores.
|
| 220 |
+
"""
|
| 221 |
+
n_layers, n_heads, seq_len, _ = attn_patterns.shape
|
| 222 |
+
|
| 223 |
+
epsilon = 1e-10
|
| 224 |
+
p = attn_patterns + epsilon
|
| 225 |
+
entropy = -torch.sum(p * torch.log(p), dim=-1) # [layers, heads, seq_len]
|
| 226 |
+
max_entropy = np.log(seq_len)
|
| 227 |
+
normalized = entropy / max_entropy if max_entropy > 0 else entropy
|
| 228 |
+
|
| 229 |
+
return normalized.mean(dim=-1) # Average over positions
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def score_induction(attn_patterns: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
Score each head for induction behavior.
|
| 235 |
+
|
| 236 |
+
For repeated tokens: if token[i] == token[j] (j < i), check attention from i to j+1.
|
| 237 |
+
Returns [n_layers, n_heads] scores.
|
| 238 |
+
"""
|
| 239 |
+
n_layers, n_heads, seq_len, _ = attn_patterns.shape
|
| 240 |
+
scores = torch.zeros(n_layers, n_heads)
|
| 241 |
+
count = 0
|
| 242 |
+
|
| 243 |
+
for i in range(2, seq_len):
|
| 244 |
+
for j in range(0, i - 1):
|
| 245 |
+
if tokens[i].item() == tokens[j].item():
|
| 246 |
+
target = j + 1
|
| 247 |
+
if target < seq_len:
|
| 248 |
+
scores += attn_patterns[:, :, i, target]
|
| 249 |
+
count += 1
|
| 250 |
+
|
| 251 |
+
if count > 0:
|
| 252 |
+
scores /= count
|
| 253 |
+
|
| 254 |
+
return scores
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def score_duplicate_token(attn_patterns: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
| 258 |
+
"""
|
| 259 |
+
Score each head for duplicate-token behavior.
|
| 260 |
+
|
| 261 |
+
For repeated tokens: check attention from later to earlier occurrence.
|
| 262 |
+
Returns [n_layers, n_heads] scores.
|
| 263 |
+
"""
|
| 264 |
+
n_layers, n_heads, seq_len, _ = attn_patterns.shape
|
| 265 |
+
scores = torch.zeros(n_layers, n_heads)
|
| 266 |
+
count = 0
|
| 267 |
+
|
| 268 |
+
for i in range(1, seq_len):
|
| 269 |
+
for j in range(0, i):
|
| 270 |
+
if tokens[i].item() == tokens[j].item():
|
| 271 |
+
scores += attn_patterns[:, :, i, j]
|
| 272 |
+
count += 1
|
| 273 |
+
|
| 274 |
+
if count > 0:
|
| 275 |
+
scores /= count
|
| 276 |
+
|
| 277 |
+
return scores
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# ============================================================================
|
| 281 |
+
# Main analysis
|
| 282 |
+
# ============================================================================
|
| 283 |
+
|
| 284 |
+
def analyze_model(model_name: str, device: str = "cpu") -> Dict[str, Any]:
|
| 285 |
+
"""
|
| 286 |
+
Run full head analysis for a model.
|
| 287 |
+
|
| 288 |
+
Returns a dict ready for JSON serialization.
|
| 289 |
+
"""
|
| 290 |
+
from transformer_lens import HookedTransformer
|
| 291 |
+
|
| 292 |
+
tl_name = HF_TO_TL_NAME.get(model_name, model_name)
|
| 293 |
+
print(f"\n{'='*60}")
|
| 294 |
+
print(f"Analyzing: {model_name} (TL name: {tl_name})")
|
| 295 |
+
print(f"{'='*60}")
|
| 296 |
+
|
| 297 |
+
print("Loading model...")
|
| 298 |
+
model = HookedTransformer.from_pretrained(tl_name, device=device)
|
| 299 |
+
|
| 300 |
+
n_layers = model.cfg.n_layers
|
| 301 |
+
n_heads = model.cfg.n_heads
|
| 302 |
+
print(f" Layers: {n_layers}, Heads per layer: {n_heads}")
|
| 303 |
+
|
| 304 |
+
# Generate test inputs
|
| 305 |
+
test_inputs = generate_test_inputs(model.tokenizer)
|
| 306 |
+
|
| 307 |
+
# Accumulators for scores
|
| 308 |
+
prev_token_scores = torch.zeros(n_layers, n_heads)
|
| 309 |
+
positional_scores = torch.zeros(n_layers, n_heads)
|
| 310 |
+
diffuse_scores = torch.zeros(n_layers, n_heads)
|
| 311 |
+
induction_scores = torch.zeros(n_layers, n_heads)
|
| 312 |
+
duplicate_scores = torch.zeros(n_layers, n_heads)
|
| 313 |
+
|
| 314 |
+
natural_count = 0
|
| 315 |
+
repetitive_count = 0
|
| 316 |
+
|
| 317 |
+
# Analyze natural prompts (for prev_token, positional, diffuse)
|
| 318 |
+
print("\nAnalyzing natural prompts...")
|
| 319 |
+
for prompt in test_inputs["natural"]:
|
| 320 |
+
try:
|
| 321 |
+
tokens = model.to_tokens(prompt)
|
| 322 |
+
if tokens.shape[1] < 3:
|
| 323 |
+
continue
|
| 324 |
+
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
_, cache = model.run_with_cache(tokens)
|
| 327 |
+
|
| 328 |
+
# Stack attention patterns: [n_layers, n_heads, seq_len, seq_len]
|
| 329 |
+
attn_patterns = torch.stack([
|
| 330 |
+
cache["pattern", layer][0] # Remove batch dim
|
| 331 |
+
for layer in range(n_layers)
|
| 332 |
+
])
|
| 333 |
+
|
| 334 |
+
prev_token_scores += score_previous_token(attn_patterns)
|
| 335 |
+
positional_scores += score_positional(attn_patterns)
|
| 336 |
+
diffuse_scores += score_diffuse(attn_patterns)
|
| 337 |
+
natural_count += 1
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f" Warning: Skipped prompt: {e}")
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
print(f" Processed {natural_count} natural prompts")
|
| 344 |
+
|
| 345 |
+
# Analyze repetitive prompts (for induction + duplicate)
|
| 346 |
+
print("Analyzing repetitive prompts...")
|
| 347 |
+
for prompt in test_inputs["repetitive"]:
|
| 348 |
+
try:
|
| 349 |
+
tokens = model.to_tokens(prompt)
|
| 350 |
+
if tokens.shape[1] < 4:
|
| 351 |
+
continue
|
| 352 |
+
|
| 353 |
+
with torch.no_grad():
|
| 354 |
+
_, cache = model.run_with_cache(tokens)
|
| 355 |
+
|
| 356 |
+
attn_patterns = torch.stack([
|
| 357 |
+
cache["pattern", layer][0]
|
| 358 |
+
for layer in range(n_layers)
|
| 359 |
+
])
|
| 360 |
+
|
| 361 |
+
induction_scores += score_induction(attn_patterns, tokens[0])
|
| 362 |
+
duplicate_scores += score_duplicate_token(attn_patterns, tokens[0])
|
| 363 |
+
|
| 364 |
+
# Also accumulate general scores for these prompts
|
| 365 |
+
prev_token_scores += score_previous_token(attn_patterns)
|
| 366 |
+
positional_scores += score_positional(attn_patterns)
|
| 367 |
+
diffuse_scores += score_diffuse(attn_patterns)
|
| 368 |
+
natural_count += 1
|
| 369 |
+
|
| 370 |
+
repetitive_count += 1
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
print(f" Warning: Skipped prompt: {e}")
|
| 374 |
+
continue
|
| 375 |
+
|
| 376 |
+
print(f" Processed {repetitive_count} repetitive prompts")
|
| 377 |
+
|
| 378 |
+
# Average scores
|
| 379 |
+
if natural_count > 0:
|
| 380 |
+
prev_token_scores /= natural_count
|
| 381 |
+
positional_scores /= natural_count
|
| 382 |
+
diffuse_scores /= natural_count
|
| 383 |
+
if repetitive_count > 0:
|
| 384 |
+
induction_scores /= repetitive_count
|
| 385 |
+
duplicate_scores /= repetitive_count
|
| 386 |
+
|
| 387 |
+
# Select top heads per category
|
| 388 |
+
all_category_scores = {
|
| 389 |
+
"previous_token": prev_token_scores,
|
| 390 |
+
"induction": induction_scores,
|
| 391 |
+
"duplicate_token": duplicate_scores,
|
| 392 |
+
"positional": positional_scores,
|
| 393 |
+
"diffuse": diffuse_scores,
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
# Print score summaries
|
| 397 |
+
print("\nScore summaries (max per category):")
|
| 398 |
+
for cat_name, scores in all_category_scores.items():
|
| 399 |
+
max_score = scores.max().item()
|
| 400 |
+
max_idx = scores.argmax()
|
| 401 |
+
max_layer = max_idx // n_heads
|
| 402 |
+
max_head = max_idx % n_heads
|
| 403 |
+
print(f" {cat_name:20s}: max={max_score:.4f} at L{max_layer}-H{max_head}")
|
| 404 |
+
|
| 405 |
+
# Build category data
|
| 406 |
+
categories_data = {}
|
| 407 |
+
|
| 408 |
+
for cat_name, scores in all_category_scores.items():
|
| 409 |
+
top_heads = select_top_heads(scores, n_layers, n_heads, cat_name)
|
| 410 |
+
|
| 411 |
+
cat_entry = dict(CATEGORY_METADATA[cat_name])
|
| 412 |
+
cat_entry["top_heads"] = top_heads
|
| 413 |
+
categories_data[cat_name] = cat_entry
|
| 414 |
+
|
| 415 |
+
print(f"\n {cat_name} ({len(top_heads)} heads):")
|
| 416 |
+
for h in top_heads:
|
| 417 |
+
print(f" L{h['layer']}-H{h['head']}: {h['score']:.4f}")
|
| 418 |
+
|
| 419 |
+
# Build the full model entry
|
| 420 |
+
model_entry = {
|
| 421 |
+
"model_name": model_name,
|
| 422 |
+
"num_layers": n_layers,
|
| 423 |
+
"num_heads": n_heads,
|
| 424 |
+
"analysis_date": time.strftime("%Y-%m-%d"),
|
| 425 |
+
"categories": categories_data,
|
| 426 |
+
"all_scores": {
|
| 427 |
+
cat: scores.tolist()
|
| 428 |
+
for cat, scores in all_category_scores.items()
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
return model_entry
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def select_top_heads(
|
| 436 |
+
scores: torch.Tensor,
|
| 437 |
+
n_layers: int,
|
| 438 |
+
n_heads: int,
|
| 439 |
+
category: str,
|
| 440 |
+
max_heads: int = 8,
|
| 441 |
+
primary_threshold: float = 0.25,
|
| 442 |
+
min_threshold: float = 0.10,
|
| 443 |
+
) -> List[Dict[str, Any]]:
|
| 444 |
+
"""
|
| 445 |
+
Select the top heads for a category, enforcing layer diversity.
|
| 446 |
+
|
| 447 |
+
Strategy:
|
| 448 |
+
1. Take all heads above primary_threshold
|
| 449 |
+
2. Ensure we include the best head from each layer above min_threshold
|
| 450 |
+
3. Cap at max_heads, keeping highest scores
|
| 451 |
+
"""
|
| 452 |
+
candidates = []
|
| 453 |
+
|
| 454 |
+
for layer in range(n_layers):
|
| 455 |
+
for head in range(n_heads):
|
| 456 |
+
score = scores[layer, head].item()
|
| 457 |
+
if score > min_threshold:
|
| 458 |
+
candidates.append({
|
| 459 |
+
"layer": layer,
|
| 460 |
+
"head": head,
|
| 461 |
+
"score": round(score, 4),
|
| 462 |
+
})
|
| 463 |
+
|
| 464 |
+
# Sort by score descending
|
| 465 |
+
candidates.sort(key=lambda x: x["score"], reverse=True)
|
| 466 |
+
|
| 467 |
+
# Select: prioritize above primary_threshold, then fill with layer diversity
|
| 468 |
+
selected = []
|
| 469 |
+
selected_keys = set()
|
| 470 |
+
layers_covered = set()
|
| 471 |
+
|
| 472 |
+
# First pass: take all above primary threshold
|
| 473 |
+
for c in candidates:
|
| 474 |
+
if c["score"] >= primary_threshold and len(selected) < max_heads:
|
| 475 |
+
key = (c["layer"], c["head"])
|
| 476 |
+
if key not in selected_keys:
|
| 477 |
+
selected.append(c)
|
| 478 |
+
selected_keys.add(key)
|
| 479 |
+
layers_covered.add(c["layer"])
|
| 480 |
+
|
| 481 |
+
# Second pass: ensure layer diversity (best from each uncovered layer)
|
| 482 |
+
for c in candidates:
|
| 483 |
+
if len(selected) >= max_heads:
|
| 484 |
+
break
|
| 485 |
+
if c["layer"] not in layers_covered:
|
| 486 |
+
key = (c["layer"], c["head"])
|
| 487 |
+
if key not in selected_keys:
|
| 488 |
+
selected.append(c)
|
| 489 |
+
selected_keys.add(key)
|
| 490 |
+
layers_covered.add(c["layer"])
|
| 491 |
+
|
| 492 |
+
# Sort final result by layer, then head
|
| 493 |
+
selected.sort(key=lambda x: (x["layer"], x["head"]))
|
| 494 |
+
|
| 495 |
+
return selected[:max_heads]
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# ============================================================================
|
| 499 |
+
# CLI
|
| 500 |
+
# ============================================================================
|
| 501 |
+
|
| 502 |
+
def main():
|
| 503 |
+
parser = argparse.ArgumentParser(description="Analyze attention head categories using TransformerLens")
|
| 504 |
+
parser.add_argument("--model", nargs="+", default=None,
|
| 505 |
+
help="HuggingFace model name(s) to analyze (e.g., gpt2, EleutherAI/pythia-70m)")
|
| 506 |
+
parser.add_argument("--all", action="store_true",
|
| 507 |
+
help="Analyze all priority models")
|
| 508 |
+
parser.add_argument("--device", default="cpu",
|
| 509 |
+
help="Device to use (cpu or cuda)")
|
| 510 |
+
parser.add_argument("--output", type=str, default=None,
|
| 511 |
+
help="Output JSON path (default: utils/head_categories.json)")
|
| 512 |
+
args = parser.parse_args()
|
| 513 |
+
|
| 514 |
+
# Determine models to analyze
|
| 515 |
+
if args.all:
|
| 516 |
+
models = ALL_PRIORITY_MODELS
|
| 517 |
+
elif args.model:
|
| 518 |
+
models = args.model
|
| 519 |
+
else:
|
| 520 |
+
models = DEFAULT_MODELS
|
| 521 |
+
|
| 522 |
+
output_path = Path(args.output) if args.output else JSON_OUTPUT_PATH
|
| 523 |
+
|
| 524 |
+
# Load existing data if present
|
| 525 |
+
existing_data = {}
|
| 526 |
+
if output_path.exists():
|
| 527 |
+
try:
|
| 528 |
+
with open(output_path, 'r') as f:
|
| 529 |
+
existing_data = json.load(f)
|
| 530 |
+
print(f"Loaded existing data from {output_path} ({len(existing_data)} models)")
|
| 531 |
+
except (json.JSONDecodeError, IOError):
|
| 532 |
+
pass
|
| 533 |
+
|
| 534 |
+
# Analyze each model
|
| 535 |
+
for model_name in models:
|
| 536 |
+
try:
|
| 537 |
+
result = analyze_model(model_name, device=args.device)
|
| 538 |
+
|
| 539 |
+
# Store under the HuggingFace name
|
| 540 |
+
existing_data[model_name] = result
|
| 541 |
+
|
| 542 |
+
# Also store under the short name for lookup
|
| 543 |
+
short_name = model_name.split('/')[-1] if '/' in model_name else None
|
| 544 |
+
if short_name and short_name != model_name:
|
| 545 |
+
existing_data[short_name] = result
|
| 546 |
+
|
| 547 |
+
except Exception as e:
|
| 548 |
+
print(f"\nERROR analyzing {model_name}: {e}")
|
| 549 |
+
import traceback
|
| 550 |
+
traceback.print_exc()
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
# Write output
|
| 554 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 555 |
+
with open(output_path, 'w') as f:
|
| 556 |
+
json.dump(existing_data, f, indent=2)
|
| 557 |
+
|
| 558 |
+
print(f"\n{'='*60}")
|
| 559 |
+
print(f"Done! Wrote {len(existing_data)} model entries to {output_path}")
|
| 560 |
+
print(f"{'='*60}")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
if __name__ == "__main__":
|
| 564 |
+
main()
|
tests/conftest.py
CHANGED
|
@@ -199,12 +199,4 @@ def mock_attribution_result():
|
|
| 199 |
}
|
| 200 |
|
| 201 |
|
| 202 |
-
# =============================================================================
|
| 203 |
-
# Head Categorization Config
|
| 204 |
-
# =============================================================================
|
| 205 |
|
| 206 |
-
@pytest.fixture
|
| 207 |
-
def default_head_config():
|
| 208 |
-
"""Default head categorization configuration for testing."""
|
| 209 |
-
from utils.head_detection import HeadCategorizationConfig
|
| 210 |
-
return HeadCategorizationConfig()
|
|
|
|
| 199 |
}
|
| 200 |
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_head_detection.py
CHANGED
|
@@ -1,313 +1,431 @@
|
|
| 1 |
"""
|
| 2 |
Tests for utils/head_detection.py
|
| 3 |
|
| 4 |
-
Tests
|
| 5 |
"""
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
import torch
|
|
|
|
| 9 |
import numpy as np
|
|
|
|
|
|
|
| 10 |
from utils.head_detection import (
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
categorize_all_heads,
|
| 18 |
-
format_categorization_summary,
|
| 19 |
-
HeadCategorizationConfig
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
class TestComputeAttentionEntropy:
|
| 24 |
-
"""Tests for
|
| 25 |
-
|
| 26 |
def test_uniform_distribution_high_entropy(self):
|
| 27 |
-
"""Uniform attention should have
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
entropy
|
| 31 |
-
|
| 32 |
-
# Normalized entropy should be close to 1.0 for uniform
|
| 33 |
-
assert 0.95 <= entropy <= 1.0, f"Expected ~1.0, got {entropy}"
|
| 34 |
-
|
| 35 |
def test_peaked_distribution_low_entropy(self):
|
| 36 |
-
"""Peaked attention should have low
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
assert
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
default_head_config
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
assert is_prev == False
|
| 78 |
-
assert score < 0.4, f"Expected low score, got {score}"
|
| 79 |
-
|
| 80 |
-
def test_short_sequence_returns_false(self, default_head_config):
|
| 81 |
-
"""Sequence shorter than min_seq_len should return False."""
|
| 82 |
-
short_matrix = torch.ones(2, 2) / 2
|
| 83 |
-
is_prev, score = detect_previous_token_head(short_matrix, default_head_config)
|
| 84 |
-
|
| 85 |
-
assert is_prev == False
|
| 86 |
-
assert score == 0.0
|
| 87 |
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
assert
|
| 100 |
-
assert
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
matrix[i, 0] = 0.05
|
| 111 |
-
matrix[i, -1] = 0.95
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
assert
|
| 116 |
-
assert score < 0.25, f"Expected low score, got {score}"
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
is_bow, score = detect_bow_head(uniform_attention_matrix, default_head_config)
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def test_consistent_distance_pattern(self, default_head_config):
|
| 142 |
-
"""Matrix with consistent distance pattern should be detected as syntactic."""
|
| 143 |
-
# Create matrix where each position attends to position 2 tokens back
|
| 144 |
size = 6
|
| 145 |
matrix = torch.zeros(size, size)
|
| 146 |
-
for i in range(size):
|
| 147 |
-
|
| 148 |
-
matrix[i,
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
torch.
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
class TestCategorizeAttentionHead:
|
| 168 |
-
"""Tests for categorize_attention_head function."""
|
| 169 |
-
|
| 170 |
-
def test_categorizes_previous_token_head(self, previous_token_attention_matrix, default_head_config):
|
| 171 |
-
"""Should categorize previous-token pattern correctly."""
|
| 172 |
-
result = categorize_attention_head(
|
| 173 |
-
previous_token_attention_matrix,
|
| 174 |
-
layer_idx=0,
|
| 175 |
-
head_idx=3,
|
| 176 |
-
config=default_head_config
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
assert result['category'] == 'previous_token'
|
| 180 |
-
assert result['layer'] == 0
|
| 181 |
-
assert result['head'] == 3
|
| 182 |
-
assert result['label'] == 'L0-H3'
|
| 183 |
-
assert 'scores' in result
|
| 184 |
-
|
| 185 |
-
def test_categorizes_first_token_head(self, first_token_attention_matrix, default_head_config):
|
| 186 |
-
"""Should categorize first-token pattern correctly."""
|
| 187 |
-
result = categorize_attention_head(
|
| 188 |
-
first_token_attention_matrix,
|
| 189 |
-
layer_idx=2,
|
| 190 |
-
head_idx=5,
|
| 191 |
-
config=default_head_config
|
| 192 |
-
)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
size = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
matrix = torch.zeros(size, size)
|
| 203 |
for i in range(size):
|
| 204 |
-
|
| 205 |
-
matrix[i,
|
| 206 |
-
remaining = 0.9 / (size - 1)
|
| 207 |
-
for j in range(1, size):
|
| 208 |
-
matrix[i, j] = remaining
|
| 209 |
-
|
| 210 |
-
result = categorize_attention_head(
|
| 211 |
-
matrix,
|
| 212 |
-
layer_idx=1,
|
| 213 |
-
head_idx=0,
|
| 214 |
-
config=default_head_config
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
assert result['category'] == 'bow'
|
| 218 |
-
|
| 219 |
-
def test_result_structure(self, uniform_attention_matrix):
|
| 220 |
-
"""Result should have all required keys."""
|
| 221 |
-
result = categorize_attention_head(
|
| 222 |
-
uniform_attention_matrix,
|
| 223 |
-
layer_idx=0,
|
| 224 |
-
head_idx=0
|
| 225 |
-
)
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
| 240 |
-
for
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
class TestFormatCategorizationSummary:
|
| 255 |
-
"""Tests for format_categorization_summary function."""
|
| 256 |
-
|
| 257 |
-
def test_formats_empty_categorization(self):
|
| 258 |
-
"""Should format empty categorization without error."""
|
| 259 |
-
empty = {
|
| 260 |
-
'previous_token': [],
|
| 261 |
-
'first_token': [],
|
| 262 |
-
'bow': [],
|
| 263 |
-
'syntactic': [],
|
| 264 |
-
'other': []
|
| 265 |
}
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
{'layer': 0, 'head': 1, 'label': 'L0-H1'},
|
| 276 |
-
{'layer': 0, 'head': 2, 'label': 'L0-H2'},
|
| 277 |
-
],
|
| 278 |
-
'first_token': [
|
| 279 |
-
{'layer': 1, 'head': 0, 'label': 'L1-H0'},
|
| 280 |
-
],
|
| 281 |
-
'bow': [],
|
| 282 |
-
'syntactic': [],
|
| 283 |
-
'other': []
|
| 284 |
-
}
|
| 285 |
-
result = format_categorization_summary(categorized)
|
| 286 |
-
|
| 287 |
-
assert "Total Heads: 3" in result
|
| 288 |
-
assert "Previous-Token Heads: 2" in result
|
| 289 |
-
assert "First/Positional-Token Heads: 1" in result
|
| 290 |
-
assert "Layer 0" in result
|
| 291 |
-
assert "Layer 1" in result
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class TestHeadCategorizationConfig:
|
| 295 |
-
"""Tests for HeadCategorizationConfig defaults."""
|
| 296 |
-
|
| 297 |
-
def test_default_values(self):
|
| 298 |
-
"""Default config should have reasonable values."""
|
| 299 |
-
config = HeadCategorizationConfig()
|
| 300 |
|
| 301 |
-
assert
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
| 310 |
|
| 311 |
-
|
| 312 |
-
assert
|
| 313 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Tests for utils/head_detection.py
|
| 3 |
|
| 4 |
+
Tests the offline JSON + runtime verification head categorization system.
|
| 5 |
"""
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
import torch
|
| 9 |
+
import json
|
| 10 |
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from unittest.mock import patch, mock_open
|
| 13 |
from utils.head_detection import (
|
| 14 |
+
load_head_categories,
|
| 15 |
+
verify_head_activation,
|
| 16 |
+
get_active_head_summary,
|
| 17 |
+
clear_category_cache,
|
| 18 |
+
_compute_attention_entropy,
|
| 19 |
+
_find_repeated_tokens,
|
|
|
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
|
| 23 |
+
# =============================================================================
|
| 24 |
+
# Sample JSON data for mocking
|
| 25 |
+
# =============================================================================
|
| 26 |
+
|
| 27 |
+
SAMPLE_JSON = {
|
| 28 |
+
"test-model": {
|
| 29 |
+
"model_name": "test-model",
|
| 30 |
+
"num_layers": 2,
|
| 31 |
+
"num_heads": 4,
|
| 32 |
+
"analysis_date": "2026-02-26",
|
| 33 |
+
"categories": {
|
| 34 |
+
"previous_token": {
|
| 35 |
+
"display_name": "Previous Token",
|
| 36 |
+
"description": "Attends to the previous token",
|
| 37 |
+
"educational_text": "Looks at the word before.",
|
| 38 |
+
"icon": "arrow-left",
|
| 39 |
+
"requires_repetition": False,
|
| 40 |
+
"top_heads": [
|
| 41 |
+
{"layer": 0, "head": 1, "score": 0.85},
|
| 42 |
+
{"layer": 1, "head": 2, "score": 0.72}
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
"induction": {
|
| 46 |
+
"display_name": "Induction",
|
| 47 |
+
"description": "Pattern matching",
|
| 48 |
+
"educational_text": "Finds repeated patterns.",
|
| 49 |
+
"icon": "repeat",
|
| 50 |
+
"requires_repetition": True,
|
| 51 |
+
"suggested_prompt": "Try repeating words.",
|
| 52 |
+
"top_heads": [
|
| 53 |
+
{"layer": 1, "head": 0, "score": 0.90}
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
"duplicate_token": {
|
| 57 |
+
"display_name": "Duplicate Token",
|
| 58 |
+
"description": "Finds duplicates",
|
| 59 |
+
"educational_text": "Spots repeated words.",
|
| 60 |
+
"icon": "clone",
|
| 61 |
+
"requires_repetition": True,
|
| 62 |
+
"suggested_prompt": "Try typing the same word twice.",
|
| 63 |
+
"top_heads": [
|
| 64 |
+
{"layer": 0, "head": 3, "score": 0.78}
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
"positional": {
|
| 68 |
+
"display_name": "Positional",
|
| 69 |
+
"description": "First token focus",
|
| 70 |
+
"educational_text": "Anchors to position 0.",
|
| 71 |
+
"icon": "map-pin",
|
| 72 |
+
"requires_repetition": False,
|
| 73 |
+
"top_heads": [
|
| 74 |
+
{"layer": 0, "head": 0, "score": 0.88}
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
"diffuse": {
|
| 78 |
+
"display_name": "Diffuse",
|
| 79 |
+
"description": "Spread attention",
|
| 80 |
+
"educational_text": "Even distribution.",
|
| 81 |
+
"icon": "expand-arrows-alt",
|
| 82 |
+
"requires_repetition": False,
|
| 83 |
+
"top_heads": [
|
| 84 |
+
{"layer": 1, "head": 3, "score": 0.80}
|
| 85 |
+
]
|
| 86 |
+
}
|
| 87 |
+
},
|
| 88 |
+
"all_scores": {}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@pytest.fixture(autouse=True)
|
| 94 |
+
def clear_cache():
|
| 95 |
+
"""Clear the category cache before each test."""
|
| 96 |
+
clear_category_cache()
|
| 97 |
+
yield
|
| 98 |
+
clear_category_cache()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# =============================================================================
|
| 102 |
+
# Tests for _compute_attention_entropy
|
| 103 |
+
# =============================================================================
|
| 104 |
+
|
| 105 |
class TestComputeAttentionEntropy:
|
| 106 |
+
"""Tests for _compute_attention_entropy helper."""
|
| 107 |
+
|
| 108 |
def test_uniform_distribution_high_entropy(self):
|
| 109 |
+
"""Uniform attention should have entropy near 1.0."""
|
| 110 |
+
weights = torch.ones(8) / 8
|
| 111 |
+
entropy = _compute_attention_entropy(weights)
|
| 112 |
+
assert entropy > 0.95
|
| 113 |
+
|
|
|
|
|
|
|
|
|
|
| 114 |
def test_peaked_distribution_low_entropy(self):
|
| 115 |
+
"""Peaked attention should have low entropy."""
|
| 116 |
+
weights = torch.zeros(8)
|
| 117 |
+
weights[0] = 0.98
|
| 118 |
+
weights[1:] = 0.02 / 7
|
| 119 |
+
entropy = _compute_attention_entropy(weights)
|
| 120 |
+
assert entropy < 0.3
|
| 121 |
+
|
| 122 |
+
def test_entropy_in_range(self):
|
| 123 |
+
"""Entropy should always be between 0 and 1."""
|
| 124 |
+
for _ in range(10):
|
| 125 |
+
weights = torch.softmax(torch.randn(6), dim=0)
|
| 126 |
+
entropy = _compute_attention_entropy(weights)
|
| 127 |
+
assert 0.0 <= entropy <= 1.0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# =============================================================================
|
| 131 |
+
# Tests for _find_repeated_tokens
|
| 132 |
+
# =============================================================================
|
| 133 |
+
|
| 134 |
+
class TestFindRepeatedTokens:
|
| 135 |
+
"""Tests for _find_repeated_tokens helper."""
|
| 136 |
+
|
| 137 |
+
def test_no_repeats(self):
|
| 138 |
+
"""No repetition returns empty dict."""
|
| 139 |
+
assert _find_repeated_tokens([1, 2, 3, 4]) == {}
|
| 140 |
+
|
| 141 |
+
def test_simple_repeat(self):
|
| 142 |
+
"""Repeated token returns positions."""
|
| 143 |
+
result = _find_repeated_tokens([10, 20, 10, 30])
|
| 144 |
+
assert 10 in result
|
| 145 |
+
assert result[10] == [0, 2]
|
| 146 |
+
assert 20 not in result
|
| 147 |
+
|
| 148 |
+
def test_multiple_repeats(self):
|
| 149 |
+
"""Multiple repeated tokens tracked."""
|
| 150 |
+
result = _find_repeated_tokens([5, 6, 5, 6, 7])
|
| 151 |
+
assert 5 in result and 6 in result
|
| 152 |
+
assert 7 not in result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
def test_empty_input(self):
|
| 155 |
+
assert _find_repeated_tokens([]) == {}
|
| 156 |
|
| 157 |
+
|
| 158 |
+
# =============================================================================
|
| 159 |
+
# Tests for load_head_categories
|
| 160 |
+
# =============================================================================
|
| 161 |
+
|
| 162 |
+
class TestLoadHeadCategories:
|
| 163 |
+
"""Tests for load_head_categories function."""
|
| 164 |
+
|
| 165 |
+
def test_loads_from_json(self, tmp_path):
|
| 166 |
+
"""Should load model data from JSON file."""
|
| 167 |
+
json_file = tmp_path / "head_categories.json"
|
| 168 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 169 |
+
|
| 170 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 171 |
+
result = load_head_categories("test-model")
|
| 172 |
|
| 173 |
+
assert result is not None
|
| 174 |
+
assert result["model_name"] == "test-model"
|
| 175 |
+
assert "previous_token" in result["categories"]
|
| 176 |
+
|
| 177 |
+
def test_returns_none_for_unknown_model(self, tmp_path):
|
| 178 |
+
"""Should return None when model not in JSON."""
|
| 179 |
+
json_file = tmp_path / "head_categories.json"
|
| 180 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 181 |
+
|
| 182 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 183 |
+
result = load_head_categories("nonexistent-model")
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
assert result is None
|
| 186 |
+
|
| 187 |
+
def test_returns_none_when_no_file(self, tmp_path):
|
| 188 |
+
"""Should return None when JSON file doesn't exist."""
|
| 189 |
+
with patch('utils.head_detection._JSON_PATH', tmp_path / "missing.json"):
|
| 190 |
+
result = load_head_categories("test-model")
|
| 191 |
|
| 192 |
+
assert result is None
|
|
|
|
| 193 |
|
| 194 |
+
def test_caches_results(self, tmp_path):
|
| 195 |
+
"""Should cache loaded data."""
|
| 196 |
+
json_file = tmp_path / "head_categories.json"
|
| 197 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 198 |
|
| 199 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 200 |
+
result1 = load_head_categories("test-model")
|
| 201 |
+
# Delete file to prove cache is used
|
| 202 |
+
json_file.unlink()
|
| 203 |
+
result2 = load_head_categories("test-model")
|
|
|
|
| 204 |
|
| 205 |
+
assert result1 is result2
|
| 206 |
+
|
| 207 |
+
def test_short_name_alias(self, tmp_path):
|
| 208 |
+
"""Should find model by short name (after /)."""
|
| 209 |
+
data = {"my-model": {"model_name": "my-model", "categories": {}}}
|
| 210 |
+
json_file = tmp_path / "head_categories.json"
|
| 211 |
+
json_file.write_text(json.dumps(data))
|
| 212 |
+
|
| 213 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 214 |
+
result = load_head_categories("org/my-model")
|
| 215 |
|
| 216 |
+
assert result is not None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# =============================================================================
|
| 220 |
+
# Tests for verify_head_activation
|
| 221 |
+
# =============================================================================
|
| 222 |
|
| 223 |
+
class TestVerifyHeadActivation:
|
| 224 |
+
"""Tests for verify_head_activation function."""
|
| 225 |
|
| 226 |
+
def test_previous_token_strong(self):
|
| 227 |
+
"""Strong previous-token pattern should score high."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
size = 6
|
| 229 |
matrix = torch.zeros(size, size)
|
| 230 |
+
for i in range(1, size):
|
| 231 |
+
matrix[i, i - 1] = 0.8
|
| 232 |
+
matrix[i, i] = 0.2
|
| 233 |
+
matrix[0, 0] = 1.0
|
| 234 |
+
|
| 235 |
+
score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "previous_token")
|
| 236 |
+
assert score > 0.6
|
| 237 |
+
|
| 238 |
+
def test_previous_token_weak(self):
|
| 239 |
+
"""Uniform attention should have low previous-token score."""
|
| 240 |
+
size = 6
|
| 241 |
+
matrix = torch.ones(size, size) / size
|
| 242 |
+
score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "previous_token")
|
| 243 |
+
assert score < 0.3
|
| 244 |
+
|
| 245 |
+
def test_induction_with_repetition(self):
|
| 246 |
+
"""Induction pattern should score > 0 when repeated tokens are present."""
|
| 247 |
+
# Tokens: [A, B, C, A, ?] — head should attend to B (position 1) from position 3
|
| 248 |
+
size = 5
|
| 249 |
+
matrix = torch.ones(size, size) / size # Baseline uniform
|
| 250 |
+
matrix[3, 1] = 0.7 # Position 3 (second A) attends to position 1 (B after first A)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
token_ids = [10, 20, 30, 10, 40] # Token 10 repeats
|
| 253 |
+
score = verify_head_activation(matrix, token_ids, "induction")
|
| 254 |
+
assert score > 0.3
|
| 255 |
+
|
| 256 |
+
def test_induction_no_repetition(self):
|
| 257 |
+
"""Induction should return 0.0 when no tokens repeat."""
|
| 258 |
+
matrix = torch.ones(4, 4) / 4
|
| 259 |
+
score = verify_head_activation(matrix, [1, 2, 3, 4], "induction")
|
| 260 |
+
assert score == 0.0
|
| 261 |
+
|
| 262 |
+
def test_duplicate_token_with_repeats(self):
|
| 263 |
+
"""Duplicate-token head should score > 0 when later positions attend to earlier same token."""
|
| 264 |
size = 5
|
| 265 |
+
matrix = torch.ones(size, size) / size
|
| 266 |
+
matrix[3, 0] = 0.6 # Position 3 (second occurrence of token 10) attends to position 0
|
| 267 |
+
|
| 268 |
+
token_ids = [10, 20, 30, 10, 40]
|
| 269 |
+
score = verify_head_activation(matrix, token_ids, "duplicate_token")
|
| 270 |
+
assert score > 0.3
|
| 271 |
+
|
| 272 |
+
def test_duplicate_token_no_repeats(self):
|
| 273 |
+
"""Should return 0.0 when no duplicates."""
|
| 274 |
+
matrix = torch.ones(4, 4) / 4
|
| 275 |
+
score = verify_head_activation(matrix, [1, 2, 3, 4], "duplicate_token")
|
| 276 |
+
assert score == 0.0
|
| 277 |
+
|
| 278 |
+
def test_positional_strong(self):
|
| 279 |
+
"""Strong first-token attention should score high."""
|
| 280 |
+
size = 6
|
| 281 |
matrix = torch.zeros(size, size)
|
| 282 |
for i in range(size):
|
| 283 |
+
matrix[i, 0] = 0.7
|
| 284 |
+
matrix[i, i] = 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
+
score = verify_head_activation(matrix, [1, 2, 3, 4, 5, 6], "positional")
|
| 287 |
+
assert score > 0.5
|
| 288 |
+
|
| 289 |
+
def test_diffuse_uniform(self):
|
| 290 |
+
"""Uniform attention should have high diffuse score."""
|
| 291 |
+
size = 8
|
| 292 |
+
matrix = torch.ones(size, size) / size
|
| 293 |
+
score = verify_head_activation(matrix, list(range(size)), "diffuse")
|
| 294 |
+
assert score > 0.8
|
| 295 |
+
|
| 296 |
+
def test_diffuse_peaked(self):
|
| 297 |
+
"""Peaked attention should have low diffuse score."""
|
| 298 |
+
size = 8
|
| 299 |
+
matrix = torch.zeros(size, size)
|
| 300 |
+
matrix[:, 0] = 1.0
|
| 301 |
+
score = verify_head_activation(matrix, list(range(size)), "diffuse")
|
| 302 |
+
assert score < 0.3
|
| 303 |
|
| 304 |
+
def test_unknown_category(self):
|
| 305 |
+
"""Unknown category should return 0.0."""
|
| 306 |
+
matrix = torch.ones(4, 4) / 4
|
| 307 |
+
assert verify_head_activation(matrix, [1, 2, 3, 4], "nonexistent") == 0.0
|
| 308 |
|
| 309 |
+
def test_short_sequence(self):
|
| 310 |
+
"""Very short sequence should return 0.0."""
|
| 311 |
+
matrix = torch.ones(1, 1)
|
| 312 |
+
assert verify_head_activation(matrix, [1], "previous_token") == 0.0
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# =============================================================================
|
| 316 |
+
# Tests for get_active_head_summary
|
| 317 |
+
# =============================================================================
|
| 318 |
+
|
| 319 |
+
class TestGetActiveHeadSummary:
|
| 320 |
+
"""Tests for get_active_head_summary function."""
|
| 321 |
+
|
| 322 |
+
def _make_activation_data(self, token_ids, num_layers=2, num_heads=4, seq_len=None):
|
| 323 |
+
"""Helper: create mock activation_data with given token_ids."""
|
| 324 |
+
if seq_len is None:
|
| 325 |
+
seq_len = len(token_ids)
|
| 326 |
|
| 327 |
+
attention_outputs = {}
|
| 328 |
+
for layer in range(num_layers):
|
| 329 |
+
# Create uniform attention [1, num_heads, seq_len, seq_len]
|
| 330 |
+
attn = torch.ones(1, num_heads, seq_len, seq_len) / seq_len
|
| 331 |
+
attention_outputs[f'model.layers.{layer}.self_attn'] = {
|
| 332 |
+
'output': [
|
| 333 |
+
[[0.1] * seq_len], # hidden states (unused)
|
| 334 |
+
attn.tolist()
|
| 335 |
+
]
|
| 336 |
+
}
|
| 337 |
|
| 338 |
+
return {
|
| 339 |
+
'model': 'test-model',
|
| 340 |
+
'input_ids': [token_ids],
|
| 341 |
+
'attention_outputs': attention_outputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
}
|
| 343 |
+
|
| 344 |
+
def test_returns_none_for_unknown_model(self, tmp_path):
|
| 345 |
+
"""Should return None when model not in JSON."""
|
| 346 |
+
json_file = tmp_path / "head_categories.json"
|
| 347 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 348 |
+
|
| 349 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 350 |
+
data = self._make_activation_data([1, 2, 3, 4])
|
| 351 |
+
result = get_active_head_summary(data, "unknown-model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
+
assert result is None
|
| 354 |
+
|
| 355 |
+
def test_returns_categories_structure(self, tmp_path):
|
| 356 |
+
"""Should return proper structure with categories."""
|
| 357 |
+
json_file = tmp_path / "head_categories.json"
|
| 358 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 359 |
+
|
| 360 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 361 |
+
data = self._make_activation_data([1, 2, 3, 4])
|
| 362 |
+
result = get_active_head_summary(data, "test-model")
|
| 363 |
|
| 364 |
+
assert result is not None
|
| 365 |
+
assert result["model_available"] is True
|
| 366 |
+
assert "categories" in result
|
| 367 |
+
assert "previous_token" in result["categories"]
|
| 368 |
+
assert "induction" in result["categories"]
|
| 369 |
+
|
| 370 |
+
def test_heads_have_activation_scores(self, tmp_path):
|
| 371 |
+
"""Each head should have an activation_score."""
|
| 372 |
+
json_file = tmp_path / "head_categories.json"
|
| 373 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 374 |
+
|
| 375 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 376 |
+
data = self._make_activation_data([1, 2, 3, 4])
|
| 377 |
+
result = get_active_head_summary(data, "test-model")
|
| 378 |
+
|
| 379 |
+
for cat_key, cat_data in result["categories"].items():
|
| 380 |
+
for head in cat_data.get("heads", []):
|
| 381 |
+
assert "activation_score" in head
|
| 382 |
+
assert "is_active" in head
|
| 383 |
+
assert "label" in head
|
| 384 |
+
|
| 385 |
+
def test_induction_grayed_when_no_repeats(self, tmp_path):
|
| 386 |
+
"""Induction should be non-applicable when no repeated tokens."""
|
| 387 |
+
json_file = tmp_path / "head_categories.json"
|
| 388 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 389 |
+
|
| 390 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 391 |
+
data = self._make_activation_data([1, 2, 3, 4]) # No repeats
|
| 392 |
+
result = get_active_head_summary(data, "test-model")
|
| 393 |
+
|
| 394 |
+
induction = result["categories"]["induction"]
|
| 395 |
+
assert induction["is_applicable"] is False
|
| 396 |
+
assert all(h["activation_score"] == 0.0 for h in induction["heads"])
|
| 397 |
+
|
| 398 |
+
def test_induction_active_with_repeats(self, tmp_path):
|
| 399 |
+
"""Induction should be applicable when tokens repeat."""
|
| 400 |
+
json_file = tmp_path / "head_categories.json"
|
| 401 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 402 |
+
|
| 403 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 404 |
+
data = self._make_activation_data([10, 20, 10, 30]) # Token 10 repeats
|
| 405 |
+
result = get_active_head_summary(data, "test-model")
|
| 406 |
+
|
| 407 |
+
induction = result["categories"]["induction"]
|
| 408 |
+
assert induction["is_applicable"] is True
|
| 409 |
+
|
| 410 |
+
def test_suggested_prompt_included(self, tmp_path):
|
| 411 |
+
"""Suggested prompt should appear for repetition-dependent categories."""
|
| 412 |
+
json_file = tmp_path / "head_categories.json"
|
| 413 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 414 |
+
|
| 415 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 416 |
+
data = self._make_activation_data([1, 2, 3, 4])
|
| 417 |
+
result = get_active_head_summary(data, "test-model")
|
| 418 |
+
|
| 419 |
+
assert result["categories"]["induction"]["suggested_prompt"] is not None
|
| 420 |
+
assert result["categories"]["duplicate_token"]["suggested_prompt"] is not None
|
| 421 |
+
|
| 422 |
+
def test_other_category_always_present(self, tmp_path):
|
| 423 |
+
"""Other/Unclassified category should always be in the result."""
|
| 424 |
+
json_file = tmp_path / "head_categories.json"
|
| 425 |
+
json_file.write_text(json.dumps(SAMPLE_JSON))
|
| 426 |
+
|
| 427 |
+
with patch('utils.head_detection._JSON_PATH', json_file):
|
| 428 |
+
data = self._make_activation_data([1, 2, 3, 4])
|
| 429 |
+
result = get_active_head_summary(data, "test-model")
|
| 430 |
+
|
| 431 |
+
assert "other" in result["categories"]
|
utils/__init__.py
CHANGED
|
@@ -8,7 +8,7 @@ from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
|
|
| 8 |
detect_significant_probability_increases,
|
| 9 |
evaluate_sequence_ablation, generate_bertviz_model_view_html)
|
| 10 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 11 |
-
from .head_detection import
|
| 12 |
from .beam_search import perform_beam_search
|
| 13 |
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
|
| 14 |
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
|
|
@@ -38,10 +38,9 @@ __all__ = [
|
|
| 38 |
'MODEL_FAMILIES',
|
| 39 |
|
| 40 |
# Head detection
|
| 41 |
-
'
|
| 42 |
-
'
|
| 43 |
-
'
|
| 44 |
-
'HeadCategorizationConfig',
|
| 45 |
|
| 46 |
# Beam search
|
| 47 |
'perform_beam_search',
|
|
|
|
| 8 |
detect_significant_probability_increases,
|
| 9 |
evaluate_sequence_ablation, generate_bertviz_model_view_html)
|
| 10 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 11 |
+
from .head_detection import load_head_categories, verify_head_activation, get_active_head_summary
|
| 12 |
from .beam_search import perform_beam_search
|
| 13 |
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
|
| 14 |
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
|
|
|
|
| 38 |
'MODEL_FAMILIES',
|
| 39 |
|
| 40 |
# Head detection
|
| 41 |
+
'load_head_categories',
|
| 42 |
+
'verify_head_activation',
|
| 43 |
+
'get_active_head_summary',
|
|
|
|
| 44 |
|
| 45 |
# Beam search
|
| 46 |
'perform_beam_search',
|
utils/head_categories.json
ADDED
|
@@ -0,0 +1,1099 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"gpt2": {
|
| 3 |
+
"model_name": "gpt2",
|
| 4 |
+
"num_layers": 12,
|
| 5 |
+
"num_heads": 12,
|
| 6 |
+
"analysis_date": "2026-02-26",
|
| 7 |
+
"categories": {
|
| 8 |
+
"previous_token": {
|
| 9 |
+
"display_name": "Previous Token",
|
| 10 |
+
"description": "Attends to the immediately preceding token \u2014 like reading left to right",
|
| 11 |
+
"icon": "arrow-left",
|
| 12 |
+
"educational_text": "This head looks at the word right before the current one. Like reading left to right, it helps track local word-by-word patterns.",
|
| 13 |
+
"requires_repetition": false,
|
| 14 |
+
"top_heads": [
|
| 15 |
+
{
|
| 16 |
+
"layer": 1,
|
| 17 |
+
"head": 0,
|
| 18 |
+
"score": 0.3655
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"layer": 2,
|
| 22 |
+
"head": 2,
|
| 23 |
+
"score": 0.5679
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"layer": 2,
|
| 27 |
+
"head": 5,
|
| 28 |
+
"score": 0.3384
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"layer": 2,
|
| 32 |
+
"head": 9,
|
| 33 |
+
"score": 0.4052
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"layer": 3,
|
| 37 |
+
"head": 2,
|
| 38 |
+
"score": 0.4164
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"layer": 3,
|
| 42 |
+
"head": 6,
|
| 43 |
+
"score": 0.3359
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"layer": 3,
|
| 47 |
+
"head": 7,
|
| 48 |
+
"score": 0.4419
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"layer": 4,
|
| 52 |
+
"head": 11,
|
| 53 |
+
"score": 0.97
|
| 54 |
+
}
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
"induction": {
|
| 58 |
+
"display_name": "Induction",
|
| 59 |
+
"description": "Completes repeated patterns: if it saw [A][B] before and now sees [A], it predicts [B]",
|
| 60 |
+
"icon": "repeat",
|
| 61 |
+
"educational_text": "This head finds patterns that happened before and predicts they'll happen again. If it saw 'the cat' earlier, it expects the same words to follow.",
|
| 62 |
+
"requires_repetition": true,
|
| 63 |
+
"suggested_prompt": "Try: 'The cat sat on the mat. The cat' \u2014 the repeated 'The cat' lets induction heads activate.",
|
| 64 |
+
"top_heads": [
|
| 65 |
+
{
|
| 66 |
+
"layer": 5,
|
| 67 |
+
"head": 0,
|
| 68 |
+
"score": 0.3363
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"layer": 5,
|
| 72 |
+
"head": 1,
|
| 73 |
+
"score": 0.4412
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"layer": 5,
|
| 77 |
+
"head": 5,
|
| 78 |
+
"score": 0.4119
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"layer": 5,
|
| 82 |
+
"head": 8,
|
| 83 |
+
"score": 0.3032
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"layer": 6,
|
| 87 |
+
"head": 9,
|
| 88 |
+
"score": 0.3017
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"layer": 7,
|
| 92 |
+
"head": 10,
|
| 93 |
+
"score": 0.2849
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"layer": 8,
|
| 97 |
+
"head": 1,
|
| 98 |
+
"score": 0.2608
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"layer": 10,
|
| 102 |
+
"head": 7,
|
| 103 |
+
"score": 0.2196
|
| 104 |
+
}
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
"duplicate_token": {
|
| 108 |
+
"display_name": "Duplicate Token",
|
| 109 |
+
"description": "Notices when the same word appears more than once",
|
| 110 |
+
"icon": "clone",
|
| 111 |
+
"educational_text": "This head notices when the same word appears more than once, like a highlighter for repeated words. It helps the model track which words have already been said.",
|
| 112 |
+
"requires_repetition": true,
|
| 113 |
+
"suggested_prompt": "Try a prompt with repeated words like 'The cat sat. The cat slept.' to see duplicate-token heads light up.",
|
| 114 |
+
"top_heads": [
|
| 115 |
+
{
|
| 116 |
+
"layer": 0,
|
| 117 |
+
"head": 1,
|
| 118 |
+
"score": 0.4175
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"layer": 0,
|
| 122 |
+
"head": 5,
|
| 123 |
+
"score": 0.4155
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"layer": 1,
|
| 127 |
+
"head": 11,
|
| 128 |
+
"score": 0.3256
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"layer": 3,
|
| 132 |
+
"head": 0,
|
| 133 |
+
"score": 0.2416
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"layer": 4,
|
| 137 |
+
"head": 7,
|
| 138 |
+
"score": 0.1238
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"layer": 11,
|
| 142 |
+
"head": 8,
|
| 143 |
+
"score": 0.1741
|
| 144 |
+
}
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
"positional": {
|
| 148 |
+
"display_name": "Positional / First-Token",
|
| 149 |
+
"description": "Always pays attention to the very first word, using it as an anchor point",
|
| 150 |
+
"icon": "map-pin",
|
| 151 |
+
"educational_text": "This head always pays attention to the very first word, using it as an anchor point. The first token serves as a 'default' position when no other token is specifically relevant.",
|
| 152 |
+
"requires_repetition": false,
|
| 153 |
+
"top_heads": [
|
| 154 |
+
{
|
| 155 |
+
"layer": 7,
|
| 156 |
+
"head": 2,
|
| 157 |
+
"score": 0.9077
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"layer": 9,
|
| 161 |
+
"head": 6,
|
| 162 |
+
"score": 0.9077
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"layer": 9,
|
| 166 |
+
"head": 9,
|
| 167 |
+
"score": 0.9064
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"layer": 9,
|
| 171 |
+
"head": 11,
|
| 172 |
+
"score": 0.9301
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"layer": 10,
|
| 176 |
+
"head": 10,
|
| 177 |
+
"score": 0.9098
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"layer": 11,
|
| 181 |
+
"head": 2,
|
| 182 |
+
"score": 0.8962
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"layer": 11,
|
| 186 |
+
"head": 6,
|
| 187 |
+
"score": 0.9231
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"layer": 11,
|
| 191 |
+
"head": 9,
|
| 192 |
+
"score": 0.9117
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
"diffuse": {
|
| 197 |
+
"display_name": "Diffuse / Spread",
|
| 198 |
+
"description": "Spreads attention evenly across many words, gathering general context",
|
| 199 |
+
"icon": "expand-arrows-alt",
|
| 200 |
+
"educational_text": "This head spreads its attention evenly across many words, gathering general context rather than focusing on one spot. It provides a 'big picture' summary of the input.",
|
| 201 |
+
"requires_repetition": false,
|
| 202 |
+
"top_heads": [
|
| 203 |
+
{
|
| 204 |
+
"layer": 0,
|
| 205 |
+
"head": 10,
|
| 206 |
+
"score": 0.6076
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"layer": 0,
|
| 210 |
+
"head": 11,
|
| 211 |
+
"score": 0.5915
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"layer": 1,
|
| 215 |
+
"head": 2,
|
| 216 |
+
"score": 0.5851
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"layer": 1,
|
| 220 |
+
"head": 4,
|
| 221 |
+
"score": 0.5693
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"layer": 1,
|
| 225 |
+
"head": 10,
|
| 226 |
+
"score": 0.6001
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"layer": 2,
|
| 230 |
+
"head": 7,
|
| 231 |
+
"score": 0.6227
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"layer": 2,
|
| 235 |
+
"head": 10,
|
| 236 |
+
"score": 0.6325
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"layer": 11,
|
| 240 |
+
"head": 0,
|
| 241 |
+
"score": 0.6132
|
| 242 |
+
}
|
| 243 |
+
]
|
| 244 |
+
}
|
| 245 |
+
},
|
| 246 |
+
"all_scores": {
|
| 247 |
+
"previous_token": [
|
| 248 |
+
[
|
| 249 |
+
0.1650262176990509,
|
| 250 |
+
0.005524545907974243,
|
| 251 |
+
0.13794219493865967,
|
| 252 |
+
0.11309953033924103,
|
| 253 |
+
0.19386060535907745,
|
| 254 |
+
0.02020726539194584,
|
| 255 |
+
0.18705399334430695,
|
| 256 |
+
0.3287373483181,
|
| 257 |
+
0.1688501238822937,
|
| 258 |
+
0.14645136892795563,
|
| 259 |
+
0.12409798055887222,
|
| 260 |
+
0.14697492122650146
|
| 261 |
+
],
|
| 262 |
+
[
|
| 263 |
+
0.36550161242485046,
|
| 264 |
+
0.22920921444892883,
|
| 265 |
+
0.1901777684688568,
|
| 266 |
+
0.13691475987434387,
|
| 267 |
+
0.1552433967590332,
|
| 268 |
+
0.1548655927181244,
|
| 269 |
+
0.14041779935359955,
|
| 270 |
+
0.1399569809436798,
|
| 271 |
+
0.14001941680908203,
|
| 272 |
+
0.12206045538187027,
|
| 273 |
+
0.18723534047603607,
|
| 274 |
+
0.05272947624325752
|
| 275 |
+
],
|
| 276 |
+
[
|
| 277 |
+
0.24368862807750702,
|
| 278 |
+
0.11734970659017563,
|
| 279 |
+
0.5678969025611877,
|
| 280 |
+
0.33175796270370483,
|
| 281 |
+
0.3293865919113159,
|
| 282 |
+
0.33843594789505005,
|
| 283 |
+
0.1687498688697815,
|
| 284 |
+
0.2169996052980423,
|
| 285 |
+
0.33436763286590576,
|
| 286 |
+
0.405174195766449,
|
| 287 |
+
0.20988500118255615,
|
| 288 |
+
0.1365954577922821
|
| 289 |
+
],
|
| 290 |
+
[
|
| 291 |
+
0.08308680355548859,
|
| 292 |
+
0.16770434379577637,
|
| 293 |
+
0.41642817854881287,
|
| 294 |
+
0.32616299390792847,
|
| 295 |
+
0.09816452860832214,
|
| 296 |
+
0.12414131313562393,
|
| 297 |
+
0.33591750264167786,
|
| 298 |
+
0.4418589174747467,
|
| 299 |
+
0.3060630261898041,
|
| 300 |
+
0.21817748248577118,
|
| 301 |
+
0.1548490822315216,
|
| 302 |
+
0.2623787224292755
|
| 303 |
+
],
|
| 304 |
+
[
|
| 305 |
+
0.24851615726947784,
|
| 306 |
+
0.22178645431995392,
|
| 307 |
+
0.10810651630163193,
|
| 308 |
+
0.2638419270515442,
|
| 309 |
+
0.1461866945028305,
|
| 310 |
+
0.19259677827358246,
|
| 311 |
+
0.16893190145492554,
|
| 312 |
+
0.20602412521839142,
|
| 313 |
+
0.11169518530368805,
|
| 314 |
+
0.16701465845108032,
|
| 315 |
+
0.09775038063526154,
|
| 316 |
+
0.9700173139572144
|
| 317 |
+
],
|
| 318 |
+
[
|
| 319 |
+
0.1162194162607193,
|
| 320 |
+
0.09808940440416336,
|
| 321 |
+
0.20977501571178436,
|
| 322 |
+
0.16994376480579376,
|
| 323 |
+
0.2316969633102417,
|
| 324 |
+
0.10760845243930817,
|
| 325 |
+
0.26810961961746216,
|
| 326 |
+
0.1556214690208435,
|
| 327 |
+
0.13168412446975708,
|
| 328 |
+
0.10098359733819962,
|
| 329 |
+
0.1563761830329895,
|
| 330 |
+
0.11529763042926788
|
| 331 |
+
],
|
| 332 |
+
[
|
| 333 |
+
0.23046550154685974,
|
| 334 |
+
0.13669200241565704,
|
| 335 |
+
0.10113422572612762,
|
| 336 |
+
0.12357200682163239,
|
| 337 |
+
0.12948814034461975,
|
| 338 |
+
0.14964132010936737,
|
| 339 |
+
0.11104538291692734,
|
| 340 |
+
0.17790208756923676,
|
| 341 |
+
0.3313186764717102,
|
| 342 |
+
0.09724397212266922,
|
| 343 |
+
0.1065865010023117,
|
| 344 |
+
0.19595712423324585
|
| 345 |
+
],
|
| 346 |
+
[
|
| 347 |
+
0.2756780683994293,
|
| 348 |
+
0.09617989510297775,
|
| 349 |
+
0.0887245386838913,
|
| 350 |
+
0.14660504460334778,
|
| 351 |
+
0.11926672607660294,
|
| 352 |
+
0.12578082084655762,
|
| 353 |
+
0.10664939880371094,
|
| 354 |
+
0.11368991434574127,
|
| 355 |
+
0.18360558152198792,
|
| 356 |
+
0.130024254322052,
|
| 357 |
+
0.10562390089035034,
|
| 358 |
+
0.10479450225830078
|
| 359 |
+
],
|
| 360 |
+
[
|
| 361 |
+
0.10714849084615707,
|
| 362 |
+
0.10390549898147583,
|
| 363 |
+
0.11945408582687378,
|
| 364 |
+
0.10176572948694229,
|
| 365 |
+
0.15246066451072693,
|
| 366 |
+
0.1935780942440033,
|
| 367 |
+
0.13547158241271973,
|
| 368 |
+
0.24629735946655273,
|
| 369 |
+
0.14471763372421265,
|
| 370 |
+
0.12072619050741196,
|
| 371 |
+
0.12850022315979004,
|
| 372 |
+
0.10024647414684296
|
| 373 |
+
],
|
| 374 |
+
[
|
| 375 |
+
0.1123703345656395,
|
| 376 |
+
0.10224141925573349,
|
| 377 |
+
0.10966678708791733,
|
| 378 |
+
0.24468424916267395,
|
| 379 |
+
0.09359707683324814,
|
| 380 |
+
0.11123354732990265,
|
| 381 |
+
0.09214123338460922,
|
| 382 |
+
0.11035183817148209,
|
| 383 |
+
0.09690441191196442,
|
| 384 |
+
0.09199563413858414,
|
| 385 |
+
0.16506430506706238,
|
| 386 |
+
0.08864383399486542
|
| 387 |
+
],
|
| 388 |
+
[
|
| 389 |
+
0.09993860870599747,
|
| 390 |
+
0.1017073541879654,
|
| 391 |
+
0.09143912047147751,
|
| 392 |
+
0.1137363463640213,
|
| 393 |
+
0.11926724761724472,
|
| 394 |
+
0.1261630356311798,
|
| 395 |
+
0.09609334915876389,
|
| 396 |
+
0.1267780214548111,
|
| 397 |
+
0.09360888600349426,
|
| 398 |
+
0.15695181488990784,
|
| 399 |
+
0.09125342220067978,
|
| 400 |
+
0.16533184051513672
|
| 401 |
+
],
|
| 402 |
+
[
|
| 403 |
+
0.1551479697227478,
|
| 404 |
+
0.10182406008243561,
|
| 405 |
+
0.09162592142820358,
|
| 406 |
+
0.14142417907714844,
|
| 407 |
+
0.10655181109905243,
|
| 408 |
+
0.09299013763666153,
|
| 409 |
+
0.08795793354511261,
|
| 410 |
+
0.10052843391895294,
|
| 411 |
+
0.18854694068431854,
|
| 412 |
+
0.09097206592559814,
|
| 413 |
+
0.14251284301280975,
|
| 414 |
+
0.13573673367500305
|
| 415 |
+
]
|
| 416 |
+
],
|
| 417 |
+
"induction": [
|
| 418 |
+
[
|
| 419 |
+
0.07627037912607193,
|
| 420 |
+
0.0035299647133797407,
|
| 421 |
+
0.050907380878925323,
|
| 422 |
+
0.018350504338741302,
|
| 423 |
+
0.055634528398513794,
|
| 424 |
+
0.015752490609884262,
|
| 425 |
+
0.09711054712533951,
|
| 426 |
+
0.08642718195915222,
|
| 427 |
+
0.07673756778240204,
|
| 428 |
+
0.06478650867938995,
|
| 429 |
+
0.05675221234560013,
|
| 430 |
+
0.0686919093132019
|
| 431 |
+
],
|
| 432 |
+
[
|
| 433 |
+
0.098502516746521,
|
| 434 |
+
0.08570204675197601,
|
| 435 |
+
0.09086534380912781,
|
| 436 |
+
0.05725013464689255,
|
| 437 |
+
0.06655086576938629,
|
| 438 |
+
0.08535383641719818,
|
| 439 |
+
0.04390129819512367,
|
| 440 |
+
0.05150846764445305,
|
| 441 |
+
0.05973561853170395,
|
| 442 |
+
0.05239921063184738,
|
| 443 |
+
0.10886937379837036,
|
| 444 |
+
0.03350156173110008
|
| 445 |
+
],
|
| 446 |
+
[
|
| 447 |
+
0.0880986899137497,
|
| 448 |
+
0.029988640919327736,
|
| 449 |
+
0.06596572697162628,
|
| 450 |
+
0.09502042829990387,
|
| 451 |
+
0.06376759707927704,
|
| 452 |
+
0.07735122740268707,
|
| 453 |
+
0.07770463079214096,
|
| 454 |
+
0.08998467028141022,
|
| 455 |
+
0.08355952054262161,
|
| 456 |
+
0.08642251044511795,
|
| 457 |
+
0.0951002761721611,
|
| 458 |
+
0.038624998182058334
|
| 459 |
+
],
|
| 460 |
+
[
|
| 461 |
+
0.012395743280649185,
|
| 462 |
+
0.0515044704079628,
|
| 463 |
+
0.0702400729060173,
|
| 464 |
+
0.038637131452560425,
|
| 465 |
+
0.03541486710309982,
|
| 466 |
+
0.04828893393278122,
|
| 467 |
+
0.07664503902196884,
|
| 468 |
+
0.05478388071060181,
|
| 469 |
+
0.05722055584192276,
|
| 470 |
+
0.05503711849451065,
|
| 471 |
+
0.05377575010061264,
|
| 472 |
+
0.05681142956018448
|
| 473 |
+
],
|
| 474 |
+
[
|
| 475 |
+
0.023173518478870392,
|
| 476 |
+
0.04842953383922577,
|
| 477 |
+
0.02587379515171051,
|
| 478 |
+
0.0371115505695343,
|
| 479 |
+
0.043572355061769485,
|
| 480 |
+
0.025999004021286964,
|
| 481 |
+
0.057220708578825,
|
| 482 |
+
0.05670655891299248,
|
| 483 |
+
0.05118811875581741,
|
| 484 |
+
0.029776636511087418,
|
| 485 |
+
0.02828892692923546,
|
| 486 |
+
0.050957612693309784
|
| 487 |
+
],
|
| 488 |
+
[
|
| 489 |
+
0.3362796902656555,
|
| 490 |
+
0.44116583466529846,
|
| 491 |
+
0.04926660656929016,
|
| 492 |
+
0.060651201754808426,
|
| 493 |
+
0.049554307013750076,
|
| 494 |
+
0.41194018721580505,
|
| 495 |
+
0.038970425724983215,
|
| 496 |
+
0.01051054522395134,
|
| 497 |
+
0.30320701003074646,
|
| 498 |
+
0.07053252309560776,
|
| 499 |
+
0.05541849881410599,
|
| 500 |
+
0.03842315822839737
|
| 501 |
+
],
|
| 502 |
+
[
|
| 503 |
+
0.04865153878927231,
|
| 504 |
+
0.13892090320587158,
|
| 505 |
+
0.023456398397684097,
|
| 506 |
+
0.043447092175483704,
|
| 507 |
+
0.05254914611577988,
|
| 508 |
+
0.06307318806648254,
|
| 509 |
+
0.06592734158039093,
|
| 510 |
+
0.06641103327274323,
|
| 511 |
+
0.06890955567359924,
|
| 512 |
+
0.3017217516899109,
|
| 513 |
+
0.053376901894807816,
|
| 514 |
+
0.05453646928071976
|
| 515 |
+
],
|
| 516 |
+
[
|
| 517 |
+
0.04203842580318451,
|
| 518 |
+
0.06195511296391487,
|
| 519 |
+
0.18403273820877075,
|
| 520 |
+
0.06932497024536133,
|
| 521 |
+
0.025891464203596115,
|
| 522 |
+
0.03674555569887161,
|
| 523 |
+
0.05915430188179016,
|
| 524 |
+
0.08904685080051422,
|
| 525 |
+
0.029217243194580078,
|
| 526 |
+
0.047680627554655075,
|
| 527 |
+
0.28489723801612854,
|
| 528 |
+
0.15201476216316223
|
| 529 |
+
],
|
| 530 |
+
[
|
| 531 |
+
0.03113759122788906,
|
| 532 |
+
0.2607646584510803,
|
| 533 |
+
0.04262052848935127,
|
| 534 |
+
0.03490695357322693,
|
| 535 |
+
0.020729169249534607,
|
| 536 |
+
0.039468441158533096,
|
| 537 |
+
0.17247121036052704,
|
| 538 |
+
0.02061128057539463,
|
| 539 |
+
0.0941251665353775,
|
| 540 |
+
0.044258393347263336,
|
| 541 |
+
0.09541143476963043,
|
| 542 |
+
0.03278326988220215
|
| 543 |
+
],
|
| 544 |
+
[
|
| 545 |
+
0.06156448647379875,
|
| 546 |
+
0.09029851853847504,
|
| 547 |
+
0.06509305536746979,
|
| 548 |
+
0.04298751801252365,
|
| 549 |
+
0.02618749439716339,
|
| 550 |
+
0.029909756034612656,
|
| 551 |
+
0.08973383903503418,
|
| 552 |
+
0.06374338269233704,
|
| 553 |
+
0.02463320828974247,
|
| 554 |
+
0.10424073040485382,
|
| 555 |
+
0.016569094732403755,
|
| 556 |
+
0.04829319566488266
|
| 557 |
+
],
|
| 558 |
+
[
|
| 559 |
+
0.0732613354921341,
|
| 560 |
+
0.15449705719947815,
|
| 561 |
+
0.048853177577257156,
|
| 562 |
+
0.12552715837955475,
|
| 563 |
+
0.1161937341094017,
|
| 564 |
+
0.020513027906417847,
|
| 565 |
+
0.08032035827636719,
|
| 566 |
+
0.21955707669258118,
|
| 567 |
+
0.07728692889213562,
|
| 568 |
+
0.014143750071525574,
|
| 569 |
+
0.056671954691410065,
|
| 570 |
+
0.1141514927148819
|
| 571 |
+
],
|
| 572 |
+
[
|
| 573 |
+
0.10236237943172455,
|
| 574 |
+
0.0509863905608654,
|
| 575 |
+
0.02403058484196663,
|
| 576 |
+
0.046142492443323135,
|
| 577 |
+
0.03625836968421936,
|
| 578 |
+
0.05091869831085205,
|
| 579 |
+
0.02450958639383316,
|
| 580 |
+
0.057415880262851715,
|
| 581 |
+
0.09816241264343262,
|
| 582 |
+
0.045323897153139114,
|
| 583 |
+
0.12710919976234436,
|
| 584 |
+
0.06512586772441864
|
| 585 |
+
]
|
| 586 |
+
],
|
| 587 |
+
"duplicate_token": [
|
| 588 |
+
[
|
| 589 |
+
0.061639100313186646,
|
| 590 |
+
0.4175182282924652,
|
| 591 |
+
0.05723930522799492,
|
| 592 |
+
0.039668913930654526,
|
| 593 |
+
0.0939607322216034,
|
| 594 |
+
0.41551661491394043,
|
| 595 |
+
0.07361333817243576,
|
| 596 |
+
0.0333673469722271,
|
| 597 |
+
0.0963386595249176,
|
| 598 |
+
0.0499253086745739,
|
| 599 |
+
0.17845425009727478,
|
| 600 |
+
0.0740630105137825
|
| 601 |
+
],
|
| 602 |
+
[
|
| 603 |
+
0.03887755423784256,
|
| 604 |
+
0.03720149025321007,
|
| 605 |
+
0.07625596970319748,
|
| 606 |
+
0.052537791430950165,
|
| 607 |
+
0.06014804169535637,
|
| 608 |
+
0.09469039738178253,
|
| 609 |
+
0.05574027821421623,
|
| 610 |
+
0.03633364289999008,
|
| 611 |
+
0.05319533869624138,
|
| 612 |
+
0.04128124564886093,
|
| 613 |
+
0.10213665664196014,
|
| 614 |
+
0.3255976736545563
|
| 615 |
+
],
|
| 616 |
+
[
|
| 617 |
+
0.0270945243537426,
|
| 618 |
+
0.02465079165995121,
|
| 619 |
+
0.003460302483290434,
|
| 620 |
+
0.01619820110499859,
|
| 621 |
+
0.008633781224489212,
|
| 622 |
+
0.012598037719726562,
|
| 623 |
+
0.04559514671564102,
|
| 624 |
+
0.06271781027317047,
|
| 625 |
+
0.014696493744850159,
|
| 626 |
+
0.012923041358590126,
|
| 627 |
+
0.07460619509220123,
|
| 628 |
+
0.027807259932160378
|
| 629 |
+
],
|
| 630 |
+
[
|
| 631 |
+
0.24161744117736816,
|
| 632 |
+
0.013565832749009132,
|
| 633 |
+
0.006801762618124485,
|
| 634 |
+
0.0032485886476933956,
|
| 635 |
+
0.02135937288403511,
|
| 636 |
+
0.024630073457956314,
|
| 637 |
+
0.015564021654427052,
|
| 638 |
+
0.005436367355287075,
|
| 639 |
+
0.007849231362342834,
|
| 640 |
+
0.015441101975739002,
|
| 641 |
+
0.04518696293234825,
|
| 642 |
+
0.013415353372693062
|
| 643 |
+
],
|
| 644 |
+
[
|
| 645 |
+
0.0038080490194261074,
|
| 646 |
+
0.00991421565413475,
|
| 647 |
+
0.025079775601625443,
|
| 648 |
+
0.011280774138867855,
|
| 649 |
+
0.04912680760025978,
|
| 650 |
+
0.006715251598507166,
|
| 651 |
+
0.021937724202871323,
|
| 652 |
+
0.12375693023204803,
|
| 653 |
+
0.026765504851937294,
|
| 654 |
+
0.011192137375473976,
|
| 655 |
+
0.025936853140592575,
|
| 656 |
+
8.196845010388643e-05
|
| 657 |
+
],
|
| 658 |
+
[
|
| 659 |
+
0.023429764434695244,
|
| 660 |
+
0.016590412706136703,
|
| 661 |
+
0.017092403024435043,
|
| 662 |
+
0.03277356177568436,
|
| 663 |
+
0.016331162303686142,
|
| 664 |
+
0.021816818043589592,
|
| 665 |
+
0.011733165010809898,
|
| 666 |
+
0.005887174047529697,
|
| 667 |
+
0.01492474414408207,
|
| 668 |
+
0.030711984261870384,
|
| 669 |
+
0.07108811289072037,
|
| 670 |
+
0.06261330097913742
|
| 671 |
+
],
|
| 672 |
+
[
|
| 673 |
+
0.02555452659726143,
|
| 674 |
+
0.029351357370615005,
|
| 675 |
+
0.021288855001330376,
|
| 676 |
+
0.024492312222719193,
|
| 677 |
+
0.039061177521944046,
|
| 678 |
+
0.03344884514808655,
|
| 679 |
+
0.06831201910972595,
|
| 680 |
+
0.03736294433474541,
|
| 681 |
+
0.019588876515626907,
|
| 682 |
+
0.04092007130384445,
|
| 683 |
+
0.01721787452697754,
|
| 684 |
+
0.019499698653817177
|
| 685 |
+
],
|
| 686 |
+
[
|
| 687 |
+
0.020283106714487076,
|
| 688 |
+
0.02244160696864128,
|
| 689 |
+
0.01908939704298973,
|
| 690 |
+
0.0162697471678257,
|
| 691 |
+
0.02050776034593582,
|
| 692 |
+
0.02750096097588539,
|
| 693 |
+
0.026029860600829124,
|
| 694 |
+
0.03217357397079468,
|
| 695 |
+
0.014307908713817596,
|
| 696 |
+
0.006763854529708624,
|
| 697 |
+
0.04564401134848595,
|
| 698 |
+
0.027008097618818283
|
| 699 |
+
],
|
| 700 |
+
[
|
| 701 |
+
0.027883464470505714,
|
| 702 |
+
0.041265588253736496,
|
| 703 |
+
0.028905224055051804,
|
| 704 |
+
0.013592107221484184,
|
| 705 |
+
0.0074845412746071815,
|
| 706 |
+
0.03488120436668396,
|
| 707 |
+
0.04030846059322357,
|
| 708 |
+
0.010207113809883595,
|
| 709 |
+
0.035800714045763016,
|
| 710 |
+
0.029832065105438232,
|
| 711 |
+
0.02576960064470768,
|
| 712 |
+
0.014182129874825478
|
| 713 |
+
],
|
| 714 |
+
[
|
| 715 |
+
0.017836367711424828,
|
| 716 |
+
0.029379570856690407,
|
| 717 |
+
0.022140078246593475,
|
| 718 |
+
0.036215025931596756,
|
| 719 |
+
0.024319598451256752,
|
| 720 |
+
0.026142369955778122,
|
| 721 |
+
0.018539801239967346,
|
| 722 |
+
0.019365690648555756,
|
| 723 |
+
0.011654431000351906,
|
| 724 |
+
0.025902757421135902,
|
| 725 |
+
0.015683690086007118,
|
| 726 |
+
0.010347607545554638
|
| 727 |
+
],
|
| 728 |
+
[
|
| 729 |
+
0.02144056186079979,
|
| 730 |
+
0.046325650066137314,
|
| 731 |
+
0.021630164235830307,
|
| 732 |
+
0.05147164314985275,
|
| 733 |
+
0.042117439210414886,
|
| 734 |
+
0.02441989816725254,
|
| 735 |
+
0.02136657014489174,
|
| 736 |
+
0.05447021871805191,
|
| 737 |
+
0.03011142648756504,
|
| 738 |
+
0.020071811974048615,
|
| 739 |
+
0.016738489270210266,
|
| 740 |
+
0.04836065694689751
|
| 741 |
+
],
|
| 742 |
+
[
|
| 743 |
+
0.13101476430892944,
|
| 744 |
+
0.03627091646194458,
|
| 745 |
+
0.0201750285923481,
|
| 746 |
+
0.06851539760828018,
|
| 747 |
+
0.029396140947937965,
|
| 748 |
+
0.03782244399189949,
|
| 749 |
+
0.014253688976168633,
|
| 750 |
+
0.044284969568252563,
|
| 751 |
+
0.17414367198944092,
|
| 752 |
+
0.021388430148363113,
|
| 753 |
+
0.06319155544042587,
|
| 754 |
+
0.055135130882263184
|
| 755 |
+
]
|
| 756 |
+
],
|
| 757 |
+
"positional": [
|
| 758 |
+
[
|
| 759 |
+
0.5065976977348328,
|
| 760 |
+
0.07629109919071198,
|
| 761 |
+
0.5960054397583008,
|
| 762 |
+
0.1072789654135704,
|
| 763 |
+
0.1979677975177765,
|
| 764 |
+
0.13927273452281952,
|
| 765 |
+
0.40057316422462463,
|
| 766 |
+
0.294817179441452,
|
| 767 |
+
0.383198618888855,
|
| 768 |
+
0.5544258952140808,
|
| 769 |
+
0.40033283829689026,
|
| 770 |
+
0.47870078682899475
|
| 771 |
+
],
|
| 772 |
+
[
|
| 773 |
+
0.2410203516483307,
|
| 774 |
+
0.4396105706691742,
|
| 775 |
+
0.4307883381843567,
|
| 776 |
+
0.5517755746841431,
|
| 777 |
+
0.5317303538322449,
|
| 778 |
+
0.5054966807365417,
|
| 779 |
+
0.6495388746261597,
|
| 780 |
+
0.6267575025558472,
|
| 781 |
+
0.5890303254127502,
|
| 782 |
+
0.6793325543403625,
|
| 783 |
+
0.07594899833202362,
|
| 784 |
+
0.21587026119232178
|
| 785 |
+
],
|
| 786 |
+
[
|
| 787 |
+
0.4007927477359772,
|
| 788 |
+
0.7385829091072083,
|
| 789 |
+
0.1999039351940155,
|
| 790 |
+
0.30451780557632446,
|
| 791 |
+
0.46449190378189087,
|
| 792 |
+
0.3399127125740051,
|
| 793 |
+
0.514499306678772,
|
| 794 |
+
0.29614612460136414,
|
| 795 |
+
0.31728798151016235,
|
| 796 |
+
0.2615760266780853,
|
| 797 |
+
0.3395046591758728,
|
| 798 |
+
0.7219924926757812
|
| 799 |
+
],
|
| 800 |
+
[
|
| 801 |
+
0.8190140724182129,
|
| 802 |
+
0.6275245547294617,
|
| 803 |
+
0.25404971837997437,
|
| 804 |
+
0.6006070375442505,
|
| 805 |
+
0.8895429372787476,
|
| 806 |
+
0.7170742154121399,
|
| 807 |
+
0.3035760521888733,
|
| 808 |
+
0.35117024183273315,
|
| 809 |
+
0.4254607558250427,
|
| 810 |
+
0.5432918071746826,
|
| 811 |
+
0.6645973920822144,
|
| 812 |
+
0.47774600982666016
|
| 813 |
+
],
|
| 814 |
+
[
|
| 815 |
+
0.5796363949775696,
|
| 816 |
+
0.5921002626419067,
|
| 817 |
+
0.793941080570221,
|
| 818 |
+
0.49824151396751404,
|
| 819 |
+
0.7273139953613281,
|
| 820 |
+
0.6757563948631287,
|
| 821 |
+
0.64992356300354,
|
| 822 |
+
0.3122835159301758,
|
| 823 |
+
0.8277088403701782,
|
| 824 |
+
0.6422610878944397,
|
| 825 |
+
0.8769611120223999,
|
| 826 |
+
0.14915767312049866
|
| 827 |
+
],
|
| 828 |
+
[
|
| 829 |
+
0.7556132078170776,
|
| 830 |
+
0.8456296920776367,
|
| 831 |
+
0.6256846785545349,
|
| 832 |
+
0.5377398729324341,
|
| 833 |
+
0.5960881114006042,
|
| 834 |
+
0.7833361625671387,
|
| 835 |
+
0.723742663860321,
|
| 836 |
+
0.7974669933319092,
|
| 837 |
+
0.7113959789276123,
|
| 838 |
+
0.8386362791061401,
|
| 839 |
+
0.6537194848060608,
|
| 840 |
+
0.7253992557525635
|
| 841 |
+
],
|
| 842 |
+
[
|
| 843 |
+
0.538119912147522,
|
| 844 |
+
0.7342842817306519,
|
| 845 |
+
0.8442155718803406,
|
| 846 |
+
0.7554894685745239,
|
| 847 |
+
0.6839307546615601,
|
| 848 |
+
0.7064528465270996,
|
| 849 |
+
0.7554677724838257,
|
| 850 |
+
0.6205617189407349,
|
| 851 |
+
0.5202042460441589,
|
| 852 |
+
0.8443636894226074,
|
| 853 |
+
0.8635346293449402,
|
| 854 |
+
0.6343041062355042
|
| 855 |
+
],
|
| 856 |
+
[
|
| 857 |
+
0.6614936590194702,
|
| 858 |
+
0.8791419267654419,
|
| 859 |
+
0.9076933860778809,
|
| 860 |
+
0.7058827877044678,
|
| 861 |
+
0.8025026321411133,
|
| 862 |
+
0.7749000787734985,
|
| 863 |
+
0.838254451751709,
|
| 864 |
+
0.8037239909172058,
|
| 865 |
+
0.6864684224128723,
|
| 866 |
+
0.7610327005386353,
|
| 867 |
+
0.8215873837471008,
|
| 868 |
+
0.8486534357070923
|
| 869 |
+
],
|
| 870 |
+
[
|
| 871 |
+
0.8073843121528625,
|
| 872 |
+
0.8061873316764832,
|
| 873 |
+
0.7319211959838867,
|
| 874 |
+
0.8467031717300415,
|
| 875 |
+
0.7768716812133789,
|
| 876 |
+
0.6048685908317566,
|
| 877 |
+
0.7132378816604614,
|
| 878 |
+
0.6679729223251343,
|
| 879 |
+
0.6701217889785767,
|
| 880 |
+
0.7771828770637512,
|
| 881 |
+
0.7071925401687622,
|
| 882 |
+
0.8558918237686157
|
| 883 |
+
],
|
| 884 |
+
[
|
| 885 |
+
0.8133878707885742,
|
| 886 |
+
0.8669012784957886,
|
| 887 |
+
0.8068772554397583,
|
| 888 |
+
0.5790890455245972,
|
| 889 |
+
0.8904383778572083,
|
| 890 |
+
0.8204380869865417,
|
| 891 |
+
0.9076582789421082,
|
| 892 |
+
0.7966066002845764,
|
| 893 |
+
0.8762456774711609,
|
| 894 |
+
0.9064305424690247,
|
| 895 |
+
0.7492377758026123,
|
| 896 |
+
0.9301468133926392
|
| 897 |
+
],
|
| 898 |
+
[
|
| 899 |
+
0.8455430269241333,
|
| 900 |
+
0.8402767181396484,
|
| 901 |
+
0.890575110912323,
|
| 902 |
+
0.7642854452133179,
|
| 903 |
+
0.7333279252052307,
|
| 904 |
+
0.7862328290939331,
|
| 905 |
+
0.8635441660881042,
|
| 906 |
+
0.6658955812454224,
|
| 907 |
+
0.888232409954071,
|
| 908 |
+
0.7337470054626465,
|
| 909 |
+
0.9097886085510254,
|
| 910 |
+
0.7254845499992371
|
| 911 |
+
],
|
| 912 |
+
[
|
| 913 |
+
0.3025703728199005,
|
| 914 |
+
0.8144607543945312,
|
| 915 |
+
0.8962485194206238,
|
| 916 |
+
0.6487042307853699,
|
| 917 |
+
0.7963070869445801,
|
| 918 |
+
0.8672806620597839,
|
| 919 |
+
0.9231362342834473,
|
| 920 |
+
0.8210302591323853,
|
| 921 |
+
0.07466430962085724,
|
| 922 |
+
0.9117152094841003,
|
| 923 |
+
0.6209774017333984,
|
| 924 |
+
0.6903347969055176
|
| 925 |
+
]
|
| 926 |
+
],
|
| 927 |
+
"diffuse": [
|
| 928 |
+
[
|
| 929 |
+
0.5471135377883911,
|
| 930 |
+
0.1322605162858963,
|
| 931 |
+
0.492602676153183,
|
| 932 |
+
0.21496565639972687,
|
| 933 |
+
0.45495811104774475,
|
| 934 |
+
0.25727584958076477,
|
| 935 |
+
0.5676304697990417,
|
| 936 |
+
0.5459160804748535,
|
| 937 |
+
0.5383939146995544,
|
| 938 |
+
0.5441114902496338,
|
| 939 |
+
0.6075721383094788,
|
| 940 |
+
0.5915287137031555
|
| 941 |
+
],
|
| 942 |
+
[
|
| 943 |
+
0.56114661693573,
|
| 944 |
+
0.5631774663925171,
|
| 945 |
+
0.5851024389266968,
|
| 946 |
+
0.5447676777839661,
|
| 947 |
+
0.5693410038948059,
|
| 948 |
+
0.510784924030304,
|
| 949 |
+
0.4271117150783539,
|
| 950 |
+
0.48312950134277344,
|
| 951 |
+
0.5217397212982178,
|
| 952 |
+
0.4331055283546448,
|
| 953 |
+
0.60009765625,
|
| 954 |
+
0.3668949007987976
|
| 955 |
+
],
|
| 956 |
+
[
|
| 957 |
+
0.5618427991867065,
|
| 958 |
+
0.35582321882247925,
|
| 959 |
+
0.34944772720336914,
|
| 960 |
+
0.5037699937820435,
|
| 961 |
+
0.4152102470397949,
|
| 962 |
+
0.47268810868263245,
|
| 963 |
+
0.5098887085914612,
|
| 964 |
+
0.622725248336792,
|
| 965 |
+
0.47435516119003296,
|
| 966 |
+
0.48120027780532837,
|
| 967 |
+
0.6324588060379028,
|
| 968 |
+
0.4027617573738098
|
| 969 |
+
],
|
| 970 |
+
[
|
| 971 |
+
0.17714563012123108,
|
| 972 |
+
0.4172298312187195,
|
| 973 |
+
0.42452120780944824,
|
| 974 |
+
0.31828364729881287,
|
| 975 |
+
0.18911775946617126,
|
| 976 |
+
0.38251644372940063,
|
| 977 |
+
0.5157310366630554,
|
| 978 |
+
0.4105154871940613,
|
| 979 |
+
0.41387349367141724,
|
| 980 |
+
0.4185497760772705,
|
| 981 |
+
0.40337443351745605,
|
| 982 |
+
0.4543667733669281
|
| 983 |
+
],
|
| 984 |
+
[
|
| 985 |
+
0.3102322220802307,
|
| 986 |
+
0.38234779238700867,
|
| 987 |
+
0.3048619031906128,
|
| 988 |
+
0.4123547673225403,
|
| 989 |
+
0.3599177300930023,
|
| 990 |
+
0.34652307629585266,
|
| 991 |
+
0.447924941778183,
|
| 992 |
+
0.46825671195983887,
|
| 993 |
+
0.26102879643440247,
|
| 994 |
+
0.3940913677215576,
|
| 995 |
+
0.20296287536621094,
|
| 996 |
+
0.02204204723238945
|
| 997 |
+
],
|
| 998 |
+
[
|
| 999 |
+
0.2029620110988617,
|
| 1000 |
+
0.08709979802370071,
|
| 1001 |
+
0.40380486845970154,
|
| 1002 |
+
0.514489471912384,
|
| 1003 |
+
0.4261854588985443,
|
| 1004 |
+
0.1830417364835739,
|
| 1005 |
+
0.26347407698631287,
|
| 1006 |
+
0.2405150830745697,
|
| 1007 |
+
0.2826869487762451,
|
| 1008 |
+
0.24574777483940125,
|
| 1009 |
+
0.3901086449623108,
|
| 1010 |
+
0.3574109673500061
|
| 1011 |
+
],
|
| 1012 |
+
[
|
| 1013 |
+
0.46612176299095154,
|
| 1014 |
+
0.3027900457382202,
|
| 1015 |
+
0.25536319613456726,
|
| 1016 |
+
0.3338863253593445,
|
| 1017 |
+
0.3941308856010437,
|
| 1018 |
+
0.39528438448905945,
|
| 1019 |
+
0.3291165232658386,
|
| 1020 |
+
0.44284719228744507,
|
| 1021 |
+
0.41498908400535583,
|
| 1022 |
+
0.12233757972717285,
|
| 1023 |
+
0.20009461045265198,
|
| 1024 |
+
0.4175761640071869
|
| 1025 |
+
],
|
| 1026 |
+
[
|
| 1027 |
+
0.33120664954185486,
|
| 1028 |
+
0.1765395551919937,
|
| 1029 |
+
0.09227706491947174,
|
| 1030 |
+
0.37284451723098755,
|
| 1031 |
+
0.2708284258842468,
|
| 1032 |
+
0.31805992126464844,
|
| 1033 |
+
0.25206413865089417,
|
| 1034 |
+
0.21613168716430664,
|
| 1035 |
+
0.3545899987220764,
|
| 1036 |
+
0.3042650818824768,
|
| 1037 |
+
0.14626441895961761,
|
| 1038 |
+
0.1727096140384674
|
| 1039 |
+
],
|
| 1040 |
+
[
|
| 1041 |
+
0.28339847922325134,
|
| 1042 |
+
0.18787869811058044,
|
| 1043 |
+
0.36294665932655334,
|
| 1044 |
+
0.2241670787334442,
|
| 1045 |
+
0.27335819602012634,
|
| 1046 |
+
0.4469229280948639,
|
| 1047 |
+
0.2862758934497833,
|
| 1048 |
+
0.3158189654350281,
|
| 1049 |
+
0.3742186725139618,
|
| 1050 |
+
0.30465927720069885,
|
| 1051 |
+
0.38407495617866516,
|
| 1052 |
+
0.21899032592773438
|
| 1053 |
+
],
|
| 1054 |
+
[
|
| 1055 |
+
0.23532763123512268,
|
| 1056 |
+
0.16719377040863037,
|
| 1057 |
+
0.2597936987876892,
|
| 1058 |
+
0.4364214539527893,
|
| 1059 |
+
0.17044395208358765,
|
| 1060 |
+
0.2712015211582184,
|
| 1061 |
+
0.13269579410552979,
|
| 1062 |
+
0.2855920195579529,
|
| 1063 |
+
0.18635967373847961,
|
| 1064 |
+
0.1326359510421753,
|
| 1065 |
+
0.29712045192718506,
|
| 1066 |
+
0.11560585349798203
|
| 1067 |
+
],
|
| 1068 |
+
[
|
| 1069 |
+
0.2210322916507721,
|
| 1070 |
+
0.19647784531116486,
|
| 1071 |
+
0.17695878446102142,
|
| 1072 |
+
0.29771047830581665,
|
| 1073 |
+
0.33597418665885925,
|
| 1074 |
+
0.2783747613430023,
|
| 1075 |
+
0.19375105202198029,
|
| 1076 |
+
0.3423268496990204,
|
| 1077 |
+
0.16622166335582733,
|
| 1078 |
+
0.3245820999145508,
|
| 1079 |
+
0.1462937742471695,
|
| 1080 |
+
0.2878214716911316
|
| 1081 |
+
],
|
| 1082 |
+
[
|
| 1083 |
+
0.6131904721260071,
|
| 1084 |
+
0.2794114649295807,
|
| 1085 |
+
0.18150922656059265,
|
| 1086 |
+
0.42593449354171753,
|
| 1087 |
+
0.31345874071121216,
|
| 1088 |
+
0.20985659956932068,
|
| 1089 |
+
0.14243251085281372,
|
| 1090 |
+
0.2698703110218048,
|
| 1091 |
+
0.5045338869094849,
|
| 1092 |
+
0.15346872806549072,
|
| 1093 |
+
0.4387816786766052,
|
| 1094 |
+
0.3756435811519623
|
| 1095 |
+
]
|
| 1096 |
+
]
|
| 1097 |
+
}
|
| 1098 |
+
}
|
| 1099 |
+
}
|
utils/head_detection.py
CHANGED
|
@@ -1,313 +1,256 @@
|
|
| 1 |
"""
|
| 2 |
Attention Head Detection and Categorization
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
- Other: heads that don't fit the above categories
|
| 10 |
"""
|
| 11 |
|
|
|
|
|
|
|
| 12 |
import torch
|
| 13 |
import numpy as np
|
| 14 |
from typing import Dict, List, Tuple, Optional, Any
|
| 15 |
import re
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
Configuration for attention head categorization heuristics.
|
| 21 |
-
|
| 22 |
-
These thresholds are tuned to balance sensitivity (catching relevant patterns)
|
| 23 |
-
with specificity (avoiding false positives) for educational purposes.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
def __init__(self):
|
| 27 |
-
# Previous-token head thresholds
|
| 28 |
-
# Heads that primarily attend to the immediately preceding token
|
| 29 |
-
self.prev_token_threshold = 0.4 # Minimum avg attention to prev token (40%)
|
| 30 |
-
self.prev_token_diagonal_offset = 1 # Check i → i-1 pattern
|
| 31 |
-
|
| 32 |
-
# First/Positional head thresholds
|
| 33 |
-
# Heads that attend strongly to first token or show positional patterns
|
| 34 |
-
self.first_token_threshold = 0.25 # Minimum avg attention to first token (25%)
|
| 35 |
-
self.positional_pattern_threshold = 0.4 # For detecting positional patterns
|
| 36 |
-
|
| 37 |
-
# Bag-of-words head thresholds
|
| 38 |
-
# Heads with diffuse attention across many tokens
|
| 39 |
-
self.bow_entropy_threshold = 0.65 # Minimum entropy (normalized, 0-1 scale)
|
| 40 |
-
self.bow_max_attention_threshold = 0.35 # Maximum attention to any single token
|
| 41 |
-
|
| 42 |
-
# Syntactic head thresholds
|
| 43 |
-
# Heads showing structured distance patterns (e.g., subject-verb)
|
| 44 |
-
self.syntactic_distance_pattern_threshold = 0.3 # For detecting distance patterns
|
| 45 |
-
|
| 46 |
-
# General thresholds
|
| 47 |
-
self.min_seq_len = 4 # Minimum sequence length for reliable detection
|
| 48 |
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
|
|
|
| 51 |
"""
|
| 52 |
-
|
| 53 |
|
| 54 |
Args:
|
| 55 |
-
|
| 56 |
|
| 57 |
Returns:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Normalize by max entropy (log(n) where n is sequence length)
|
| 68 |
-
max_entropy = np.log(len(weights))
|
| 69 |
-
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
|
| 70 |
-
|
| 71 |
-
return normalized_entropy.item()
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def detect_previous_token_head(attention_matrix: torch.Tensor, config: HeadCategorizationConfig) -> Tuple[bool, float]:
|
| 75 |
"""
|
| 76 |
-
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
seq_len = attention_matrix.shape[0]
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
|
| 99 |
-
return
|
|
|
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
|
|
|
|
| 103 |
"""
|
| 104 |
-
|
| 105 |
|
| 106 |
Args:
|
| 107 |
-
|
| 108 |
-
config: Configuration object
|
| 109 |
|
| 110 |
Returns:
|
| 111 |
-
(
|
| 112 |
"""
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# Check average attention to first token across all positions
|
| 119 |
-
first_token_attention = attention_matrix[:, 0].mean().item()
|
| 120 |
-
is_first_token_head = first_token_attention >= config.first_token_threshold
|
| 121 |
-
|
| 122 |
-
return is_first_token_head, first_token_attention
|
| 123 |
|
| 124 |
|
| 125 |
-
def
|
| 126 |
"""
|
| 127 |
-
|
| 128 |
|
| 129 |
Args:
|
| 130 |
-
|
| 131 |
-
config: Configuration object
|
| 132 |
|
| 133 |
Returns:
|
| 134 |
-
|
| 135 |
"""
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
for i in range(seq_len):
|
| 146 |
-
entropy = compute_attention_entropy(attention_matrix[i])
|
| 147 |
-
max_attention = attention_matrix[i].max().item()
|
| 148 |
-
|
| 149 |
-
entropies.append(entropy)
|
| 150 |
-
max_attentions.append(max_attention)
|
| 151 |
-
|
| 152 |
-
avg_entropy = np.mean(entropies)
|
| 153 |
-
avg_max_attention = np.mean(max_attentions)
|
| 154 |
-
|
| 155 |
-
# BoW heads have high entropy and low max attention (diffuse)
|
| 156 |
-
is_bow_head = (avg_entropy >= config.bow_entropy_threshold and
|
| 157 |
-
avg_max_attention <= config.bow_max_attention_threshold)
|
| 158 |
-
|
| 159 |
-
return is_bow_head, avg_entropy
|
| 160 |
|
| 161 |
|
| 162 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
"""
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
This is a simplified heuristic based on consistent distance patterns.
|
| 167 |
|
| 168 |
Args:
|
| 169 |
-
|
| 170 |
-
|
|
|
|
| 171 |
|
| 172 |
Returns:
|
| 173 |
-
|
| 174 |
"""
|
| 175 |
-
seq_len =
|
| 176 |
-
|
| 177 |
-
if seq_len <
|
| 178 |
-
return
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
layer_idx: int,
|
| 213 |
-
head_idx: int,
|
| 214 |
-
config: Optional[HeadCategorizationConfig] = None) -> Dict[str, Any]:
|
| 215 |
-
"""
|
| 216 |
-
Categorize a single attention head based on its attention pattern.
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
| 223 |
|
| 224 |
-
Returns:
|
| 225 |
-
Dictionary with categorization results:
|
| 226 |
-
{
|
| 227 |
-
'layer': layer_idx,
|
| 228 |
-
'head': head_idx,
|
| 229 |
-
'category': str (one of: 'previous_token', 'first_token', 'bow', 'syntactic', 'other'),
|
| 230 |
-
'scores': dict of scores for each category,
|
| 231 |
-
'label': formatted label like "L{layer}-H{head}"
|
| 232 |
-
}
|
| 233 |
-
"""
|
| 234 |
-
if config is None:
|
| 235 |
-
config = HeadCategorizationConfig()
|
| 236 |
-
|
| 237 |
-
# Run all detection heuristics
|
| 238 |
-
is_prev, prev_score = detect_previous_token_head(attention_matrix, config)
|
| 239 |
-
is_first, first_score = detect_first_token_head(attention_matrix, config)
|
| 240 |
-
is_bow, bow_score = detect_bow_head(attention_matrix, config)
|
| 241 |
-
is_syn, syn_score = detect_syntactic_head(attention_matrix, config)
|
| 242 |
-
|
| 243 |
-
# Assign category based on highest-scoring pattern
|
| 244 |
-
# Priority: previous_token > first_token > bow > syntactic > other
|
| 245 |
-
scores = {
|
| 246 |
-
'previous_token': prev_score if is_prev else 0.0,
|
| 247 |
-
'first_token': first_score if is_first else 0.0,
|
| 248 |
-
'bow': bow_score if is_bow else 0.0,
|
| 249 |
-
'syntactic': syn_score if is_syn else 0.0
|
| 250 |
-
}
|
| 251 |
-
|
| 252 |
-
# Determine primary category
|
| 253 |
-
if is_prev:
|
| 254 |
-
category = 'previous_token'
|
| 255 |
-
elif is_first:
|
| 256 |
-
category = 'first_token'
|
| 257 |
-
elif is_bow:
|
| 258 |
-
category = 'bow'
|
| 259 |
-
elif is_syn:
|
| 260 |
-
category = 'syntactic'
|
| 261 |
else:
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
return {
|
| 265 |
-
'layer': layer_idx,
|
| 266 |
-
'head': head_idx,
|
| 267 |
-
'category': category,
|
| 268 |
-
'scores': scores,
|
| 269 |
-
'label': f"L{layer_idx}-H{head_idx}"
|
| 270 |
-
}
|
| 271 |
|
| 272 |
|
| 273 |
-
def
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
"""
|
| 276 |
-
|
|
|
|
| 277 |
|
| 278 |
Args:
|
| 279 |
activation_data: Output from execute_forward_pass with attention data
|
| 280 |
-
|
| 281 |
|
| 282 |
Returns:
|
| 283 |
-
|
| 284 |
{
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
}
|
|
|
|
| 291 |
"""
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
# Initialize result dict
|
| 296 |
-
categorized = {
|
| 297 |
-
'previous_token': [],
|
| 298 |
-
'first_token': [],
|
| 299 |
-
'bow': [],
|
| 300 |
-
'syntactic': [],
|
| 301 |
-
'other': []
|
| 302 |
-
}
|
| 303 |
|
|
|
|
| 304 |
attention_outputs = activation_data.get('attention_outputs', {})
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
|
| 308 |
-
# Process each layer's attention
|
| 309 |
for module_name, output_dict in attention_outputs.items():
|
| 310 |
-
# Extract layer number from module name
|
| 311 |
numbers = re.findall(r'\d+', module_name)
|
| 312 |
if not numbers:
|
| 313 |
continue
|
|
@@ -318,153 +261,82 @@ def categorize_all_heads(activation_data: Dict[str, Any],
|
|
| 318 |
if not isinstance(attention_output, list) or len(attention_output) < 2:
|
| 319 |
continue
|
| 320 |
|
| 321 |
-
#
|
| 322 |
attention_weights = torch.tensor(attention_output[1])
|
| 323 |
-
|
| 324 |
-
# Process each head
|
| 325 |
num_heads = attention_weights.shape[1]
|
| 326 |
-
seq_len = attention_weights.shape[2]
|
| 327 |
-
|
| 328 |
-
if seq_len < config.min_seq_len:
|
| 329 |
-
continue
|
| 330 |
|
| 331 |
for head_idx in range(num_heads):
|
| 332 |
-
|
| 333 |
-
head_attention = attention_weights[0, head_idx, :, :]
|
| 334 |
-
|
| 335 |
-
# Categorize this head
|
| 336 |
-
head_info = categorize_attention_head(head_attention, layer_idx, head_idx, config)
|
| 337 |
-
|
| 338 |
-
# Add to appropriate category list
|
| 339 |
-
category = head_info['category']
|
| 340 |
-
categorized[category].append(head_info)
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
def categorize_single_layer_heads(activation_data: Dict[str, Any],
|
| 346 |
-
layer_num: int,
|
| 347 |
-
config: Optional[HeadCategorizationConfig] = None) -> Dict[str, List[Dict[str, Any]]]:
|
| 348 |
-
"""
|
| 349 |
-
Categorize attention heads for a single layer.
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
Returns:
|
| 357 |
-
Dictionary mapping category names to lists of head info dicts for this layer only:
|
| 358 |
-
{
|
| 359 |
-
'previous_token': [...],
|
| 360 |
-
'first_token': [...],
|
| 361 |
-
'bow': [...],
|
| 362 |
-
'syntactic': [...],
|
| 363 |
-
'other': [...]
|
| 364 |
-
}
|
| 365 |
-
"""
|
| 366 |
-
if config is None:
|
| 367 |
-
config = HeadCategorizationConfig()
|
| 368 |
-
|
| 369 |
-
# Initialize result dict
|
| 370 |
-
categorized = {
|
| 371 |
-
'previous_token': [],
|
| 372 |
-
'first_token': [],
|
| 373 |
-
'bow': [],
|
| 374 |
-
'syntactic': [],
|
| 375 |
-
'other': []
|
| 376 |
}
|
| 377 |
|
| 378 |
-
|
| 379 |
-
if not attention_outputs:
|
| 380 |
-
return categorized
|
| 381 |
|
| 382 |
-
#
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
if
|
| 388 |
continue
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
break
|
| 393 |
-
|
| 394 |
-
if not target_module:
|
| 395 |
-
return categorized
|
| 396 |
-
|
| 397 |
-
output_dict = attention_outputs[target_module]
|
| 398 |
-
attention_output = output_dict.get('output')
|
| 399 |
-
|
| 400 |
-
if not isinstance(attention_output, list) or len(attention_output) < 2:
|
| 401 |
-
return categorized
|
| 402 |
-
|
| 403 |
-
# Get attention weights: [batch, heads, seq_len, seq_len]
|
| 404 |
-
attention_weights = torch.tensor(attention_output[1])
|
| 405 |
-
|
| 406 |
-
# Process each head
|
| 407 |
-
num_heads = attention_weights.shape[1]
|
| 408 |
-
seq_len = attention_weights.shape[2]
|
| 409 |
-
|
| 410 |
-
if seq_len < config.min_seq_len:
|
| 411 |
-
return categorized
|
| 412 |
-
|
| 413 |
-
for head_idx in range(num_heads):
|
| 414 |
-
# Extract attention matrix for this head: [seq_len, seq_len]
|
| 415 |
-
head_attention = attention_weights[0, head_idx, :, :]
|
| 416 |
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
Args:
|
| 432 |
-
categorized_heads: Output from categorize_all_heads or categorize_single_layer_heads
|
| 433 |
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
| 443 |
}
|
| 444 |
|
| 445 |
-
|
| 446 |
-
total_heads = sum(len(heads) for heads in categorized_heads.values())
|
| 447 |
-
|
| 448 |
-
summary.append(f"Total Heads: {total_heads}\n")
|
| 449 |
-
summary.append("=" * 60)
|
| 450 |
-
|
| 451 |
-
for category, display_name in category_names.items():
|
| 452 |
-
heads = categorized_heads.get(category, [])
|
| 453 |
-
summary.append(f"\n{display_name}: {len(heads)} heads")
|
| 454 |
-
|
| 455 |
-
if heads:
|
| 456 |
-
# Group by layer
|
| 457 |
-
heads_by_layer = {}
|
| 458 |
-
for head_info in heads:
|
| 459 |
-
layer = head_info['layer']
|
| 460 |
-
if layer not in heads_by_layer:
|
| 461 |
-
heads_by_layer[layer] = []
|
| 462 |
-
heads_by_layer[layer].append(head_info['head'])
|
| 463 |
-
|
| 464 |
-
# Format by layer
|
| 465 |
-
for layer in sorted(heads_by_layer.keys()):
|
| 466 |
-
head_indices = sorted(heads_by_layer[layer])
|
| 467 |
-
summary.append(f" Layer {layer}: Heads {head_indices}")
|
| 468 |
-
|
| 469 |
-
return "\n".join(summary)
|
| 470 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
Attention Head Detection and Categorization
|
| 3 |
|
| 4 |
+
Loads pre-computed head category data from JSON (produced by scripts/analyze_heads.py)
|
| 5 |
+
and performs lightweight runtime verification of head activation on the current input.
|
| 6 |
+
|
| 7 |
+
Categories:
|
| 8 |
+
- Previous Token: attends to the immediately preceding token
|
| 9 |
+
- Induction: completes repeated patterns ([A][B]...[A] → [B])
|
| 10 |
+
- Duplicate Token: attends to earlier occurrences of the same token
|
| 11 |
+
- Positional / First-Token: attends to the first token or positional patterns
|
| 12 |
+
- Diffuse / Spread: high-entropy, evenly distributed attention
|
| 13 |
- Other: heads that don't fit the above categories
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
import torch
|
| 19 |
import numpy as np
|
| 20 |
from typing import Dict, List, Tuple, Optional, Any
|
| 21 |
import re
|
| 22 |
+
from pathlib import Path
|
| 23 |
|
| 24 |
|
| 25 |
+
# Path to the pre-computed head categories JSON
|
| 26 |
+
_JSON_PATH = Path(__file__).parent / "head_categories.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# Cache for loaded JSON data (avoids re-reading per request)
|
| 29 |
+
_category_cache: Dict[str, Any] = {}
|
| 30 |
|
| 31 |
+
|
| 32 |
+
def load_head_categories(model_name: str) -> Optional[Dict[str, Any]]:
|
| 33 |
"""
|
| 34 |
+
Load pre-computed head category data for a model.
|
| 35 |
|
| 36 |
Args:
|
| 37 |
+
model_name: HuggingFace model name (e.g., "gpt2", "EleutherAI/pythia-70m")
|
| 38 |
|
| 39 |
Returns:
|
| 40 |
+
Dict with model's category data, or None if model not analyzed.
|
| 41 |
+
Structure: {
|
| 42 |
+
"model_name": str,
|
| 43 |
+
"num_layers": int,
|
| 44 |
+
"num_heads": int,
|
| 45 |
+
"categories": { category_name: { "top_heads": [...], ... } },
|
| 46 |
+
...
|
| 47 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
+
global _category_cache
|
| 50 |
|
| 51 |
+
# Check cache first
|
| 52 |
+
if model_name in _category_cache:
|
| 53 |
+
return _category_cache[model_name]
|
| 54 |
|
| 55 |
+
# Load JSON
|
| 56 |
+
if not _JSON_PATH.exists():
|
| 57 |
+
return None
|
|
|
|
| 58 |
|
| 59 |
+
try:
|
| 60 |
+
with open(_JSON_PATH, 'r') as f:
|
| 61 |
+
all_data = json.load(f)
|
| 62 |
+
except (json.JSONDecodeError, IOError):
|
| 63 |
+
return None
|
| 64 |
|
| 65 |
+
# Try exact match first, then common aliases
|
| 66 |
+
model_data = all_data.get(model_name)
|
| 67 |
+
if model_data is None:
|
| 68 |
+
# Try short name (e.g., "gpt2" for "openai-community/gpt2")
|
| 69 |
+
short_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
| 70 |
+
model_data = all_data.get(short_name)
|
| 71 |
|
| 72 |
+
if model_data is not None:
|
| 73 |
+
_category_cache[model_name] = model_data
|
| 74 |
|
| 75 |
+
return model_data
|
| 76 |
+
|
| 77 |
|
| 78 |
+
def clear_category_cache():
|
| 79 |
+
"""Clear the loaded category cache (useful for testing)."""
|
| 80 |
+
global _category_cache
|
| 81 |
+
_category_cache = {}
|
| 82 |
|
| 83 |
+
|
| 84 |
+
def _compute_attention_entropy(attention_weights: torch.Tensor) -> float:
|
| 85 |
"""
|
| 86 |
+
Compute normalized entropy of an attention distribution.
|
| 87 |
|
| 88 |
Args:
|
| 89 |
+
attention_weights: [seq_len] tensor of attention weights for one position
|
|
|
|
| 90 |
|
| 91 |
Returns:
|
| 92 |
+
Normalized entropy (0.0 to 1.0). 1.0 = perfectly uniform, 0.0 = fully peaked.
|
| 93 |
"""
|
| 94 |
+
epsilon = 1e-10
|
| 95 |
+
weights = attention_weights + epsilon
|
| 96 |
+
entropy = -torch.sum(weights * torch.log(weights))
|
| 97 |
+
max_entropy = np.log(len(weights))
|
| 98 |
+
return (entropy / max_entropy).item() if max_entropy > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
+
def _find_repeated_tokens(token_ids: List[int]) -> Dict[int, List[int]]:
|
| 102 |
"""
|
| 103 |
+
Find tokens that appear more than once and their positions.
|
| 104 |
|
| 105 |
Args:
|
| 106 |
+
token_ids: List of token IDs in the sequence
|
|
|
|
| 107 |
|
| 108 |
Returns:
|
| 109 |
+
Dict mapping token_id -> list of positions where it appears (only for repeated tokens)
|
| 110 |
"""
|
| 111 |
+
positions: Dict[int, List[int]] = {}
|
| 112 |
+
for i, tid in enumerate(token_ids):
|
| 113 |
+
if tid not in positions:
|
| 114 |
+
positions[tid] = []
|
| 115 |
+
positions[tid].append(i)
|
| 116 |
+
|
| 117 |
+
# Keep only tokens that appear more than once
|
| 118 |
+
return {tid: pos_list for tid, pos_list in positions.items() if len(pos_list) > 1}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
+
def verify_head_activation(
|
| 122 |
+
attn_matrix: torch.Tensor,
|
| 123 |
+
token_ids: List[int],
|
| 124 |
+
category: str
|
| 125 |
+
) -> float:
|
| 126 |
"""
|
| 127 |
+
Verify whether a head's known role is active on the current input.
|
|
|
|
|
|
|
| 128 |
|
| 129 |
Args:
|
| 130 |
+
attn_matrix: [seq_len, seq_len] attention weights for this head
|
| 131 |
+
token_ids: List of token IDs in the input
|
| 132 |
+
category: Category name (previous_token, induction, duplicate_token, positional, diffuse)
|
| 133 |
|
| 134 |
Returns:
|
| 135 |
+
Activation score (0.0 to 1.0). 0.0 means the role is not triggered on this input.
|
| 136 |
"""
|
| 137 |
+
seq_len = attn_matrix.shape[0]
|
| 138 |
+
|
| 139 |
+
if seq_len < 2:
|
| 140 |
+
return 0.0
|
| 141 |
+
|
| 142 |
+
if category == "previous_token":
|
| 143 |
+
# Mean of diagonal-1 values: how much each position attends to the previous position
|
| 144 |
+
prev_token_attentions = []
|
| 145 |
+
for i in range(1, seq_len):
|
| 146 |
+
prev_token_attentions.append(attn_matrix[i, i - 1].item())
|
| 147 |
+
return float(np.mean(prev_token_attentions)) if prev_token_attentions else 0.0
|
| 148 |
+
|
| 149 |
+
elif category == "induction":
|
| 150 |
+
# Induction pattern: [A][B]...[A] → attend to [B]
|
| 151 |
+
# For each repeated token at position i where token[i]==token[j] (j < i),
|
| 152 |
+
# check if position i attends to position j+1
|
| 153 |
+
repeated = _find_repeated_tokens(token_ids)
|
| 154 |
+
if not repeated:
|
| 155 |
+
return 0.0 # No repetition → gray out
|
| 156 |
|
| 157 |
+
induction_scores = []
|
| 158 |
+
for tid, positions in repeated.items():
|
| 159 |
+
for k in range(1, len(positions)):
|
| 160 |
+
current_pos = positions[k] # Later occurrence
|
| 161 |
+
for prev_idx in range(k):
|
| 162 |
+
prev_pos = positions[prev_idx] # Earlier occurrence
|
| 163 |
+
target_pos = prev_pos + 1 # The token AFTER the earlier occurrence
|
| 164 |
+
if target_pos < seq_len and current_pos < seq_len:
|
| 165 |
+
induction_scores.append(attn_matrix[current_pos, target_pos].item())
|
| 166 |
+
|
| 167 |
+
return float(np.mean(induction_scores)) if induction_scores else 0.0
|
| 168 |
|
| 169 |
+
elif category == "duplicate_token":
|
| 170 |
+
# Check if later occurrences attend to earlier occurrences of the same token
|
| 171 |
+
repeated = _find_repeated_tokens(token_ids)
|
| 172 |
+
if not repeated:
|
| 173 |
+
return 0.0 # No duplicates → gray out
|
| 174 |
+
|
| 175 |
+
dup_scores = []
|
| 176 |
+
for tid, positions in repeated.items():
|
| 177 |
+
for k in range(1, len(positions)):
|
| 178 |
+
later_pos = positions[k]
|
| 179 |
+
# Sum attention to all earlier occurrences
|
| 180 |
+
earlier_attention = sum(
|
| 181 |
+
attn_matrix[later_pos, positions[j]].item()
|
| 182 |
+
for j in range(k)
|
| 183 |
+
)
|
| 184 |
+
dup_scores.append(earlier_attention)
|
| 185 |
+
|
| 186 |
+
return float(np.mean(dup_scores)) if dup_scores else 0.0
|
| 187 |
|
| 188 |
+
elif category == "positional":
|
| 189 |
+
# Mean of column-0 attention (how much each position attends to the first token)
|
| 190 |
+
first_token_attention = attn_matrix[:, 0].mean().item()
|
| 191 |
+
return first_token_attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
elif category == "diffuse":
|
| 194 |
+
# Average normalized entropy across all positions
|
| 195 |
+
entropies = []
|
| 196 |
+
for i in range(seq_len):
|
| 197 |
+
entropies.append(_compute_attention_entropy(attn_matrix[i]))
|
| 198 |
+
return float(np.mean(entropies)) if entropies else 0.0
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
else:
|
| 201 |
+
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
+
def get_active_head_summary(
|
| 205 |
+
activation_data: Dict[str, Any],
|
| 206 |
+
model_name: str
|
| 207 |
+
) -> Optional[Dict[str, Any]]:
|
| 208 |
"""
|
| 209 |
+
Main entry point: load categories from JSON, verify each head on the current input,
|
| 210 |
+
and return a UI-ready structure.
|
| 211 |
|
| 212 |
Args:
|
| 213 |
activation_data: Output from execute_forward_pass with attention data
|
| 214 |
+
model_name: HuggingFace model name
|
| 215 |
|
| 216 |
Returns:
|
| 217 |
+
Dict with structure:
|
| 218 |
{
|
| 219 |
+
"model_available": True,
|
| 220 |
+
"categories": {
|
| 221 |
+
"previous_token": {
|
| 222 |
+
"display_name": str,
|
| 223 |
+
"description": str,
|
| 224 |
+
"educational_text": str,
|
| 225 |
+
"icon": str,
|
| 226 |
+
"requires_repetition": bool,
|
| 227 |
+
"suggested_prompt": str or None,
|
| 228 |
+
"is_applicable": bool, # False if requires_repetition but no repeats
|
| 229 |
+
"heads": [
|
| 230 |
+
{"layer": int, "head": int, "offline_score": float,
|
| 231 |
+
"activation_score": float, "is_active": bool, "label": str}
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
...
|
| 235 |
+
}
|
| 236 |
}
|
| 237 |
+
Returns None if model not in JSON.
|
| 238 |
"""
|
| 239 |
+
model_data = load_head_categories(model_name)
|
| 240 |
+
if model_data is None:
|
| 241 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# Extract attention weights and token IDs from activation data
|
| 244 |
attention_outputs = activation_data.get('attention_outputs', {})
|
| 245 |
+
input_ids = activation_data.get('input_ids', [[]])[0]
|
| 246 |
+
|
| 247 |
+
if not attention_outputs or not input_ids:
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
# Build a lookup: (layer, head) → attention_matrix [seq_len, seq_len]
|
| 251 |
+
head_attention_lookup: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 252 |
|
|
|
|
| 253 |
for module_name, output_dict in attention_outputs.items():
|
|
|
|
| 254 |
numbers = re.findall(r'\d+', module_name)
|
| 255 |
if not numbers:
|
| 256 |
continue
|
|
|
|
| 261 |
if not isinstance(attention_output, list) or len(attention_output) < 2:
|
| 262 |
continue
|
| 263 |
|
| 264 |
+
# attention_output[1] is [batch, heads, seq_len, seq_len]
|
| 265 |
attention_weights = torch.tensor(attention_output[1])
|
|
|
|
|
|
|
| 266 |
num_heads = attention_weights.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
for head_idx in range(num_heads):
|
| 269 |
+
head_attention_lookup[(layer_idx, head_idx)] = attention_weights[0, head_idx, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
# Check if input has repeated tokens (needed for applicability check)
|
| 272 |
+
repeated_tokens = _find_repeated_tokens(input_ids)
|
| 273 |
+
has_repetition = len(repeated_tokens) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
# Build result
|
| 276 |
+
result = {
|
| 277 |
+
"model_available": True,
|
| 278 |
+
"categories": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
}
|
| 280 |
|
| 281 |
+
categories = model_data.get("categories", {})
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
# Define category order for consistent display
|
| 284 |
+
category_order = ["previous_token", "induction", "duplicate_token", "positional", "diffuse"]
|
| 285 |
+
|
| 286 |
+
for cat_key in category_order:
|
| 287 |
+
cat_info = categories.get(cat_key)
|
| 288 |
+
if cat_info is None:
|
| 289 |
continue
|
| 290 |
|
| 291 |
+
requires_repetition = cat_info.get("requires_repetition", False)
|
| 292 |
+
is_applicable = not requires_repetition or has_repetition
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
heads_result = []
|
| 295 |
+
for head_entry in cat_info.get("top_heads", []):
|
| 296 |
+
layer = head_entry["layer"]
|
| 297 |
+
head = head_entry["head"]
|
| 298 |
+
offline_score = head_entry["score"]
|
| 299 |
+
|
| 300 |
+
# Get activation score on current input
|
| 301 |
+
attn_matrix = head_attention_lookup.get((layer, head))
|
| 302 |
+
if attn_matrix is not None and is_applicable:
|
| 303 |
+
activation_score = verify_head_activation(attn_matrix, input_ids, cat_key)
|
| 304 |
+
else:
|
| 305 |
+
activation_score = 0.0
|
| 306 |
+
|
| 307 |
+
# A head is "active" if its activation score exceeds a minimum threshold
|
| 308 |
+
is_active = activation_score > 0.1 and is_applicable
|
| 309 |
+
|
| 310 |
+
heads_result.append({
|
| 311 |
+
"layer": layer,
|
| 312 |
+
"head": head,
|
| 313 |
+
"offline_score": offline_score,
|
| 314 |
+
"activation_score": round(activation_score, 3),
|
| 315 |
+
"is_active": is_active,
|
| 316 |
+
"label": f"L{layer}-H{head}"
|
| 317 |
+
})
|
| 318 |
|
| 319 |
+
result["categories"][cat_key] = {
|
| 320 |
+
"display_name": cat_info.get("display_name", cat_key),
|
| 321 |
+
"description": cat_info.get("description", ""),
|
| 322 |
+
"educational_text": cat_info.get("educational_text", ""),
|
| 323 |
+
"icon": cat_info.get("icon", "circle"),
|
| 324 |
+
"requires_repetition": requires_repetition,
|
| 325 |
+
"suggested_prompt": cat_info.get("suggested_prompt"),
|
| 326 |
+
"is_applicable": is_applicable,
|
| 327 |
+
"heads": heads_result
|
| 328 |
+
}
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
# Add "Other" category (heads not claimed by any top list)
|
| 331 |
+
result["categories"]["other"] = {
|
| 332 |
+
"display_name": "Other / Unclassified",
|
| 333 |
+
"description": "Heads whose patterns don't fit the simple categories above",
|
| 334 |
+
"educational_text": "This head's pattern doesn't fit our simple categories — it may be doing something more complex or context-dependent.",
|
| 335 |
+
"icon": "question-circle",
|
| 336 |
+
"requires_repetition": False,
|
| 337 |
+
"suggested_prompt": None,
|
| 338 |
+
"is_applicable": True,
|
| 339 |
+
"heads": [] # We don't enumerate all "other" heads to keep the UI clean
|
| 340 |
}
|
| 341 |
|
| 342 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/model_patterns.py
CHANGED
|
@@ -1421,23 +1421,4 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
|
|
| 1421 |
return f"<p>Error generating visualization: {str(e)}</p>"
|
| 1422 |
|
| 1423 |
|
| 1424 |
-
|
| 1425 |
-
"""
|
| 1426 |
-
Get counts of attention heads in each category.
|
| 1427 |
-
|
| 1428 |
-
Useful for UI display showing the distribution of head types.
|
| 1429 |
-
|
| 1430 |
-
Args:
|
| 1431 |
-
activation_data: Output from execute_forward_pass with attention data
|
| 1432 |
-
|
| 1433 |
-
Returns:
|
| 1434 |
-
Dict mapping category name to count of heads in that category
|
| 1435 |
-
"""
|
| 1436 |
-
from .head_detection import categorize_all_heads
|
| 1437 |
-
|
| 1438 |
-
try:
|
| 1439 |
-
categories = categorize_all_heads(activation_data)
|
| 1440 |
-
return {category: len(heads) for category, heads in categories.items()}
|
| 1441 |
-
except Exception as e:
|
| 1442 |
-
print(f"Warning: Could not categorize heads: {e}")
|
| 1443 |
-
return {}
|
|
|
|
| 1421 |
return f"<p>Error generating visualization: {str(e)}</p>"
|
| 1422 |
|
| 1423 |
|
| 1424 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|