Spaces:
Sleeping
Sleeping
Commit ·
2965a7d
1
Parent(s): aabd66e
chore: remove unused code, imports, and deprecated functions
Browse files- Delete components/tokenization_panel.py (superseded by pipeline.py)
- Remove 6 unused imports from app.py
- Remove deprecated _get_top_attended_tokens() and references
- Remove unused create_stage_summary() from pipeline.py
- Remove 7 unused utility functions from model_patterns.py and beam_search.py
- Update utils/__init__.py exports and README.md
Total: ~1,087 lines removed. All 81 tests pass.
Co-authored-by: Cursor <cursoragent@cursor.com>
- README.md +1 -2
- app.py +1 -10
- components/pipeline.py +0 -23
- components/tokenization_panel.py +0 -302
- todo.md +19 -0
- utils/__init__.py +3 -14
- utils/beam_search.py +1 -90
- utils/model_patterns.py +1 -646
README.md
CHANGED
|
@@ -77,10 +77,9 @@ Open your browser and navigate to `http://127.0.0.1:8050/`.
|
|
| 77 |
|
| 78 |
* `app.py`: Main application entry point and layout orchestration.
|
| 79 |
* `components/`: Modular UI components.
|
| 80 |
-
* `pipeline.py`: The core 5-stage visualization.
|
| 81 |
* `investigation_panel.py`: Ablation and attribution interfaces.
|
| 82 |
* `ablation_panel.py`: Specific logic for head ablation UI.
|
| 83 |
-
* `tokenization_panel.py`: Token visualization.
|
| 84 |
* `utils/`: Backend logic and helper functions.
|
| 85 |
* `model_patterns.py`: Activation capture and hooking logic.
|
| 86 |
* `model_config.py`: Model family definitions and auto-detection.
|
|
|
|
| 77 |
|
| 78 |
* `app.py`: Main application entry point and layout orchestration.
|
| 79 |
* `components/`: Modular UI components.
|
| 80 |
+
* `pipeline.py`: The core 5-stage visualization with tokenization display.
|
| 81 |
* `investigation_panel.py`: Ablation and attribution interfaces.
|
| 82 |
* `ablation_panel.py`: Specific logic for head ablation UI.
|
|
|
|
| 83 |
* `utils/`: Backend logic and helper functions.
|
| 84 |
* `model_patterns.py`: Activation capture and hooking logic.
|
| 85 |
* `model_config.py`: Model family definitions and auto-detection.
|
app.py
CHANGED
|
@@ -10,11 +10,7 @@ from dash import html, dcc, Input, Output, State, callback, no_update, ALL, MATC
|
|
| 10 |
import json
|
| 11 |
import torch
|
| 12 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 13 |
-
|
| 14 |
-
execute_forward_pass_with_head_ablation,
|
| 15 |
-
execute_forward_pass_with_multi_layer_head_ablation,
|
| 16 |
-
evaluate_sequence_ablation, score_sequence,
|
| 17 |
-
get_head_category_counts, generate_bertviz_model_view_html)
|
| 18 |
from utils.head_detection import categorize_all_heads
|
| 19 |
from utils.model_config import get_auto_selections
|
| 20 |
from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
|
|
@@ -523,11 +519,6 @@ def update_pipeline_content(activation_data, model_name):
|
|
| 523 |
else:
|
| 524 |
top_tokens = []
|
| 525 |
|
| 526 |
-
# Get attention info from first layer
|
| 527 |
-
top_attended = None
|
| 528 |
-
if layer_data:
|
| 529 |
-
top_attended = layer_data[0].get('top_attended_tokens', [])
|
| 530 |
-
|
| 531 |
# Generate BertViz HTML
|
| 532 |
from utils import generate_bertviz_html
|
| 533 |
attention_html = None
|
|
|
|
| 10 |
import json
|
| 11 |
import torch
|
| 12 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 13 |
+
perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from utils.head_detection import categorize_all_heads
|
| 15 |
from utils.model_config import get_auto_selections
|
| 16 |
from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
|
|
|
|
| 519 |
else:
|
| 520 |
top_tokens = []
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
# Generate BertViz HTML
|
| 523 |
from utils import generate_bertviz_html
|
| 524 |
attention_html = None
|
components/pipeline.py
CHANGED
|
@@ -809,26 +809,3 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
|
|
| 809 |
|
| 810 |
return html.Div(content_items)
|
| 811 |
|
| 812 |
-
|
| 813 |
-
def create_stage_summary(stage_id, activation_data=None, model_config=None):
|
| 814 |
-
"""
|
| 815 |
-
Generate summary text for a stage (shown when collapsed).
|
| 816 |
-
|
| 817 |
-
Args:
|
| 818 |
-
stage_id: Stage identifier ('tokenization', 'embedding', etc.)
|
| 819 |
-
activation_data: Optional activation data from forward pass
|
| 820 |
-
model_config: Optional model configuration
|
| 821 |
-
"""
|
| 822 |
-
if not activation_data:
|
| 823 |
-
return "Awaiting input..."
|
| 824 |
-
|
| 825 |
-
summaries = {
|
| 826 |
-
'tokenization': lambda: f"{len(activation_data.get('input_ids', [[]])[0])} tokens",
|
| 827 |
-
'embedding': lambda: f"{model_config.hidden_size if model_config else 768}-dim vectors" if model_config else "Vectors ready",
|
| 828 |
-
'attention': lambda: f"{model_config.num_attention_heads if model_config else 12} heads" if model_config else "Context gathered",
|
| 829 |
-
'mlp': lambda: f"{model_config.num_hidden_layers if model_config else 12} layers" if model_config else "Transformations applied",
|
| 830 |
-
'output': lambda: f"→ {activation_data.get('actual_output', {}).get('token', '?')}" if activation_data.get('actual_output') else "Output computed"
|
| 831 |
-
}
|
| 832 |
-
|
| 833 |
-
return summaries.get(stage_id, lambda: "")()
|
| 834 |
-
|
|
|
|
| 809 |
|
| 810 |
return html.Div(content_items)
|
| 811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
components/tokenization_panel.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Tokenization panel component for visualizing the tokenization process.
|
| 3 |
-
|
| 4 |
-
Displays tokens in vertical rows: [token] → [ID] → [embedding] per token.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dash import html, dcc
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def create_static_tokenization_diagram():
|
| 11 |
-
"""Create static HTML/CSS diagram showing example tokenization flow."""
|
| 12 |
-
return html.Div([
|
| 13 |
-
html.H4("Example: How text becomes model input",
|
| 14 |
-
style={'marginBottom': '1rem', 'color': '#495057', 'fontSize': '16px'}),
|
| 15 |
-
|
| 16 |
-
# Example flow: text -> tokens -> IDs -> embeddings
|
| 17 |
-
html.Div([
|
| 18 |
-
# Input text
|
| 19 |
-
html.Div([
|
| 20 |
-
html.Div('"Hello world"',
|
| 21 |
-
className='example-text',
|
| 22 |
-
style={'padding': '8px 12px', 'backgroundColor': '#e9ecef',
|
| 23 |
-
'borderRadius': '4px', 'fontFamily': 'monospace'})
|
| 24 |
-
], style={'flex': '1', 'textAlign': 'center'}),
|
| 25 |
-
|
| 26 |
-
html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
|
| 27 |
-
|
| 28 |
-
# Tokens
|
| 29 |
-
html.Div([
|
| 30 |
-
html.Div([
|
| 31 |
-
html.Span('["Hello"', style={'fontFamily': 'monospace'}),
|
| 32 |
-
html.Span(', " world"]', style={'fontFamily': 'monospace'})
|
| 33 |
-
], style={'padding': '8px 12px', 'backgroundColor': '#d4edff',
|
| 34 |
-
'borderRadius': '4px'})
|
| 35 |
-
], style={'flex': '1', 'textAlign': 'center'}),
|
| 36 |
-
|
| 37 |
-
html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
|
| 38 |
-
|
| 39 |
-
# IDs
|
| 40 |
-
html.Div([
|
| 41 |
-
html.Div('[1234, 5678]',
|
| 42 |
-
style={'padding': '8px 12px', 'backgroundColor': '#ffe5d4',
|
| 43 |
-
'borderRadius': '4px', 'fontFamily': 'monospace'})
|
| 44 |
-
], style={'flex': '1', 'textAlign': 'center'}),
|
| 45 |
-
|
| 46 |
-
html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
|
| 47 |
-
|
| 48 |
-
# Embeddings
|
| 49 |
-
html.Div([
|
| 50 |
-
html.Div('[[ ... ], [ ... ]]',
|
| 51 |
-
style={'padding': '8px 12px', 'backgroundColor': '#e5d4ff',
|
| 52 |
-
'borderRadius': '4px', 'fontFamily': 'monospace'})
|
| 53 |
-
], style={'flex': '1', 'textAlign': 'center'})
|
| 54 |
-
|
| 55 |
-
], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center',
|
| 56 |
-
'padding': '1rem', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
|
| 57 |
-
'border': '1px solid #dee2e6'})
|
| 58 |
-
], style={'marginBottom': '2rem'})
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def create_tokenization_panel():
|
| 62 |
-
"""Create the tokenization visualization panel with three columns."""
|
| 63 |
-
return html.Div([
|
| 64 |
-
# Section title and subtitle
|
| 65 |
-
html.Div([
|
| 66 |
-
html.H3("Step 1: Tokenization & Embedding",
|
| 67 |
-
className="section-title",
|
| 68 |
-
style={'marginBottom': '0.5rem'}),
|
| 69 |
-
html.P("This is the first step in processing text through a transformer model. "
|
| 70 |
-
"The input text is broken into tokens, converted to IDs, and embedded as vectors.",
|
| 71 |
-
style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '1.5rem'})
|
| 72 |
-
]),
|
| 73 |
-
|
| 74 |
-
# Static example diagram (always visible)
|
| 75 |
-
create_static_tokenization_diagram(),
|
| 76 |
-
|
| 77 |
-
# Dynamic tokenization display container (populated by callback)
|
| 78 |
-
html.Div(id='tokenization-display-container', children=[
|
| 79 |
-
# This will be populated after analysis runs
|
| 80 |
-
])
|
| 81 |
-
|
| 82 |
-
], id='tokenization-panel', style={'display': 'none'}, className='tokenization-section')
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def create_tokenization_display(tokens_list, token_ids_list, color_palette=None):
|
| 86 |
-
"""
|
| 87 |
-
Create a vertical tokenization display showing each token's flow.
|
| 88 |
-
|
| 89 |
-
Args:
|
| 90 |
-
tokens_list: List of token strings
|
| 91 |
-
token_ids_list: List of token IDs
|
| 92 |
-
color_palette: Optional list of colors for each token (auto-generated if None)
|
| 93 |
-
|
| 94 |
-
Returns:
|
| 95 |
-
Dash HTML component with vertical token rows: [token] → [ID] → [embedding]
|
| 96 |
-
"""
|
| 97 |
-
if color_palette is None:
|
| 98 |
-
# Generate distinct colors for each token
|
| 99 |
-
color_palette = generate_token_colors(len(tokens_list))
|
| 100 |
-
|
| 101 |
-
preview_token = tokens_list[0] if tokens_list else ""
|
| 102 |
-
preview_id = token_ids_list[0] if token_ids_list else ""
|
| 103 |
-
preview_color = color_palette[0] if color_palette else '#f8f9fa'
|
| 104 |
-
|
| 105 |
-
return html.Details([
|
| 106 |
-
html.Summary(
|
| 107 |
-
html.Div([
|
| 108 |
-
html.Span("Tokenization preview:", style={'color': '#6c757d', 'fontSize': '13px'}),
|
| 109 |
-
html.Span(preview_token, style={
|
| 110 |
-
'padding': '4px 8px',
|
| 111 |
-
'backgroundColor': preview_color,
|
| 112 |
-
'borderRadius': '4px',
|
| 113 |
-
'fontFamily': 'monospace',
|
| 114 |
-
'fontSize': '12px'
|
| 115 |
-
}),
|
| 116 |
-
html.Span('→', style={'color': '#6c757d'}),
|
| 117 |
-
html.Span(str(preview_id), style={
|
| 118 |
-
'padding': '4px 8px',
|
| 119 |
-
'backgroundColor': '#ffe5d4',
|
| 120 |
-
'borderRadius': '4px',
|
| 121 |
-
'fontFamily': 'monospace',
|
| 122 |
-
'fontSize': '12px'
|
| 123 |
-
}),
|
| 124 |
-
html.Span('→', style={'color': '#6c757d'}),
|
| 125 |
-
html.Span('[ ... ]', style={
|
| 126 |
-
'padding': '4px 8px',
|
| 127 |
-
'backgroundColor': '#e5d4ff',
|
| 128 |
-
'borderRadius': '4px',
|
| 129 |
-
'fontFamily': 'monospace',
|
| 130 |
-
'fontSize': '12px'
|
| 131 |
-
}),
|
| 132 |
-
html.Span('...', style={'color': '#6c757d'}),
|
| 133 |
-
html.Span("Expand", style={'marginLeft': 'auto', 'color': '#667eea', 'fontWeight': '500'})
|
| 134 |
-
], style={'display': 'flex', 'alignItems': 'center', 'gap': '8px', 'flexWrap': 'wrap'})
|
| 135 |
-
),
|
| 136 |
-
|
| 137 |
-
html.Div([
|
| 138 |
-
html.H4("Full Tokenization:",
|
| 139 |
-
style={'marginTop': '1.5rem', 'marginBottom': '1rem',
|
| 140 |
-
'color': '#495057', 'fontSize': '16px'}),
|
| 141 |
-
|
| 142 |
-
# Column headers row
|
| 143 |
-
html.Div([
|
| 144 |
-
html.Span("Token", className='token-header',
|
| 145 |
-
style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'}),
|
| 146 |
-
html.Span("", style={'width': '32px'}), # Arrow spacer
|
| 147 |
-
html.Span("ID", className='token-header',
|
| 148 |
-
style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'}),
|
| 149 |
-
html.Span("", style={'width': '32px'}), # Arrow spacer
|
| 150 |
-
html.Span("Embedding", className='token-header',
|
| 151 |
-
style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'})
|
| 152 |
-
], className='tokenization-header-row',
|
| 153 |
-
style={'display': 'flex', 'alignItems': 'center', 'gap': '4px',
|
| 154 |
-
'marginBottom': '0.75rem', 'paddingBottom': '0.5rem',
|
| 155 |
-
'borderBottom': '1px solid #e9ecef'}),
|
| 156 |
-
|
| 157 |
-
# Vertical token rows - each row shows [token] → [ID] → [embedding]
|
| 158 |
-
html.Div([
|
| 159 |
-
create_token_row(token, token_id, color, idx)
|
| 160 |
-
for idx, (token, token_id, color) in enumerate(zip(tokens_list, token_ids_list, color_palette))
|
| 161 |
-
], className='tokenization-rows')
|
| 162 |
-
|
| 163 |
-
], style={'padding': '1rem', 'backgroundColor': '#ffffff',
|
| 164 |
-
'borderRadius': '8px', 'border': '1px solid #dee2e6'})
|
| 165 |
-
|
| 166 |
-
], open=False, style={'marginTop': '1rem'})
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def create_token_row(token, token_id, color, idx):
|
| 170 |
-
"""
|
| 171 |
-
Create a single horizontal row showing: [token] → [ID] → [embedding].
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
token: Token string
|
| 175 |
-
token_id: Token ID number
|
| 176 |
-
color: Background color for the token
|
| 177 |
-
idx: Index of the token (for key uniqueness)
|
| 178 |
-
|
| 179 |
-
Returns:
|
| 180 |
-
Dash HTML component for a single token row
|
| 181 |
-
"""
|
| 182 |
-
# Tooltip text for educational purposes
|
| 183 |
-
tooltips = {
|
| 184 |
-
'token': "The text is broken into 'tokens' - small pieces like words or parts of words. "
|
| 185 |
-
"This is how the model reads text. Breaking words into smaller pieces lets the model "
|
| 186 |
-
"understand new words by combining pieces it already knows.",
|
| 187 |
-
'id': "Each token gets a unique number (ID) from the model's dictionary. "
|
| 188 |
-
"Think of it like a phonebook - every token has its own number. "
|
| 189 |
-
"The model uses these numbers instead of the actual text.",
|
| 190 |
-
'embedding': "Each token number is turned into a list of numbers called an 'embedding.' "
|
| 191 |
-
"These numbers capture the token's meaning. Similar words get similar numbers. "
|
| 192 |
-
"This list of numbers is what actually goes into the model's layers."
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
return html.Div([
|
| 196 |
-
# Token box
|
| 197 |
-
html.Div(
|
| 198 |
-
token,
|
| 199 |
-
className='token-row-box token-row-token',
|
| 200 |
-
style={
|
| 201 |
-
'flex': '1',
|
| 202 |
-
'padding': '8px 12px',
|
| 203 |
-
'backgroundColor': color,
|
| 204 |
-
'borderRadius': '6px',
|
| 205 |
-
'border': f'2px solid {darken_color(color)}',
|
| 206 |
-
'fontFamily': 'monospace',
|
| 207 |
-
'fontSize': '13px',
|
| 208 |
-
'textAlign': 'center',
|
| 209 |
-
'wordBreak': 'break-word',
|
| 210 |
-
'minWidth': '60px'
|
| 211 |
-
},
|
| 212 |
-
title=tooltips['token']
|
| 213 |
-
),
|
| 214 |
-
|
| 215 |
-
# Arrow
|
| 216 |
-
html.Span('→', className='token-row-arrow',
|
| 217 |
-
style={'color': '#6c757d', 'fontSize': '16px', 'padding': '0 8px'}),
|
| 218 |
-
|
| 219 |
-
# ID box
|
| 220 |
-
html.Div(
|
| 221 |
-
str(token_id),
|
| 222 |
-
className='token-row-box token-row-id',
|
| 223 |
-
style={
|
| 224 |
-
'flex': '1',
|
| 225 |
-
'padding': '8px 12px',
|
| 226 |
-
'backgroundColor': '#ffe5d4',
|
| 227 |
-
'borderRadius': '6px',
|
| 228 |
-
'border': '2px solid #e6cfc0',
|
| 229 |
-
'fontFamily': 'monospace',
|
| 230 |
-
'fontSize': '13px',
|
| 231 |
-
'textAlign': 'center',
|
| 232 |
-
'minWidth': '60px'
|
| 233 |
-
},
|
| 234 |
-
title=tooltips['id']
|
| 235 |
-
),
|
| 236 |
-
|
| 237 |
-
# Arrow
|
| 238 |
-
html.Span('→', className='token-row-arrow',
|
| 239 |
-
style={'color': '#6c757d', 'fontSize': '16px', 'padding': '0 8px'}),
|
| 240 |
-
|
| 241 |
-
# Embedding box
|
| 242 |
-
html.Div(
|
| 243 |
-
'[ ... ]',
|
| 244 |
-
className='token-row-box token-row-embedding',
|
| 245 |
-
style={
|
| 246 |
-
'flex': '1',
|
| 247 |
-
'padding': '8px 12px',
|
| 248 |
-
'backgroundColor': '#e5d4ff',
|
| 249 |
-
'borderRadius': '6px',
|
| 250 |
-
'border': '2px solid #cfbfe6',
|
| 251 |
-
'fontFamily': 'monospace',
|
| 252 |
-
'fontSize': '13px',
|
| 253 |
-
'textAlign': 'center',
|
| 254 |
-
'minWidth': '60px'
|
| 255 |
-
},
|
| 256 |
-
title=tooltips['embedding']
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
], className='token-row',
|
| 260 |
-
style={'display': 'flex', 'alignItems': 'center', 'gap': '4px', 'marginBottom': '8px'})
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def generate_token_colors(num_tokens):
|
| 264 |
-
"""Generate a list of distinct colors for tokens."""
|
| 265 |
-
# Predefined pleasant color palette
|
| 266 |
-
base_colors = [
|
| 267 |
-
'#ffcccb', # Light red
|
| 268 |
-
'#add8e6', # Light blue
|
| 269 |
-
'#90ee90', # Light green
|
| 270 |
-
'#ffb6c1', # Light pink
|
| 271 |
-
'#ffd700', # Gold
|
| 272 |
-
'#dda0dd', # Plum
|
| 273 |
-
'#f0e68c', # Khaki
|
| 274 |
-
'#ff6347', # Tomato
|
| 275 |
-
'#98fb98', # Pale green
|
| 276 |
-
'#87ceeb', # Sky blue
|
| 277 |
-
'#ffa07a', # Light salmon
|
| 278 |
-
'#da70d6', # Orchid
|
| 279 |
-
]
|
| 280 |
-
|
| 281 |
-
# Cycle through colors if we have more tokens than colors
|
| 282 |
-
colors = []
|
| 283 |
-
for i in range(num_tokens):
|
| 284 |
-
colors.append(base_colors[i % len(base_colors)])
|
| 285 |
-
|
| 286 |
-
return colors
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
def darken_color(hex_color, factor=0.8):
|
| 290 |
-
"""Darken a hex color by a factor."""
|
| 291 |
-
# Remove '#' if present
|
| 292 |
-
hex_color = hex_color.lstrip('#')
|
| 293 |
-
|
| 294 |
-
# Convert to RGB
|
| 295 |
-
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 296 |
-
|
| 297 |
-
# Darken
|
| 298 |
-
r, g, b = int(r * factor), int(g * factor), int(b * factor)
|
| 299 |
-
|
| 300 |
-
# Convert back to hex
|
| 301 |
-
return f'#{r:02x}{g:02x}{b:02x}'
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
todo.md
CHANGED
|
@@ -112,3 +112,22 @@
|
|
| 112 |
- [x] Replace per-layer ablation loop in app.py with single call to new function
|
| 113 |
- [x] Add 5 tests for multi-layer ablation in test_model_patterns.py
|
| 114 |
- [x] Verify all 78 tests pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
- [x] Replace per-layer ablation loop in app.py with single call to new function
|
| 113 |
- [x] Add 5 tests for multi-layer ablation in test_model_patterns.py
|
| 114 |
- [x] Verify all 78 tests pass
|
| 115 |
+
|
| 116 |
+
## Completed: Codebase Cleanup
|
| 117 |
+
|
| 118 |
+
- [x] Delete unused file: `components/tokenization_panel.py` (302 lines, 6 functions)
|
| 119 |
+
- [x] Remove 6 unused imports from `app.py`
|
| 120 |
+
- [x] Remove deprecated `_get_top_attended_tokens()` function from model_patterns.py
|
| 121 |
+
- [x] Remove `top_attended_tokens` field from extract_layer_data() return values
|
| 122 |
+
- [x] Remove unused `create_stage_summary()` function from pipeline.py
|
| 123 |
+
- [x] Remove 7 unused utility functions from utils/:
|
| 124 |
+
- `get_check_token_probabilities`
|
| 125 |
+
- `execute_forward_pass_with_layer_ablation`
|
| 126 |
+
- `generate_category_bertviz_html`
|
| 127 |
+
- `generate_head_view_with_categories`
|
| 128 |
+
- `compute_sequence_trajectory`
|
| 129 |
+
- `compute_layer_wise_summaries`
|
| 130 |
+
- `compute_position_layer_matrix`
|
| 131 |
+
- [x] Update `utils/__init__.py` exports
|
| 132 |
+
- [x] Update README.md to remove reference to deleted file
|
| 133 |
+
- [x] Verify all 81 tests pass
|
utils/__init__.py
CHANGED
|
@@ -1,17 +1,14 @@
|
|
| 1 |
from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
|
| 2 |
logit_lens_transformation, extract_layer_data,
|
| 3 |
-
generate_bertviz_html,
|
| 4 |
-
generate_head_view_with_categories, get_head_category_counts,
|
| 5 |
-
get_check_token_probabilities, execute_forward_pass_with_layer_ablation,
|
| 6 |
execute_forward_pass_with_head_ablation,
|
| 7 |
execute_forward_pass_with_multi_layer_head_ablation,
|
| 8 |
merge_token_probabilities,
|
| 9 |
compute_global_top5_tokens, detect_significant_probability_increases,
|
| 10 |
-
|
| 11 |
-
compute_position_layer_matrix, generate_bertviz_model_view_html)
|
| 12 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 13 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
| 14 |
-
from .beam_search import perform_beam_search
|
| 15 |
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
|
| 16 |
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
|
| 17 |
|
|
@@ -20,22 +17,15 @@ __all__ = [
|
|
| 20 |
# Model patterns
|
| 21 |
'load_model_and_get_patterns',
|
| 22 |
'execute_forward_pass',
|
| 23 |
-
'execute_forward_pass_with_layer_ablation',
|
| 24 |
'execute_forward_pass_with_head_ablation',
|
| 25 |
'execute_forward_pass_with_multi_layer_head_ablation',
|
| 26 |
'evaluate_sequence_ablation',
|
| 27 |
'logit_lens_transformation',
|
| 28 |
'extract_layer_data',
|
| 29 |
'generate_bertviz_html',
|
| 30 |
-
'generate_category_bertviz_html',
|
| 31 |
-
'generate_head_view_with_categories',
|
| 32 |
-
'get_head_category_counts',
|
| 33 |
-
'get_check_token_probabilities',
|
| 34 |
'merge_token_probabilities',
|
| 35 |
'compute_global_top5_tokens',
|
| 36 |
'detect_significant_probability_increases',
|
| 37 |
-
'compute_layer_wise_summaries',
|
| 38 |
-
'compute_position_layer_matrix',
|
| 39 |
'generate_bertviz_model_view_html',
|
| 40 |
|
| 41 |
# Model config
|
|
@@ -53,7 +43,6 @@ __all__ = [
|
|
| 53 |
|
| 54 |
# Beam search
|
| 55 |
'perform_beam_search',
|
| 56 |
-
'compute_sequence_trajectory',
|
| 57 |
|
| 58 |
# Ablation metrics
|
| 59 |
'compute_kl_divergence',
|
|
|
|
| 1 |
from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
|
| 2 |
logit_lens_transformation, extract_layer_data,
|
| 3 |
+
generate_bertviz_html,
|
|
|
|
|
|
|
| 4 |
execute_forward_pass_with_head_ablation,
|
| 5 |
execute_forward_pass_with_multi_layer_head_ablation,
|
| 6 |
merge_token_probabilities,
|
| 7 |
compute_global_top5_tokens, detect_significant_probability_increases,
|
| 8 |
+
evaluate_sequence_ablation, generate_bertviz_model_view_html)
|
|
|
|
| 9 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 10 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
| 11 |
+
from .beam_search import perform_beam_search
|
| 12 |
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
|
| 13 |
from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
|
| 14 |
|
|
|
|
| 17 |
# Model patterns
|
| 18 |
'load_model_and_get_patterns',
|
| 19 |
'execute_forward_pass',
|
|
|
|
| 20 |
'execute_forward_pass_with_head_ablation',
|
| 21 |
'execute_forward_pass_with_multi_layer_head_ablation',
|
| 22 |
'evaluate_sequence_ablation',
|
| 23 |
'logit_lens_transformation',
|
| 24 |
'extract_layer_data',
|
| 25 |
'generate_bertviz_html',
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
'merge_token_probabilities',
|
| 27 |
'compute_global_top5_tokens',
|
| 28 |
'detect_significant_probability_increases',
|
|
|
|
|
|
|
| 29 |
'generate_bertviz_model_view_html',
|
| 30 |
|
| 31 |
# Model config
|
|
|
|
| 43 |
|
| 44 |
# Beam search
|
| 45 |
'perform_beam_search',
|
|
|
|
| 46 |
|
| 47 |
# Ablation metrics
|
| 48 |
'compute_kl_divergence',
|
utils/beam_search.py
CHANGED
|
@@ -4,9 +4,7 @@ Beam search utility for text generation and sequence analysis.
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
-
from typing import List,
|
| 8 |
-
import numpy as np
|
| 9 |
-
from utils.model_patterns import get_norm_layer_from_parameter
|
| 10 |
import re
|
| 11 |
|
| 12 |
def _make_head_ablation_hook(head_indices: List[int], num_heads: int):
|
|
@@ -179,90 +177,3 @@ def perform_beam_search(model, tokenizer, prompt: str, beam_width: int = 3, max_
|
|
| 179 |
# Ensure hooks are removed even if error occurs
|
| 180 |
for hook in hooks:
|
| 181 |
hook.remove()
|
| 182 |
-
|
| 183 |
-
def compute_sequence_trajectory(activation_data: Dict[str, Any], model, tokenizer) -> Dict[int, List[float]]:
|
| 184 |
-
"""
|
| 185 |
-
Compute the trajectory of the sequence score across layers.
|
| 186 |
-
|
| 187 |
-
For each layer, calculates the probability assigned to the *actual* next token
|
| 188 |
-
at each step of the sequence.
|
| 189 |
-
|
| 190 |
-
Args:
|
| 191 |
-
activation_data: Data from execute_forward_pass (must contain block_outputs for all layers)
|
| 192 |
-
model: HuggingFace model
|
| 193 |
-
tokenizer: HuggingFace tokenizer
|
| 194 |
-
|
| 195 |
-
Returns:
|
| 196 |
-
Dict mapping layer_num -> list of scores (one per step in the generated sequence)
|
| 197 |
-
"""
|
| 198 |
-
if not activation_data or 'block_outputs' not in activation_data:
|
| 199 |
-
return {}
|
| 200 |
-
|
| 201 |
-
# Extract layer outputs
|
| 202 |
-
block_outputs = activation_data['block_outputs']
|
| 203 |
-
input_ids = activation_data['input_ids']
|
| 204 |
-
|
| 205 |
-
if isinstance(input_ids, list):
|
| 206 |
-
input_ids = torch.tensor(input_ids)
|
| 207 |
-
|
| 208 |
-
# Identify tokens: input_ids shape is [1, seq_len]
|
| 209 |
-
# The "generated" part starts after the prompt, but here we likely have the full sequence.
|
| 210 |
-
# We want to evaluate P(token_t | tokens_<t) for the whole sequence or just the new part?
|
| 211 |
-
# Usually, we visualize the whole sequence.
|
| 212 |
-
|
| 213 |
-
# We need the logits from each layer
|
| 214 |
-
# block_outputs keys are like "model.layers.0", "model.layers.1", etc.
|
| 215 |
-
|
| 216 |
-
# Sort layers
|
| 217 |
-
import re
|
| 218 |
-
layer_info = sorted(
|
| 219 |
-
[(int(re.findall(r'\d+', name)[0]), name)
|
| 220 |
-
for name in block_outputs.keys() if re.findall(r'\d+', name)]
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
# Get norm parameter for logit lens
|
| 224 |
-
norm_params = activation_data.get('norm_parameters', [])
|
| 225 |
-
norm_parameter = norm_params[0] if norm_params else None
|
| 226 |
-
final_norm = get_norm_layer_from_parameter(model, norm_parameter)
|
| 227 |
-
lm_head = model.get_output_embeddings()
|
| 228 |
-
|
| 229 |
-
trajectories = {}
|
| 230 |
-
|
| 231 |
-
# We only care about predictions for positions 0 to N-1 (predicting 1 to N)
|
| 232 |
-
target_ids = input_ids[0, 1:]
|
| 233 |
-
|
| 234 |
-
with torch.no_grad():
|
| 235 |
-
for layer_num, module_name in layer_info:
|
| 236 |
-
output_data = block_outputs[module_name]['output']
|
| 237 |
-
|
| 238 |
-
# Convert to tensor [batch, seq_len, hidden_dim]
|
| 239 |
-
hidden = torch.tensor(output_data) if not isinstance(output_data, torch.Tensor) else output_data
|
| 240 |
-
if hidden.dim() == 4: # PyVene sometimes returns [1, 1, seq_len, dim] ? No usually [1, seq, dim]
|
| 241 |
-
# If shape is weird, adjust
|
| 242 |
-
pass
|
| 243 |
-
|
| 244 |
-
# Ensure batch dim
|
| 245 |
-
if hidden.dim() == 2:
|
| 246 |
-
hidden = hidden.unsqueeze(0)
|
| 247 |
-
|
| 248 |
-
# Apply final norm
|
| 249 |
-
if final_norm is not None:
|
| 250 |
-
hidden = final_norm(hidden)
|
| 251 |
-
|
| 252 |
-
# Project to logits
|
| 253 |
-
logits = lm_head(hidden) # [batch, seq_len, vocab_size]
|
| 254 |
-
|
| 255 |
-
# We want log probs of the *next* token
|
| 256 |
-
# Logits at pos t predict token at t+1
|
| 257 |
-
# So we take logits at [0, :-1, :] and gather targets [0, 1:]
|
| 258 |
-
|
| 259 |
-
shift_logits = logits[0, :-1, :]
|
| 260 |
-
log_probs = F.log_softmax(shift_logits, dim=-1)
|
| 261 |
-
|
| 262 |
-
# Gather log probs of the actual target tokens
|
| 263 |
-
# target_ids shape [seq_len-1]
|
| 264 |
-
target_log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1)
|
| 265 |
-
|
| 266 |
-
trajectories[layer_num] = target_log_probs.tolist()
|
| 267 |
-
|
| 268 |
-
return trajectories
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
|
| 8 |
import re
|
| 9 |
|
| 10 |
def _make_head_ablation_hook(head_indices: List[int], num_heads: int):
|
|
|
|
| 177 |
# Ensure hooks are removed even if error occurs
|
| 178 |
for hook in hooks:
|
| 179 |
hook.remove()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/model_patterns.py
CHANGED
|
@@ -657,174 +657,6 @@ def execute_forward_pass_with_multi_layer_head_ablation(model, tokenizer, prompt
|
|
| 657 |
return result
|
| 658 |
|
| 659 |
|
| 660 |
-
def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
|
| 661 |
-
ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 662 |
-
"""
|
| 663 |
-
Execute forward pass with mean ablation on a specific layer.
|
| 664 |
-
|
| 665 |
-
Args:
|
| 666 |
-
model: Loaded transformer model
|
| 667 |
-
tokenizer: Loaded tokenizer
|
| 668 |
-
prompt: Input text prompt
|
| 669 |
-
config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
|
| 670 |
-
ablate_layer_num: Layer number to ablate
|
| 671 |
-
reference_activation_data: Original activation data containing the reference activations
|
| 672 |
-
|
| 673 |
-
Returns:
|
| 674 |
-
JSON-serializable dict with captured activations (with ablated layer)
|
| 675 |
-
"""
|
| 676 |
-
# Extract module lists from config
|
| 677 |
-
attention_modules = config.get("attention_modules", [])
|
| 678 |
-
block_modules = config.get("block_modules", [])
|
| 679 |
-
norm_parameters = config.get("norm_parameters", [])
|
| 680 |
-
logit_lens_parameter = config.get("logit_lens_parameter")
|
| 681 |
-
|
| 682 |
-
all_modules = attention_modules + block_modules
|
| 683 |
-
if not all_modules:
|
| 684 |
-
return {"error": "No modules specified"}
|
| 685 |
-
|
| 686 |
-
# Find the target module for the layer to ablate
|
| 687 |
-
target_module_name = None
|
| 688 |
-
for mod_name in block_modules:
|
| 689 |
-
layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
|
| 690 |
-
if layer_match and int(layer_match.group(1)) == ablate_layer_num:
|
| 691 |
-
target_module_name = mod_name
|
| 692 |
-
break
|
| 693 |
-
|
| 694 |
-
if not target_module_name:
|
| 695 |
-
return {"error": f"Could not find module for layer {ablate_layer_num}"}
|
| 696 |
-
|
| 697 |
-
# Get reference activations from ALL layers for mean computation
|
| 698 |
-
block_outputs = reference_activation_data.get('block_outputs', {})
|
| 699 |
-
if not block_outputs:
|
| 700 |
-
return {"error": "No block outputs found in reference data"}
|
| 701 |
-
|
| 702 |
-
# Collect all layer activations to compute global mean
|
| 703 |
-
all_layer_tensors = []
|
| 704 |
-
for mod_name, output_data in block_outputs.items():
|
| 705 |
-
output = output_data['output']
|
| 706 |
-
if isinstance(output, list):
|
| 707 |
-
tensor = torch.tensor(output)
|
| 708 |
-
else:
|
| 709 |
-
tensor = output
|
| 710 |
-
all_layer_tensors.append(tensor)
|
| 711 |
-
|
| 712 |
-
# Stack all layers and compute mean across ALL layers and sequence positions
|
| 713 |
-
# This gives us a single mean vector that represents the average activation
|
| 714 |
-
stacked = torch.stack(all_layer_tensors, dim=0) # [num_layers, batch, seq_len, hidden_dim]
|
| 715 |
-
# Compute mean across layers and sequence dimension
|
| 716 |
-
mean_activation = stacked.mean(dim=(0, 2), keepdim=True) # [1, batch, 1, hidden_dim]
|
| 717 |
-
mean_activation = mean_activation.squeeze(0) # [batch, 1, hidden_dim]
|
| 718 |
-
|
| 719 |
-
# Prepare inputs
|
| 720 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
| 721 |
-
seq_len = inputs['input_ids'].shape[1]
|
| 722 |
-
|
| 723 |
-
# Broadcast mean to match sequence length
|
| 724 |
-
ablation_value = mean_activation.expand(-1, seq_len, -1) # [batch, seq_len, hidden_dim]
|
| 725 |
-
|
| 726 |
-
# Build IntervenableConfig from module names
|
| 727 |
-
intervenable_representations = []
|
| 728 |
-
for mod_name in all_modules:
|
| 729 |
-
layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
|
| 730 |
-
if not layer_match:
|
| 731 |
-
return {"error": f"Invalid module name format: {mod_name}"}
|
| 732 |
-
|
| 733 |
-
if 'attn' in mod_name or 'attention' in mod_name:
|
| 734 |
-
component = 'attention_output'
|
| 735 |
-
else:
|
| 736 |
-
component = 'block_output'
|
| 737 |
-
|
| 738 |
-
intervenable_representations.append(
|
| 739 |
-
RepresentationConfig(layer=int(layer_match.group(1)), component=component, unit="pos")
|
| 740 |
-
)
|
| 741 |
-
|
| 742 |
-
intervenable_config = IntervenableConfig(
|
| 743 |
-
intervenable_representations=intervenable_representations
|
| 744 |
-
)
|
| 745 |
-
intervenable_model = IntervenableModel(intervenable_config, model)
|
| 746 |
-
|
| 747 |
-
# Register hooks to capture activations
|
| 748 |
-
captured = {}
|
| 749 |
-
name_to_module = dict(intervenable_model.model.named_modules())
|
| 750 |
-
|
| 751 |
-
def make_hook(mod_name: str):
|
| 752 |
-
return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
|
| 753 |
-
|
| 754 |
-
# Register ablation hook for target module
|
| 755 |
-
def ablation_hook(module, input, output):
|
| 756 |
-
# Replace output with mean activation
|
| 757 |
-
if isinstance(output, tuple):
|
| 758 |
-
# For modules that return tuples (hidden_states, ...), replace first element
|
| 759 |
-
ablated = (ablation_value,) + output[1:]
|
| 760 |
-
return ablated
|
| 761 |
-
else:
|
| 762 |
-
return ablation_value
|
| 763 |
-
|
| 764 |
-
hooks = []
|
| 765 |
-
for mod_name in all_modules:
|
| 766 |
-
if mod_name in name_to_module:
|
| 767 |
-
if mod_name == target_module_name:
|
| 768 |
-
# Apply ablation hook
|
| 769 |
-
hooks.append(name_to_module[mod_name].register_forward_hook(ablation_hook))
|
| 770 |
-
else:
|
| 771 |
-
# Regular capture hook
|
| 772 |
-
hooks.append(name_to_module[mod_name].register_forward_hook(make_hook(mod_name)))
|
| 773 |
-
|
| 774 |
-
# Execute forward pass
|
| 775 |
-
with torch.no_grad():
|
| 776 |
-
model_output = intervenable_model.model(**inputs, use_cache=False)
|
| 777 |
-
|
| 778 |
-
# Remove hooks
|
| 779 |
-
for hook in hooks:
|
| 780 |
-
hook.remove()
|
| 781 |
-
|
| 782 |
-
# Capture ablated layer output as well
|
| 783 |
-
captured[target_module_name] = {"output": safe_to_serializable(ablation_value)}
|
| 784 |
-
|
| 785 |
-
# Separate outputs by type
|
| 786 |
-
attention_outputs = {}
|
| 787 |
-
block_outputs = {}
|
| 788 |
-
|
| 789 |
-
for mod_name, output in captured.items():
|
| 790 |
-
if 'attn' in mod_name or 'attention' in mod_name:
|
| 791 |
-
attention_outputs[mod_name] = output
|
| 792 |
-
else:
|
| 793 |
-
block_outputs[mod_name] = output
|
| 794 |
-
|
| 795 |
-
# Capture normalization parameters
|
| 796 |
-
all_params = dict(model.named_parameters())
|
| 797 |
-
norm_data = [safe_to_serializable(all_params[p]) for p in norm_parameters if p in all_params]
|
| 798 |
-
|
| 799 |
-
# Extract predicted token from model output
|
| 800 |
-
actual_output = None
|
| 801 |
-
global_top5_tokens = []
|
| 802 |
-
try:
|
| 803 |
-
output_token, output_prob = get_actual_model_output(model_output, tokenizer)
|
| 804 |
-
actual_output = {"token": output_token, "probability": output_prob}
|
| 805 |
-
global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
|
| 806 |
-
except Exception as e:
|
| 807 |
-
print(f"Warning: Could not extract model output: {e}")
|
| 808 |
-
|
| 809 |
-
# Build output dictionary
|
| 810 |
-
result = {
|
| 811 |
-
"model": getattr(model.config, "name_or_path", "unknown"),
|
| 812 |
-
"prompt": prompt,
|
| 813 |
-
"input_ids": safe_to_serializable(inputs["input_ids"]),
|
| 814 |
-
"attention_modules": list(attention_outputs.keys()),
|
| 815 |
-
"attention_outputs": attention_outputs,
|
| 816 |
-
"block_modules": list(block_outputs.keys()),
|
| 817 |
-
"block_outputs": block_outputs,
|
| 818 |
-
"norm_parameters": norm_parameters,
|
| 819 |
-
"norm_data": norm_data,
|
| 820 |
-
"actual_output": actual_output,
|
| 821 |
-
"global_top5_tokens": global_top5_tokens,
|
| 822 |
-
"ablated_layer": ablate_layer_num
|
| 823 |
-
}
|
| 824 |
-
|
| 825 |
-
return result
|
| 826 |
-
|
| 827 |
-
|
| 828 |
def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
|
| 829 |
ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
|
| 830 |
"""
|
|
@@ -1159,85 +991,6 @@ def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, to
|
|
| 1159 |
return None
|
| 1160 |
|
| 1161 |
|
| 1162 |
-
def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokenizer, check_token: str) -> Optional[Dict[str, Any]]:
|
| 1163 |
-
"""
|
| 1164 |
-
Collect check token probabilities across all layers.
|
| 1165 |
-
|
| 1166 |
-
Sums probabilities of token variants (with and without leading space).
|
| 1167 |
-
Returns layer numbers and merged probabilities for plotting.
|
| 1168 |
-
"""
|
| 1169 |
-
if not check_token or not check_token.strip():
|
| 1170 |
-
return None
|
| 1171 |
-
|
| 1172 |
-
try:
|
| 1173 |
-
# Get block modules (all layers)
|
| 1174 |
-
layer_modules = activation_data.get('block_modules', [])
|
| 1175 |
-
if not layer_modules:
|
| 1176 |
-
return None
|
| 1177 |
-
|
| 1178 |
-
# Extract and sort layers
|
| 1179 |
-
layer_info = sorted(
|
| 1180 |
-
[(int(re.findall(r'\d+', name)[0]), name)
|
| 1181 |
-
for name in layer_modules if re.findall(r'\d+', name)]
|
| 1182 |
-
)
|
| 1183 |
-
|
| 1184 |
-
# Try tokenizing with and without leading space
|
| 1185 |
-
token_variants = [
|
| 1186 |
-
(check_token.strip(), tokenizer.encode(check_token.strip(), add_special_tokens=False)),
|
| 1187 |
-
(' ' + check_token.strip(), tokenizer.encode(' ' + check_token.strip(), add_special_tokens=False))
|
| 1188 |
-
]
|
| 1189 |
-
|
| 1190 |
-
# Get token IDs for both variants (if they exist and differ)
|
| 1191 |
-
target_token_ids = []
|
| 1192 |
-
for variant_text, token_ids in token_variants:
|
| 1193 |
-
if token_ids:
|
| 1194 |
-
tid = token_ids[-1] # Use last sub-token
|
| 1195 |
-
if tid not in target_token_ids:
|
| 1196 |
-
target_token_ids.append(tid)
|
| 1197 |
-
|
| 1198 |
-
if not target_token_ids:
|
| 1199 |
-
return None
|
| 1200 |
-
|
| 1201 |
-
# Get norm parameter
|
| 1202 |
-
norm_params = activation_data.get('norm_parameters', [])
|
| 1203 |
-
norm_parameter = norm_params[0] if norm_params else None
|
| 1204 |
-
final_norm = get_norm_layer_from_parameter(model, norm_parameter)
|
| 1205 |
-
lm_head = model.get_output_embeddings()
|
| 1206 |
-
|
| 1207 |
-
# Collect probabilities for all layers (sum both variants)
|
| 1208 |
-
layers = []
|
| 1209 |
-
probabilities = []
|
| 1210 |
-
|
| 1211 |
-
for layer_num, module_name in layer_info:
|
| 1212 |
-
layer_output = activation_data['block_outputs'][module_name]['output']
|
| 1213 |
-
|
| 1214 |
-
with torch.no_grad():
|
| 1215 |
-
hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
|
| 1216 |
-
if hidden.dim() == 4:
|
| 1217 |
-
hidden = hidden.squeeze(0)
|
| 1218 |
-
|
| 1219 |
-
if final_norm is not None:
|
| 1220 |
-
hidden = final_norm(hidden)
|
| 1221 |
-
|
| 1222 |
-
logits = lm_head(hidden)
|
| 1223 |
-
probs = F.softmax(logits[0, -1, :], dim=-1)
|
| 1224 |
-
|
| 1225 |
-
# Sum probabilities of all variants
|
| 1226 |
-
merged_prob = sum(probs[tid].item() for tid in target_token_ids)
|
| 1227 |
-
|
| 1228 |
-
layers.append(layer_num)
|
| 1229 |
-
probabilities.append(merged_prob)
|
| 1230 |
-
|
| 1231 |
-
return {
|
| 1232 |
-
'token': check_token.strip(), # Return canonical form without leading space
|
| 1233 |
-
'layers': layers,
|
| 1234 |
-
'probabilities': probabilities
|
| 1235 |
-
}
|
| 1236 |
-
except Exception as e:
|
| 1237 |
-
print(f"Error computing check token probabilities: {e}")
|
| 1238 |
-
return None
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
|
| 1242 |
layer_wise_deltas: Dict[int, Dict[str, float]],
|
| 1243 |
actual_output_token: str,
|
|
@@ -1281,246 +1034,13 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
|
|
| 1281 |
return significant_layers
|
| 1282 |
|
| 1283 |
|
| 1284 |
-
def _get_top_attended_tokens(activation_data: Dict[str, Any], layer_num: int, tokenizer, top_k: int = 3) -> Optional[List[Tuple[str, float]]]:
|
| 1285 |
-
"""
|
| 1286 |
-
DEPRECATED: This function is deprecated and will be removed in a future version.
|
| 1287 |
-
Use head categorization from head_detection.py instead for more meaningful attention analysis.
|
| 1288 |
-
|
| 1289 |
-
Get top-K attended input tokens for the current position (last token) in a layer.
|
| 1290 |
-
Averages attention across all heads.
|
| 1291 |
-
|
| 1292 |
-
Args:
|
| 1293 |
-
activation_data: Output from execute_forward_pass
|
| 1294 |
-
layer_num: Layer number to analyze
|
| 1295 |
-
tokenizer: Tokenizer for decoding tokens
|
| 1296 |
-
top_k: Number of top attended tokens to return
|
| 1297 |
-
|
| 1298 |
-
Returns:
|
| 1299 |
-
List of (token_string, attention_weight) tuples, sorted by weight (highest first)
|
| 1300 |
-
"""
|
| 1301 |
-
import warnings
|
| 1302 |
-
warnings.warn(
|
| 1303 |
-
"_get_top_attended_tokens is deprecated. Use categorize_all_heads() from head_detection.py instead.",
|
| 1304 |
-
DeprecationWarning,
|
| 1305 |
-
stacklevel=2
|
| 1306 |
-
)
|
| 1307 |
-
try:
|
| 1308 |
-
attention_outputs = activation_data.get('attention_outputs', {})
|
| 1309 |
-
input_ids = activation_data.get('input_ids', [])
|
| 1310 |
-
|
| 1311 |
-
# print(f"DEBUG _get_top_attended_tokens: layer_num={layer_num}, attention_outputs keys={list(attention_outputs.keys())}")
|
| 1312 |
-
|
| 1313 |
-
if not attention_outputs or not input_ids:
|
| 1314 |
-
print(f"DEBUG _get_top_attended_tokens: Missing data - attention_outputs empty={not attention_outputs}, input_ids empty={not input_ids}")
|
| 1315 |
-
return None
|
| 1316 |
-
|
| 1317 |
-
# Find attention output for this layer
|
| 1318 |
-
target_module = None
|
| 1319 |
-
for module_name in attention_outputs.keys():
|
| 1320 |
-
numbers = re.findall(r'\d+', module_name)
|
| 1321 |
-
if numbers and int(numbers[0]) == layer_num:
|
| 1322 |
-
target_module = module_name
|
| 1323 |
-
break
|
| 1324 |
-
|
| 1325 |
-
if not target_module:
|
| 1326 |
-
return None
|
| 1327 |
-
|
| 1328 |
-
attention_output = attention_outputs[target_module]['output']
|
| 1329 |
-
if not isinstance(attention_output, list) or len(attention_output) < 2:
|
| 1330 |
-
return None
|
| 1331 |
-
|
| 1332 |
-
# Get attention weights: [batch, heads, seq_len, seq_len]
|
| 1333 |
-
attention_weights = torch.tensor(attention_output[1])
|
| 1334 |
-
|
| 1335 |
-
# Average across heads: [seq_len, seq_len]
|
| 1336 |
-
avg_attention = attention_weights[0].mean(dim=0)
|
| 1337 |
-
|
| 1338 |
-
# Get attention from last position to all positions
|
| 1339 |
-
last_pos_attention = avg_attention[-1, :] # [seq_len]
|
| 1340 |
-
|
| 1341 |
-
# Get top-K attended positions
|
| 1342 |
-
top_values, top_indices = torch.topk(last_pos_attention, min(top_k, len(last_pos_attention)))
|
| 1343 |
-
|
| 1344 |
-
# Convert to tokens
|
| 1345 |
-
input_ids_tensor = torch.tensor(input_ids[0]) if isinstance(input_ids[0], list) else torch.tensor(input_ids)
|
| 1346 |
-
result = []
|
| 1347 |
-
for idx, weight in zip(top_indices, top_values):
|
| 1348 |
-
token_id = input_ids_tensor[idx].item()
|
| 1349 |
-
token_str = tokenizer.decode([token_id], skip_special_tokens=False)
|
| 1350 |
-
result.append((token_str, weight.item()))
|
| 1351 |
-
|
| 1352 |
-
return result
|
| 1353 |
-
|
| 1354 |
-
except Exception as e:
|
| 1355 |
-
print(f"Warning: Could not compute attended tokens for layer {layer_num}: {e}")
|
| 1356 |
-
return None
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
def compute_position_layer_matrix(activation_data: Dict[str, Any], model, tokenizer) -> Dict[str, Any]:
|
| 1360 |
-
"""
|
| 1361 |
-
Compute a 2D matrix of layer-to-layer deltas for each token position.
|
| 1362 |
-
|
| 1363 |
-
This function computes the top-token probability delta at each (layer, position) pair,
|
| 1364 |
-
creating a heatmap-ready data structure.
|
| 1365 |
-
|
| 1366 |
-
Args:
|
| 1367 |
-
activation_data: Activation data from forward pass
|
| 1368 |
-
model: Transformer model for logit lens computation
|
| 1369 |
-
tokenizer: Tokenizer for decoding tokens
|
| 1370 |
-
|
| 1371 |
-
Returns:
|
| 1372 |
-
Dict with:
|
| 1373 |
-
- 'matrix': 2D list [num_layers, seq_len] of delta values
|
| 1374 |
-
- 'tokens': List of token strings for X-axis labels
|
| 1375 |
-
- 'layer_nums': List of layer numbers for Y-axis labels
|
| 1376 |
-
- 'top_tokens': 2D list [num_layers, seq_len] of top token strings at each cell
|
| 1377 |
-
"""
|
| 1378 |
-
import copy
|
| 1379 |
-
import numpy as np
|
| 1380 |
-
|
| 1381 |
-
input_ids = activation_data.get('input_ids', [[]])
|
| 1382 |
-
if not input_ids or not input_ids[0]:
|
| 1383 |
-
return {'matrix': [], 'tokens': [], 'layer_nums': [], 'top_tokens': []}
|
| 1384 |
-
|
| 1385 |
-
seq_len = len(input_ids[0])
|
| 1386 |
-
|
| 1387 |
-
# Get token strings for X-axis labels
|
| 1388 |
-
tokens = [tokenizer.decode([tid]) for tid in input_ids[0]]
|
| 1389 |
-
|
| 1390 |
-
# Get layer modules and sort by layer number
|
| 1391 |
-
layer_modules = activation_data.get('block_modules', [])
|
| 1392 |
-
if not layer_modules:
|
| 1393 |
-
return {'matrix': [], 'tokens': tokens, 'layer_nums': [], 'top_tokens': []}
|
| 1394 |
-
|
| 1395 |
-
layer_info = sorted(
|
| 1396 |
-
[(int(re.findall(r'\d+', name)[0]), name)
|
| 1397 |
-
for name in layer_modules if re.findall(r'\d+', name)]
|
| 1398 |
-
)
|
| 1399 |
-
layer_nums = [ln for ln, _ in layer_info]
|
| 1400 |
-
num_layers = len(layer_nums)
|
| 1401 |
-
|
| 1402 |
-
# Helper function to slice data to a specific position (adapted from app.py)
|
| 1403 |
-
def slice_data(data, pos):
|
| 1404 |
-
if not data:
|
| 1405 |
-
return data
|
| 1406 |
-
sliced = copy.deepcopy(data)
|
| 1407 |
-
|
| 1408 |
-
# Slice Block Outputs: [batch, seq, hidden] -> [batch, 1, hidden]
|
| 1409 |
-
if 'block_outputs' in sliced:
|
| 1410 |
-
for mod in sliced['block_outputs']:
|
| 1411 |
-
out = sliced['block_outputs'][mod]['output']
|
| 1412 |
-
if isinstance(out, list) and len(out) > 0 and isinstance(out[0], list):
|
| 1413 |
-
if pos < len(out[0]):
|
| 1414 |
-
sliced['block_outputs'][mod]['output'] = [[out[0][pos]]]
|
| 1415 |
-
|
| 1416 |
-
# Slice Attention Outputs: [batch, heads, seq, seq] -> [batch, heads, 1, seq]
|
| 1417 |
-
if 'attention_outputs' in sliced:
|
| 1418 |
-
for mod in sliced['attention_outputs']:
|
| 1419 |
-
out = sliced['attention_outputs'][mod]['output']
|
| 1420 |
-
if len(out) > 1:
|
| 1421 |
-
attns = out[1]
|
| 1422 |
-
if isinstance(attns, list) and len(attns) > 0:
|
| 1423 |
-
batch_0 = attns[0]
|
| 1424 |
-
new_batch_0 = []
|
| 1425 |
-
for head in batch_0:
|
| 1426 |
-
if pos < len(head):
|
| 1427 |
-
new_batch_0.append([head[pos]])
|
| 1428 |
-
sliced['attention_outputs'][mod]['output'] = [out[0], [new_batch_0]] + out[2:]
|
| 1429 |
-
|
| 1430 |
-
# Slice input_ids
|
| 1431 |
-
if 'input_ids' in sliced:
|
| 1432 |
-
ids = sliced['input_ids'][0]
|
| 1433 |
-
if pos < len(ids):
|
| 1434 |
-
sliced['input_ids'][0] = ids[:pos+1]
|
| 1435 |
-
|
| 1436 |
-
return sliced
|
| 1437 |
-
|
| 1438 |
-
# Initialize matrix and top_tokens 2D array
|
| 1439 |
-
matrix = [[0.0] * seq_len for _ in range(num_layers)]
|
| 1440 |
-
top_tokens_matrix = [[''] * seq_len for _ in range(num_layers)]
|
| 1441 |
-
|
| 1442 |
-
# Compute delta for each position
|
| 1443 |
-
for pos in range(seq_len):
|
| 1444 |
-
sliced = slice_data(activation_data, pos)
|
| 1445 |
-
layer_data = extract_layer_data(sliced, model, tokenizer)
|
| 1446 |
-
|
| 1447 |
-
if not layer_data:
|
| 1448 |
-
continue
|
| 1449 |
-
|
| 1450 |
-
# Fill in matrix for this position
|
| 1451 |
-
for layer_info_item in layer_data:
|
| 1452 |
-
layer_num = layer_info_item.get('layer_num')
|
| 1453 |
-
if layer_num is None or layer_num not in layer_nums:
|
| 1454 |
-
continue
|
| 1455 |
-
|
| 1456 |
-
layer_idx = layer_nums.index(layer_num)
|
| 1457 |
-
|
| 1458 |
-
# Get top token and its delta (layer-to-layer change)
|
| 1459 |
-
top_token = layer_info_item.get('top_token', '')
|
| 1460 |
-
deltas = layer_info_item.get('deltas', {})
|
| 1461 |
-
|
| 1462 |
-
# The delta for the top token represents how much it changed from prev layer
|
| 1463 |
-
delta = deltas.get(top_token, 0.0) if top_token else 0.0
|
| 1464 |
-
|
| 1465 |
-
matrix[layer_idx][pos] = delta
|
| 1466 |
-
top_tokens_matrix[layer_idx][pos] = top_token if top_token else ''
|
| 1467 |
-
|
| 1468 |
-
return {
|
| 1469 |
-
'matrix': matrix,
|
| 1470 |
-
'tokens': tokens,
|
| 1471 |
-
'layer_nums': layer_nums,
|
| 1472 |
-
'top_tokens': top_tokens_matrix
|
| 1473 |
-
}
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]], activation_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 1477 |
-
"""
|
| 1478 |
-
Compute summary structures from layer data for easy access.
|
| 1479 |
-
|
| 1480 |
-
Args:
|
| 1481 |
-
layer_data: List of layer data dicts from extract_layer_data()
|
| 1482 |
-
activation_data: Activation data containing actual output token
|
| 1483 |
-
|
| 1484 |
-
Returns:
|
| 1485 |
-
Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
|
| 1486 |
-
"""
|
| 1487 |
-
layer_wise_top5_probs = {} # layer_num -> {token: prob}
|
| 1488 |
-
layer_wise_top5_deltas = {} # layer_num -> {token: delta}
|
| 1489 |
-
|
| 1490 |
-
for layer_info in layer_data:
|
| 1491 |
-
layer_num = layer_info.get('layer_num')
|
| 1492 |
-
if layer_num is not None:
|
| 1493 |
-
layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
|
| 1494 |
-
layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
|
| 1495 |
-
|
| 1496 |
-
# Extract actual output token from activation data
|
| 1497 |
-
actual_output = activation_data.get('actual_output', {})
|
| 1498 |
-
actual_output_token = actual_output.get('token', '').strip() if actual_output else ''
|
| 1499 |
-
|
| 1500 |
-
# Detect significant layers based on actual output token
|
| 1501 |
-
significant_layers = []
|
| 1502 |
-
if actual_output_token:
|
| 1503 |
-
significant_layers = detect_significant_probability_increases(
|
| 1504 |
-
layer_wise_top5_probs,
|
| 1505 |
-
layer_wise_top5_deltas,
|
| 1506 |
-
actual_output_token,
|
| 1507 |
-
threshold=1.0
|
| 1508 |
-
)
|
| 1509 |
-
|
| 1510 |
-
return {
|
| 1511 |
-
'layer_wise_top5_probs': layer_wise_top5_probs,
|
| 1512 |
-
'layer_wise_top5_deltas': layer_wise_top5_deltas,
|
| 1513 |
-
'significant_layers': significant_layers
|
| 1514 |
-
}
|
| 1515 |
-
|
| 1516 |
-
|
| 1517 |
def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
|
| 1518 |
"""
|
| 1519 |
Extract layer-by-layer data for accordion display with top-5, deltas, and attention.
|
| 1520 |
Also tracks global top 5 tokens across all layers.
|
| 1521 |
|
| 1522 |
Returns:
|
| 1523 |
-
List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas,
|
| 1524 |
global_top5_probs, global_top5_deltas
|
| 1525 |
"""
|
| 1526 |
layer_modules = activation_data.get('block_modules', [])
|
|
@@ -1561,11 +1081,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
|
|
| 1561 |
for layer_num, module_name in layer_info:
|
| 1562 |
top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if can_compute_predictions else None
|
| 1563 |
|
| 1564 |
-
# NOTE: top_attended_tokens is deprecated. Use categorize_all_heads() from
|
| 1565 |
-
# head_detection.py instead for more meaningful attention analysis.
|
| 1566 |
-
# Kept as None for backward compatibility with existing code.
|
| 1567 |
-
top_attended = None
|
| 1568 |
-
|
| 1569 |
# Get probabilities for global top 5 tokens at this layer
|
| 1570 |
global_top5_probs = {}
|
| 1571 |
global_top5_deltas = {}
|
|
@@ -1596,7 +1111,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
|
|
| 1596 |
'top_3_tokens': top_tokens[:3], # Keep for backward compatibility
|
| 1597 |
'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
|
| 1598 |
'deltas': deltas,
|
| 1599 |
-
'top_attended_tokens': top_attended,
|
| 1600 |
'global_top5_probs': global_top5_probs, # New: global top 5 probs at this layer
|
| 1601 |
'global_top5_deltas': global_top5_deltas # New: global top 5 deltas
|
| 1602 |
})
|
|
@@ -1613,7 +1127,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
|
|
| 1613 |
'top_3_tokens': [],
|
| 1614 |
'top_5_tokens': [],
|
| 1615 |
'deltas': {},
|
| 1616 |
-
'top_attended_tokens': top_attended,
|
| 1617 |
'global_top5_probs': {},
|
| 1618 |
'global_top5_deltas': {}
|
| 1619 |
})
|
|
@@ -1758,164 +1271,6 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
|
|
| 1758 |
return f"<p>Error generating visualization: {str(e)}</p>"
|
| 1759 |
|
| 1760 |
|
| 1761 |
-
def generate_category_bertviz_html(activation_data: Dict[str, Any], category_heads: List[Dict[str, Any]]) -> str:
|
| 1762 |
-
"""
|
| 1763 |
-
Generate BertViz attention visualization HTML for a specific category of heads.
|
| 1764 |
-
|
| 1765 |
-
Shows only the attention patterns for heads in the specified category.
|
| 1766 |
-
|
| 1767 |
-
Args:
|
| 1768 |
-
activation_data: Output from execute_forward_pass
|
| 1769 |
-
category_heads: List of head info dicts for this category (from categorize_all_heads)
|
| 1770 |
-
|
| 1771 |
-
Returns:
|
| 1772 |
-
HTML string for the visualization
|
| 1773 |
-
"""
|
| 1774 |
-
try:
|
| 1775 |
-
from bertviz import head_view
|
| 1776 |
-
from transformers import AutoTokenizer
|
| 1777 |
-
|
| 1778 |
-
if not category_heads:
|
| 1779 |
-
return "<p>No heads in this category.</p>"
|
| 1780 |
-
|
| 1781 |
-
# Extract attention modules and sort by layer
|
| 1782 |
-
attention_outputs = activation_data.get('attention_outputs', {})
|
| 1783 |
-
if not attention_outputs:
|
| 1784 |
-
return "<p>No attention data available</p>"
|
| 1785 |
-
|
| 1786 |
-
# Build a map of layer -> head indices for this category
|
| 1787 |
-
category_map = {} # layer_num -> list of head indices
|
| 1788 |
-
for head_info in category_heads:
|
| 1789 |
-
layer = head_info['layer']
|
| 1790 |
-
head = head_info['head']
|
| 1791 |
-
if layer not in category_map:
|
| 1792 |
-
category_map[layer] = []
|
| 1793 |
-
category_map[layer].append(head)
|
| 1794 |
-
|
| 1795 |
-
# Sort attention modules by layer number and filter heads
|
| 1796 |
-
# Track which layers we've already processed to avoid duplicates
|
| 1797 |
-
layer_attention_pairs = []
|
| 1798 |
-
processed_layers = set()
|
| 1799 |
-
|
| 1800 |
-
for module_name in attention_outputs.keys():
|
| 1801 |
-
numbers = re.findall(r'\d+', module_name)
|
| 1802 |
-
if numbers:
|
| 1803 |
-
layer_num = int(numbers[0])
|
| 1804 |
-
|
| 1805 |
-
# Skip layers not in this category
|
| 1806 |
-
if layer_num not in category_map:
|
| 1807 |
-
continue
|
| 1808 |
-
|
| 1809 |
-
# Skip if we've already processed this layer (prevents duplicate/mismatched tensors)
|
| 1810 |
-
if layer_num in processed_layers:
|
| 1811 |
-
continue
|
| 1812 |
-
|
| 1813 |
-
attention_output = attention_outputs[module_name]['output']
|
| 1814 |
-
if isinstance(attention_output, list) and len(attention_output) >= 2:
|
| 1815 |
-
# Get attention weights (element 1 of the output tuple)
|
| 1816 |
-
full_attention = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
|
| 1817 |
-
|
| 1818 |
-
# Filter to only include heads in this category
|
| 1819 |
-
head_indices = category_map[layer_num]
|
| 1820 |
-
filtered_attention = full_attention[:, head_indices, :, :] # Select specific heads
|
| 1821 |
-
|
| 1822 |
-
layer_attention_pairs.append((layer_num, filtered_attention))
|
| 1823 |
-
processed_layers.add(layer_num)
|
| 1824 |
-
|
| 1825 |
-
if not layer_attention_pairs:
|
| 1826 |
-
return "<p>No valid attention data found for this category.</p>"
|
| 1827 |
-
|
| 1828 |
-
# Sort by layer number and extract attention tensors
|
| 1829 |
-
layer_attention_pairs.sort(key=lambda x: x[0])
|
| 1830 |
-
attentions = tuple(attn for _, attn in layer_attention_pairs)
|
| 1831 |
-
|
| 1832 |
-
# Get tokens
|
| 1833 |
-
input_ids = torch.tensor(activation_data['input_ids'])
|
| 1834 |
-
model_name = activation_data.get('model', 'unknown')
|
| 1835 |
-
|
| 1836 |
-
# Load tokenizer and convert to tokens
|
| 1837 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 1838 |
-
raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 1839 |
-
# Clean up tokens (remove special tokenizer artifacts like Ġ for GPT-2)
|
| 1840 |
-
tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
|
| 1841 |
-
|
| 1842 |
-
# Generate visualization using head_view (better for showing specific heads)
|
| 1843 |
-
html_result = head_view(attentions, tokens, html_action='return')
|
| 1844 |
-
base_html = html_result.data if hasattr(html_result, 'data') else str(html_result)
|
| 1845 |
-
|
| 1846 |
-
# Create a legend mapping head indices to their actual layer-head labels
|
| 1847 |
-
legend_items = []
|
| 1848 |
-
head_counter = 0
|
| 1849 |
-
for layer_num, _ in layer_attention_pairs:
|
| 1850 |
-
head_indices = category_map[layer_num]
|
| 1851 |
-
for head_idx in head_indices:
|
| 1852 |
-
legend_items.append(f"Head {head_counter}: L{layer_num}-H{head_idx}")
|
| 1853 |
-
head_counter += 1
|
| 1854 |
-
|
| 1855 |
-
legend_html = """
|
| 1856 |
-
<div style="background-color: #f8f9fa; padding: 10px; margin-bottom: 10px; border-radius: 5px; border: 1px solid #dee2e6;">
|
| 1857 |
-
<strong style="color: #495057;">Head Index Reference:</strong><br/>
|
| 1858 |
-
<div style="font-size: 12px; color: #6c757d; margin-top: 5px; display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 5px;">
|
| 1859 |
-
{items}
|
| 1860 |
-
</div>
|
| 1861 |
-
</div>
|
| 1862 |
-
""".format(items=''.join(f'<span>{item}</span>' for item in legend_items))
|
| 1863 |
-
|
| 1864 |
-
# Prepend legend to the visualization
|
| 1865 |
-
return legend_html + base_html
|
| 1866 |
-
|
| 1867 |
-
except Exception as e:
|
| 1868 |
-
import traceback
|
| 1869 |
-
traceback.print_exc()
|
| 1870 |
-
return f"<p>Error generating category visualization: {str(e)}</p>"
|
| 1871 |
-
|
| 1872 |
-
|
| 1873 |
-
def generate_head_view_with_categories(activation_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 1874 |
-
"""
|
| 1875 |
-
Generate BertViz head view HTML along with head categorization data.
|
| 1876 |
-
|
| 1877 |
-
Combines the head_view visualization with categorization from head_detection.py
|
| 1878 |
-
to provide both visual attention patterns and semantic categorization.
|
| 1879 |
-
|
| 1880 |
-
Args:
|
| 1881 |
-
activation_data: Output from execute_forward_pass with attention data
|
| 1882 |
-
|
| 1883 |
-
Returns:
|
| 1884 |
-
Dict with:
|
| 1885 |
-
- 'html': BertViz head_view HTML string
|
| 1886 |
-
- 'categories': Dict from categorize_all_heads (category -> list of head info)
|
| 1887 |
-
- 'summary': Formatted text summary of head categorization
|
| 1888 |
-
- 'error': Error message if visualization failed (optional)
|
| 1889 |
-
"""
|
| 1890 |
-
from .head_detection import categorize_all_heads, format_categorization_summary
|
| 1891 |
-
|
| 1892 |
-
result = {
|
| 1893 |
-
'html': None,
|
| 1894 |
-
'categories': {},
|
| 1895 |
-
'summary': '',
|
| 1896 |
-
'error': None
|
| 1897 |
-
}
|
| 1898 |
-
|
| 1899 |
-
# Generate the base head_view visualization
|
| 1900 |
-
try:
|
| 1901 |
-
result['html'] = generate_bertviz_html(activation_data, layer_index=0, view_type='full')
|
| 1902 |
-
except Exception as e:
|
| 1903 |
-
result['error'] = f"Failed to generate head view: {str(e)}"
|
| 1904 |
-
result['html'] = f"<p>Error generating visualization: {str(e)}</p>"
|
| 1905 |
-
|
| 1906 |
-
# Generate head categorization
|
| 1907 |
-
try:
|
| 1908 |
-
result['categories'] = categorize_all_heads(activation_data)
|
| 1909 |
-
result['summary'] = format_categorization_summary(result['categories'])
|
| 1910 |
-
except Exception as e:
|
| 1911 |
-
if result['error']:
|
| 1912 |
-
result['error'] += f"; Categorization failed: {str(e)}"
|
| 1913 |
-
else:
|
| 1914 |
-
result['error'] = f"Categorization failed: {str(e)}"
|
| 1915 |
-
|
| 1916 |
-
return result
|
| 1917 |
-
|
| 1918 |
-
|
| 1919 |
def get_head_category_counts(activation_data: Dict[str, Any]) -> Dict[str, int]:
|
| 1920 |
"""
|
| 1921 |
Get counts of attention heads in each category.
|
|
|
|
| 657 |
return result
|
| 658 |
|
| 659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
|
| 661 |
ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
|
| 662 |
"""
|
|
|
|
| 991 |
return None
|
| 992 |
|
| 993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
|
| 995 |
layer_wise_deltas: Dict[int, Dict[str, float]],
|
| 996 |
actual_output_token: str,
|
|
|
|
| 1034 |
return significant_layers
|
| 1035 |
|
| 1036 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
|
| 1038 |
"""
|
| 1039 |
Extract layer-by-layer data for accordion display with top-5, deltas, and attention.
|
| 1040 |
Also tracks global top 5 tokens across all layers.
|
| 1041 |
|
| 1042 |
Returns:
|
| 1043 |
+
List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas,
|
| 1044 |
global_top5_probs, global_top5_deltas
|
| 1045 |
"""
|
| 1046 |
layer_modules = activation_data.get('block_modules', [])
|
|
|
|
| 1081 |
for layer_num, module_name in layer_info:
|
| 1082 |
top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if can_compute_predictions else None
|
| 1083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1084 |
# Get probabilities for global top 5 tokens at this layer
|
| 1085 |
global_top5_probs = {}
|
| 1086 |
global_top5_deltas = {}
|
|
|
|
| 1111 |
'top_3_tokens': top_tokens[:3], # Keep for backward compatibility
|
| 1112 |
'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
|
| 1113 |
'deltas': deltas,
|
|
|
|
| 1114 |
'global_top5_probs': global_top5_probs, # New: global top 5 probs at this layer
|
| 1115 |
'global_top5_deltas': global_top5_deltas # New: global top 5 deltas
|
| 1116 |
})
|
|
|
|
| 1127 |
'top_3_tokens': [],
|
| 1128 |
'top_5_tokens': [],
|
| 1129 |
'deltas': {},
|
|
|
|
| 1130 |
'global_top5_probs': {},
|
| 1131 |
'global_top5_deltas': {}
|
| 1132 |
})
|
|
|
|
| 1271 |
return f"<p>Error generating visualization: {str(e)}</p>"
|
| 1272 |
|
| 1273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1274 |
def get_head_category_counts(activation_data: Dict[str, Any]) -> Dict[str, int]:
|
| 1275 |
"""
|
| 1276 |
Get counts of attention heads in each category.
|