cdpearlman Cursor commited on
Commit
2965a7d
·
1 Parent(s): aabd66e

chore: remove unused code, imports, and deprecated functions

Browse files

- Delete components/tokenization_panel.py (superseded by pipeline.py)

- Remove 6 unused imports from app.py

- Remove deprecated _get_top_attended_tokens() and references

- Remove unused create_stage_summary() from pipeline.py

- Remove 7 unused utility functions from model_patterns.py and beam_search.py

- Update utils/__init__.py exports and README.md

Total: ~1,087 lines removed. All 81 tests pass.
Co-authored-by: Cursor <cursoragent@cursor.com>

README.md CHANGED
@@ -77,10 +77,9 @@ Open your browser and navigate to `http://127.0.0.1:8050/`.
77
 
78
  * `app.py`: Main application entry point and layout orchestration.
79
  * `components/`: Modular UI components.
80
- * `pipeline.py`: The core 5-stage visualization.
81
  * `investigation_panel.py`: Ablation and attribution interfaces.
82
  * `ablation_panel.py`: Specific logic for head ablation UI.
83
- * `tokenization_panel.py`: Token visualization.
84
  * `utils/`: Backend logic and helper functions.
85
  * `model_patterns.py`: Activation capture and hooking logic.
86
  * `model_config.py`: Model family definitions and auto-detection.
 
77
 
78
  * `app.py`: Main application entry point and layout orchestration.
79
  * `components/`: Modular UI components.
80
+ * `pipeline.py`: The core 5-stage visualization with tokenization display.
81
  * `investigation_panel.py`: Ablation and attribution interfaces.
82
  * `ablation_panel.py`: Specific logic for head ablation UI.
 
83
  * `utils/`: Backend logic and helper functions.
84
  * `model_patterns.py`: Activation capture and hooking logic.
85
  * `model_config.py`: Model family definitions and auto-detection.
app.py CHANGED
@@ -10,11 +10,7 @@ from dash import html, dcc, Input, Output, State, callback, no_update, ALL, MATC
10
  import json
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
- categorize_single_layer_heads, perform_beam_search,
14
- execute_forward_pass_with_head_ablation,
15
- execute_forward_pass_with_multi_layer_head_ablation,
16
- evaluate_sequence_ablation, score_sequence,
17
- get_head_category_counts, generate_bertviz_model_view_html)
18
  from utils.head_detection import categorize_all_heads
19
  from utils.model_config import get_auto_selections
20
  from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
@@ -523,11 +519,6 @@ def update_pipeline_content(activation_data, model_name):
523
  else:
524
  top_tokens = []
525
 
526
- # Get attention info from first layer
527
- top_attended = None
528
- if layer_data:
529
- top_attended = layer_data[0].get('top_attended_tokens', [])
530
-
531
  # Generate BertViz HTML
532
  from utils import generate_bertviz_html
533
  attention_html = None
 
10
  import json
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
+ perform_beam_search, execute_forward_pass_with_multi_layer_head_ablation)
 
 
 
 
14
  from utils.head_detection import categorize_all_heads
15
  from utils.model_config import get_auto_selections
16
  from utils.token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution
 
519
  else:
520
  top_tokens = []
521
 
 
 
 
 
 
522
  # Generate BertViz HTML
523
  from utils import generate_bertviz_html
524
  attention_html = None
components/pipeline.py CHANGED
@@ -809,26 +809,3 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
809
 
810
  return html.Div(content_items)
811
 
812
-
813
- def create_stage_summary(stage_id, activation_data=None, model_config=None):
814
- """
815
- Generate summary text for a stage (shown when collapsed).
816
-
817
- Args:
818
- stage_id: Stage identifier ('tokenization', 'embedding', etc.)
819
- activation_data: Optional activation data from forward pass
820
- model_config: Optional model configuration
821
- """
822
- if not activation_data:
823
- return "Awaiting input..."
824
-
825
- summaries = {
826
- 'tokenization': lambda: f"{len(activation_data.get('input_ids', [[]])[0])} tokens",
827
- 'embedding': lambda: f"{model_config.hidden_size if model_config else 768}-dim vectors" if model_config else "Vectors ready",
828
- 'attention': lambda: f"{model_config.num_attention_heads if model_config else 12} heads" if model_config else "Context gathered",
829
- 'mlp': lambda: f"{model_config.num_hidden_layers if model_config else 12} layers" if model_config else "Transformations applied",
830
- 'output': lambda: f"→ {activation_data.get('actual_output', {}).get('token', '?')}" if activation_data.get('actual_output') else "Output computed"
831
- }
832
-
833
- return summaries.get(stage_id, lambda: "")()
834
-
 
809
 
810
  return html.Div(content_items)
811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
components/tokenization_panel.py DELETED
@@ -1,302 +0,0 @@
1
- """
2
- Tokenization panel component for visualizing the tokenization process.
3
-
4
- Displays tokens in vertical rows: [token] → [ID] → [embedding] per token.
5
- """
6
-
7
- from dash import html, dcc
8
-
9
-
10
- def create_static_tokenization_diagram():
11
- """Create static HTML/CSS diagram showing example tokenization flow."""
12
- return html.Div([
13
- html.H4("Example: How text becomes model input",
14
- style={'marginBottom': '1rem', 'color': '#495057', 'fontSize': '16px'}),
15
-
16
- # Example flow: text -> tokens -> IDs -> embeddings
17
- html.Div([
18
- # Input text
19
- html.Div([
20
- html.Div('"Hello world"',
21
- className='example-text',
22
- style={'padding': '8px 12px', 'backgroundColor': '#e9ecef',
23
- 'borderRadius': '4px', 'fontFamily': 'monospace'})
24
- ], style={'flex': '1', 'textAlign': 'center'}),
25
-
26
- html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
27
-
28
- # Tokens
29
- html.Div([
30
- html.Div([
31
- html.Span('["Hello"', style={'fontFamily': 'monospace'}),
32
- html.Span(', " world"]', style={'fontFamily': 'monospace'})
33
- ], style={'padding': '8px 12px', 'backgroundColor': '#d4edff',
34
- 'borderRadius': '4px'})
35
- ], style={'flex': '1', 'textAlign': 'center'}),
36
-
37
- html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
38
-
39
- # IDs
40
- html.Div([
41
- html.Div('[1234, 5678]',
42
- style={'padding': '8px 12px', 'backgroundColor': '#ffe5d4',
43
- 'borderRadius': '4px', 'fontFamily': 'monospace'})
44
- ], style={'flex': '1', 'textAlign': 'center'}),
45
-
46
- html.Div('→', style={'padding': '0 10px', 'fontSize': '20px', 'color': '#6c757d'}),
47
-
48
- # Embeddings
49
- html.Div([
50
- html.Div('[[ ... ], [ ... ]]',
51
- style={'padding': '8px 12px', 'backgroundColor': '#e5d4ff',
52
- 'borderRadius': '4px', 'fontFamily': 'monospace'})
53
- ], style={'flex': '1', 'textAlign': 'center'})
54
-
55
- ], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center',
56
- 'padding': '1rem', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
57
- 'border': '1px solid #dee2e6'})
58
- ], style={'marginBottom': '2rem'})
59
-
60
-
61
- def create_tokenization_panel():
62
- """Create the tokenization visualization panel with three columns."""
63
- return html.Div([
64
- # Section title and subtitle
65
- html.Div([
66
- html.H3("Step 1: Tokenization & Embedding",
67
- className="section-title",
68
- style={'marginBottom': '0.5rem'}),
69
- html.P("This is the first step in processing text through a transformer model. "
70
- "The input text is broken into tokens, converted to IDs, and embedded as vectors.",
71
- style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '1.5rem'})
72
- ]),
73
-
74
- # Static example diagram (always visible)
75
- create_static_tokenization_diagram(),
76
-
77
- # Dynamic tokenization display container (populated by callback)
78
- html.Div(id='tokenization-display-container', children=[
79
- # This will be populated after analysis runs
80
- ])
81
-
82
- ], id='tokenization-panel', style={'display': 'none'}, className='tokenization-section')
83
-
84
-
85
- def create_tokenization_display(tokens_list, token_ids_list, color_palette=None):
86
- """
87
- Create a vertical tokenization display showing each token's flow.
88
-
89
- Args:
90
- tokens_list: List of token strings
91
- token_ids_list: List of token IDs
92
- color_palette: Optional list of colors for each token (auto-generated if None)
93
-
94
- Returns:
95
- Dash HTML component with vertical token rows: [token] → [ID] → [embedding]
96
- """
97
- if color_palette is None:
98
- # Generate distinct colors for each token
99
- color_palette = generate_token_colors(len(tokens_list))
100
-
101
- preview_token = tokens_list[0] if tokens_list else ""
102
- preview_id = token_ids_list[0] if token_ids_list else ""
103
- preview_color = color_palette[0] if color_palette else '#f8f9fa'
104
-
105
- return html.Details([
106
- html.Summary(
107
- html.Div([
108
- html.Span("Tokenization preview:", style={'color': '#6c757d', 'fontSize': '13px'}),
109
- html.Span(preview_token, style={
110
- 'padding': '4px 8px',
111
- 'backgroundColor': preview_color,
112
- 'borderRadius': '4px',
113
- 'fontFamily': 'monospace',
114
- 'fontSize': '12px'
115
- }),
116
- html.Span('→', style={'color': '#6c757d'}),
117
- html.Span(str(preview_id), style={
118
- 'padding': '4px 8px',
119
- 'backgroundColor': '#ffe5d4',
120
- 'borderRadius': '4px',
121
- 'fontFamily': 'monospace',
122
- 'fontSize': '12px'
123
- }),
124
- html.Span('→', style={'color': '#6c757d'}),
125
- html.Span('[ ... ]', style={
126
- 'padding': '4px 8px',
127
- 'backgroundColor': '#e5d4ff',
128
- 'borderRadius': '4px',
129
- 'fontFamily': 'monospace',
130
- 'fontSize': '12px'
131
- }),
132
- html.Span('...', style={'color': '#6c757d'}),
133
- html.Span("Expand", style={'marginLeft': 'auto', 'color': '#667eea', 'fontWeight': '500'})
134
- ], style={'display': 'flex', 'alignItems': 'center', 'gap': '8px', 'flexWrap': 'wrap'})
135
- ),
136
-
137
- html.Div([
138
- html.H4("Full Tokenization:",
139
- style={'marginTop': '1.5rem', 'marginBottom': '1rem',
140
- 'color': '#495057', 'fontSize': '16px'}),
141
-
142
- # Column headers row
143
- html.Div([
144
- html.Span("Token", className='token-header',
145
- style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'}),
146
- html.Span("", style={'width': '32px'}), # Arrow spacer
147
- html.Span("ID", className='token-header',
148
- style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'}),
149
- html.Span("", style={'width': '32px'}), # Arrow spacer
150
- html.Span("Embedding", className='token-header',
151
- style={'flex': '1', 'fontWeight': '600', 'color': '#495057', 'fontSize': '13px'})
152
- ], className='tokenization-header-row',
153
- style={'display': 'flex', 'alignItems': 'center', 'gap': '4px',
154
- 'marginBottom': '0.75rem', 'paddingBottom': '0.5rem',
155
- 'borderBottom': '1px solid #e9ecef'}),
156
-
157
- # Vertical token rows - each row shows [token] → [ID] → [embedding]
158
- html.Div([
159
- create_token_row(token, token_id, color, idx)
160
- for idx, (token, token_id, color) in enumerate(zip(tokens_list, token_ids_list, color_palette))
161
- ], className='tokenization-rows')
162
-
163
- ], style={'padding': '1rem', 'backgroundColor': '#ffffff',
164
- 'borderRadius': '8px', 'border': '1px solid #dee2e6'})
165
-
166
- ], open=False, style={'marginTop': '1rem'})
167
-
168
-
169
- def create_token_row(token, token_id, color, idx):
170
- """
171
- Create a single horizontal row showing: [token] → [ID] → [embedding].
172
-
173
- Args:
174
- token: Token string
175
- token_id: Token ID number
176
- color: Background color for the token
177
- idx: Index of the token (for key uniqueness)
178
-
179
- Returns:
180
- Dash HTML component for a single token row
181
- """
182
- # Tooltip text for educational purposes
183
- tooltips = {
184
- 'token': "The text is broken into 'tokens' - small pieces like words or parts of words. "
185
- "This is how the model reads text. Breaking words into smaller pieces lets the model "
186
- "understand new words by combining pieces it already knows.",
187
- 'id': "Each token gets a unique number (ID) from the model's dictionary. "
188
- "Think of it like a phonebook - every token has its own number. "
189
- "The model uses these numbers instead of the actual text.",
190
- 'embedding': "Each token number is turned into a list of numbers called an 'embedding.' "
191
- "These numbers capture the token's meaning. Similar words get similar numbers. "
192
- "This list of numbers is what actually goes into the model's layers."
193
- }
194
-
195
- return html.Div([
196
- # Token box
197
- html.Div(
198
- token,
199
- className='token-row-box token-row-token',
200
- style={
201
- 'flex': '1',
202
- 'padding': '8px 12px',
203
- 'backgroundColor': color,
204
- 'borderRadius': '6px',
205
- 'border': f'2px solid {darken_color(color)}',
206
- 'fontFamily': 'monospace',
207
- 'fontSize': '13px',
208
- 'textAlign': 'center',
209
- 'wordBreak': 'break-word',
210
- 'minWidth': '60px'
211
- },
212
- title=tooltips['token']
213
- ),
214
-
215
- # Arrow
216
- html.Span('→', className='token-row-arrow',
217
- style={'color': '#6c757d', 'fontSize': '16px', 'padding': '0 8px'}),
218
-
219
- # ID box
220
- html.Div(
221
- str(token_id),
222
- className='token-row-box token-row-id',
223
- style={
224
- 'flex': '1',
225
- 'padding': '8px 12px',
226
- 'backgroundColor': '#ffe5d4',
227
- 'borderRadius': '6px',
228
- 'border': '2px solid #e6cfc0',
229
- 'fontFamily': 'monospace',
230
- 'fontSize': '13px',
231
- 'textAlign': 'center',
232
- 'minWidth': '60px'
233
- },
234
- title=tooltips['id']
235
- ),
236
-
237
- # Arrow
238
- html.Span('→', className='token-row-arrow',
239
- style={'color': '#6c757d', 'fontSize': '16px', 'padding': '0 8px'}),
240
-
241
- # Embedding box
242
- html.Div(
243
- '[ ... ]',
244
- className='token-row-box token-row-embedding',
245
- style={
246
- 'flex': '1',
247
- 'padding': '8px 12px',
248
- 'backgroundColor': '#e5d4ff',
249
- 'borderRadius': '6px',
250
- 'border': '2px solid #cfbfe6',
251
- 'fontFamily': 'monospace',
252
- 'fontSize': '13px',
253
- 'textAlign': 'center',
254
- 'minWidth': '60px'
255
- },
256
- title=tooltips['embedding']
257
- )
258
-
259
- ], className='token-row',
260
- style={'display': 'flex', 'alignItems': 'center', 'gap': '4px', 'marginBottom': '8px'})
261
-
262
-
263
- def generate_token_colors(num_tokens):
264
- """Generate a list of distinct colors for tokens."""
265
- # Predefined pleasant color palette
266
- base_colors = [
267
- '#ffcccb', # Light red
268
- '#add8e6', # Light blue
269
- '#90ee90', # Light green
270
- '#ffb6c1', # Light pink
271
- '#ffd700', # Gold
272
- '#dda0dd', # Plum
273
- '#f0e68c', # Khaki
274
- '#ff6347', # Tomato
275
- '#98fb98', # Pale green
276
- '#87ceeb', # Sky blue
277
- '#ffa07a', # Light salmon
278
- '#da70d6', # Orchid
279
- ]
280
-
281
- # Cycle through colors if we have more tokens than colors
282
- colors = []
283
- for i in range(num_tokens):
284
- colors.append(base_colors[i % len(base_colors)])
285
-
286
- return colors
287
-
288
-
289
- def darken_color(hex_color, factor=0.8):
290
- """Darken a hex color by a factor."""
291
- # Remove '#' if present
292
- hex_color = hex_color.lstrip('#')
293
-
294
- # Convert to RGB
295
- r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
296
-
297
- # Darken
298
- r, g, b = int(r * factor), int(g * factor), int(b * factor)
299
-
300
- # Convert back to hex
301
- return f'#{r:02x}{g:02x}{b:02x}'
302
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
todo.md CHANGED
@@ -112,3 +112,22 @@
112
  - [x] Replace per-layer ablation loop in app.py with single call to new function
113
  - [x] Add 5 tests for multi-layer ablation in test_model_patterns.py
114
  - [x] Verify all 78 tests pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  - [x] Replace per-layer ablation loop in app.py with single call to new function
113
  - [x] Add 5 tests for multi-layer ablation in test_model_patterns.py
114
  - [x] Verify all 78 tests pass
115
+
116
+ ## Completed: Codebase Cleanup
117
+
118
+ - [x] Delete unused file: `components/tokenization_panel.py` (302 lines, 6 functions)
119
+ - [x] Remove 6 unused imports from `app.py`
120
+ - [x] Remove deprecated `_get_top_attended_tokens()` function from model_patterns.py
121
+ - [x] Remove `top_attended_tokens` field from extract_layer_data() return values
122
+ - [x] Remove unused `create_stage_summary()` function from pipeline.py
123
+ - [x] Remove 7 unused utility functions from utils/:
124
+ - `get_check_token_probabilities`
125
+ - `execute_forward_pass_with_layer_ablation`
126
+ - `generate_category_bertviz_html`
127
+ - `generate_head_view_with_categories`
128
+ - `compute_sequence_trajectory`
129
+ - `compute_layer_wise_summaries`
130
+ - `compute_position_layer_matrix`
131
+ - [x] Update `utils/__init__.py` exports
132
+ - [x] Update README.md to remove reference to deleted file
133
+ - [x] Verify all 81 tests pass
utils/__init__.py CHANGED
@@ -1,17 +1,14 @@
1
  from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
2
  logit_lens_transformation, extract_layer_data,
3
- generate_bertviz_html, generate_category_bertviz_html,
4
- generate_head_view_with_categories, get_head_category_counts,
5
- get_check_token_probabilities, execute_forward_pass_with_layer_ablation,
6
  execute_forward_pass_with_head_ablation,
7
  execute_forward_pass_with_multi_layer_head_ablation,
8
  merge_token_probabilities,
9
  compute_global_top5_tokens, detect_significant_probability_increases,
10
- compute_layer_wise_summaries, evaluate_sequence_ablation,
11
- compute_position_layer_matrix, generate_bertviz_model_view_html)
12
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
13
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
14
- from .beam_search import perform_beam_search, compute_sequence_trajectory
15
  from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
16
  from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
17
 
@@ -20,22 +17,15 @@ __all__ = [
20
  # Model patterns
21
  'load_model_and_get_patterns',
22
  'execute_forward_pass',
23
- 'execute_forward_pass_with_layer_ablation',
24
  'execute_forward_pass_with_head_ablation',
25
  'execute_forward_pass_with_multi_layer_head_ablation',
26
  'evaluate_sequence_ablation',
27
  'logit_lens_transformation',
28
  'extract_layer_data',
29
  'generate_bertviz_html',
30
- 'generate_category_bertviz_html',
31
- 'generate_head_view_with_categories',
32
- 'get_head_category_counts',
33
- 'get_check_token_probabilities',
34
  'merge_token_probabilities',
35
  'compute_global_top5_tokens',
36
  'detect_significant_probability_increases',
37
- 'compute_layer_wise_summaries',
38
- 'compute_position_layer_matrix',
39
  'generate_bertviz_model_view_html',
40
 
41
  # Model config
@@ -53,7 +43,6 @@ __all__ = [
53
 
54
  # Beam search
55
  'perform_beam_search',
56
- 'compute_sequence_trajectory',
57
 
58
  # Ablation metrics
59
  'compute_kl_divergence',
 
1
  from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
2
  logit_lens_transformation, extract_layer_data,
3
+ generate_bertviz_html,
 
 
4
  execute_forward_pass_with_head_ablation,
5
  execute_forward_pass_with_multi_layer_head_ablation,
6
  merge_token_probabilities,
7
  compute_global_top5_tokens, detect_significant_probability_increases,
8
+ evaluate_sequence_ablation, generate_bertviz_model_view_html)
 
9
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
10
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
11
+ from .beam_search import perform_beam_search
12
  from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
13
  from .token_attribution import compute_integrated_gradients, compute_simple_gradient_attribution, create_attribution_visualization_data
14
 
 
17
  # Model patterns
18
  'load_model_and_get_patterns',
19
  'execute_forward_pass',
 
20
  'execute_forward_pass_with_head_ablation',
21
  'execute_forward_pass_with_multi_layer_head_ablation',
22
  'evaluate_sequence_ablation',
23
  'logit_lens_transformation',
24
  'extract_layer_data',
25
  'generate_bertviz_html',
 
 
 
 
26
  'merge_token_probabilities',
27
  'compute_global_top5_tokens',
28
  'detect_significant_probability_increases',
 
 
29
  'generate_bertviz_model_view_html',
30
 
31
  # Model config
 
43
 
44
  # Beam search
45
  'perform_beam_search',
 
46
 
47
  # Ablation metrics
48
  'compute_kl_divergence',
utils/beam_search.py CHANGED
@@ -4,9 +4,7 @@ Beam search utility for text generation and sequence analysis.
4
 
5
  import torch
6
  import torch.nn.functional as F
7
- from typing import List, Tuple, Dict, Any, Optional
8
- import numpy as np
9
- from utils.model_patterns import get_norm_layer_from_parameter
10
  import re
11
 
12
  def _make_head_ablation_hook(head_indices: List[int], num_heads: int):
@@ -179,90 +177,3 @@ def perform_beam_search(model, tokenizer, prompt: str, beam_width: int = 3, max_
179
  # Ensure hooks are removed even if error occurs
180
  for hook in hooks:
181
  hook.remove()
182
-
183
- def compute_sequence_trajectory(activation_data: Dict[str, Any], model, tokenizer) -> Dict[int, List[float]]:
184
- """
185
- Compute the trajectory of the sequence score across layers.
186
-
187
- For each layer, calculates the probability assigned to the *actual* next token
188
- at each step of the sequence.
189
-
190
- Args:
191
- activation_data: Data from execute_forward_pass (must contain block_outputs for all layers)
192
- model: HuggingFace model
193
- tokenizer: HuggingFace tokenizer
194
-
195
- Returns:
196
- Dict mapping layer_num -> list of scores (one per step in the generated sequence)
197
- """
198
- if not activation_data or 'block_outputs' not in activation_data:
199
- return {}
200
-
201
- # Extract layer outputs
202
- block_outputs = activation_data['block_outputs']
203
- input_ids = activation_data['input_ids']
204
-
205
- if isinstance(input_ids, list):
206
- input_ids = torch.tensor(input_ids)
207
-
208
- # Identify tokens: input_ids shape is [1, seq_len]
209
- # The "generated" part starts after the prompt, but here we likely have the full sequence.
210
- # We want to evaluate P(token_t | tokens_<t) for the whole sequence or just the new part?
211
- # Usually, we visualize the whole sequence.
212
-
213
- # We need the logits from each layer
214
- # block_outputs keys are like "model.layers.0", "model.layers.1", etc.
215
-
216
- # Sort layers
217
- import re
218
- layer_info = sorted(
219
- [(int(re.findall(r'\d+', name)[0]), name)
220
- for name in block_outputs.keys() if re.findall(r'\d+', name)]
221
- )
222
-
223
- # Get norm parameter for logit lens
224
- norm_params = activation_data.get('norm_parameters', [])
225
- norm_parameter = norm_params[0] if norm_params else None
226
- final_norm = get_norm_layer_from_parameter(model, norm_parameter)
227
- lm_head = model.get_output_embeddings()
228
-
229
- trajectories = {}
230
-
231
- # We only care about predictions for positions 0 to N-1 (predicting 1 to N)
232
- target_ids = input_ids[0, 1:]
233
-
234
- with torch.no_grad():
235
- for layer_num, module_name in layer_info:
236
- output_data = block_outputs[module_name]['output']
237
-
238
- # Convert to tensor [batch, seq_len, hidden_dim]
239
- hidden = torch.tensor(output_data) if not isinstance(output_data, torch.Tensor) else output_data
240
- if hidden.dim() == 4: # PyVene sometimes returns [1, 1, seq_len, dim] ? No usually [1, seq, dim]
241
- # If shape is weird, adjust
242
- pass
243
-
244
- # Ensure batch dim
245
- if hidden.dim() == 2:
246
- hidden = hidden.unsqueeze(0)
247
-
248
- # Apply final norm
249
- if final_norm is not None:
250
- hidden = final_norm(hidden)
251
-
252
- # Project to logits
253
- logits = lm_head(hidden) # [batch, seq_len, vocab_size]
254
-
255
- # We want log probs of the *next* token
256
- # Logits at pos t predict token at t+1
257
- # So we take logits at [0, :-1, :] and gather targets [0, 1:]
258
-
259
- shift_logits = logits[0, :-1, :]
260
- log_probs = F.log_softmax(shift_logits, dim=-1)
261
-
262
- # Gather log probs of the actual target tokens
263
- # target_ids shape [seq_len-1]
264
- target_log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1)
265
-
266
- trajectories[layer_num] = target_log_probs.tolist()
267
-
268
- return trajectories
 
4
 
5
  import torch
6
  import torch.nn.functional as F
7
+ from typing import List, Dict, Any, Optional
 
 
8
  import re
9
 
10
  def _make_head_ablation_hook(head_indices: List[int], num_heads: int):
 
177
  # Ensure hooks are removed even if error occurs
178
  for hook in hooks:
179
  hook.remove()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/model_patterns.py CHANGED
@@ -657,174 +657,6 @@ def execute_forward_pass_with_multi_layer_head_ablation(model, tokenizer, prompt
657
  return result
658
 
659
 
660
- def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, config: Dict[str, Any],
661
- ablate_layer_num: int, reference_activation_data: Dict[str, Any]) -> Dict[str, Any]:
662
- """
663
- Execute forward pass with mean ablation on a specific layer.
664
-
665
- Args:
666
- model: Loaded transformer model
667
- tokenizer: Loaded tokenizer
668
- prompt: Input text prompt
669
- config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
670
- ablate_layer_num: Layer number to ablate
671
- reference_activation_data: Original activation data containing the reference activations
672
-
673
- Returns:
674
- JSON-serializable dict with captured activations (with ablated layer)
675
- """
676
- # Extract module lists from config
677
- attention_modules = config.get("attention_modules", [])
678
- block_modules = config.get("block_modules", [])
679
- norm_parameters = config.get("norm_parameters", [])
680
- logit_lens_parameter = config.get("logit_lens_parameter")
681
-
682
- all_modules = attention_modules + block_modules
683
- if not all_modules:
684
- return {"error": "No modules specified"}
685
-
686
- # Find the target module for the layer to ablate
687
- target_module_name = None
688
- for mod_name in block_modules:
689
- layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
690
- if layer_match and int(layer_match.group(1)) == ablate_layer_num:
691
- target_module_name = mod_name
692
- break
693
-
694
- if not target_module_name:
695
- return {"error": f"Could not find module for layer {ablate_layer_num}"}
696
-
697
- # Get reference activations from ALL layers for mean computation
698
- block_outputs = reference_activation_data.get('block_outputs', {})
699
- if not block_outputs:
700
- return {"error": "No block outputs found in reference data"}
701
-
702
- # Collect all layer activations to compute global mean
703
- all_layer_tensors = []
704
- for mod_name, output_data in block_outputs.items():
705
- output = output_data['output']
706
- if isinstance(output, list):
707
- tensor = torch.tensor(output)
708
- else:
709
- tensor = output
710
- all_layer_tensors.append(tensor)
711
-
712
- # Stack all layers and compute mean across ALL layers and sequence positions
713
- # This gives us a single mean vector that represents the average activation
714
- stacked = torch.stack(all_layer_tensors, dim=0) # [num_layers, batch, seq_len, hidden_dim]
715
- # Compute mean across layers and sequence dimension
716
- mean_activation = stacked.mean(dim=(0, 2), keepdim=True) # [1, batch, 1, hidden_dim]
717
- mean_activation = mean_activation.squeeze(0) # [batch, 1, hidden_dim]
718
-
719
- # Prepare inputs
720
- inputs = tokenizer(prompt, return_tensors="pt")
721
- seq_len = inputs['input_ids'].shape[1]
722
-
723
- # Broadcast mean to match sequence length
724
- ablation_value = mean_activation.expand(-1, seq_len, -1) # [batch, seq_len, hidden_dim]
725
-
726
- # Build IntervenableConfig from module names
727
- intervenable_representations = []
728
- for mod_name in all_modules:
729
- layer_match = re.search(r'\.(\d+)(?:\.|$)', mod_name)
730
- if not layer_match:
731
- return {"error": f"Invalid module name format: {mod_name}"}
732
-
733
- if 'attn' in mod_name or 'attention' in mod_name:
734
- component = 'attention_output'
735
- else:
736
- component = 'block_output'
737
-
738
- intervenable_representations.append(
739
- RepresentationConfig(layer=int(layer_match.group(1)), component=component, unit="pos")
740
- )
741
-
742
- intervenable_config = IntervenableConfig(
743
- intervenable_representations=intervenable_representations
744
- )
745
- intervenable_model = IntervenableModel(intervenable_config, model)
746
-
747
- # Register hooks to capture activations
748
- captured = {}
749
- name_to_module = dict(intervenable_model.model.named_modules())
750
-
751
- def make_hook(mod_name: str):
752
- return lambda module, inputs, output: captured.update({mod_name: {"output": safe_to_serializable(output)}})
753
-
754
- # Register ablation hook for target module
755
- def ablation_hook(module, input, output):
756
- # Replace output with mean activation
757
- if isinstance(output, tuple):
758
- # For modules that return tuples (hidden_states, ...), replace first element
759
- ablated = (ablation_value,) + output[1:]
760
- return ablated
761
- else:
762
- return ablation_value
763
-
764
- hooks = []
765
- for mod_name in all_modules:
766
- if mod_name in name_to_module:
767
- if mod_name == target_module_name:
768
- # Apply ablation hook
769
- hooks.append(name_to_module[mod_name].register_forward_hook(ablation_hook))
770
- else:
771
- # Regular capture hook
772
- hooks.append(name_to_module[mod_name].register_forward_hook(make_hook(mod_name)))
773
-
774
- # Execute forward pass
775
- with torch.no_grad():
776
- model_output = intervenable_model.model(**inputs, use_cache=False)
777
-
778
- # Remove hooks
779
- for hook in hooks:
780
- hook.remove()
781
-
782
- # Capture ablated layer output as well
783
- captured[target_module_name] = {"output": safe_to_serializable(ablation_value)}
784
-
785
- # Separate outputs by type
786
- attention_outputs = {}
787
- block_outputs = {}
788
-
789
- for mod_name, output in captured.items():
790
- if 'attn' in mod_name or 'attention' in mod_name:
791
- attention_outputs[mod_name] = output
792
- else:
793
- block_outputs[mod_name] = output
794
-
795
- # Capture normalization parameters
796
- all_params = dict(model.named_parameters())
797
- norm_data = [safe_to_serializable(all_params[p]) for p in norm_parameters if p in all_params]
798
-
799
- # Extract predicted token from model output
800
- actual_output = None
801
- global_top5_tokens = []
802
- try:
803
- output_token, output_prob = get_actual_model_output(model_output, tokenizer)
804
- actual_output = {"token": output_token, "probability": output_prob}
805
- global_top5_tokens = compute_global_top5_tokens(model_output, tokenizer, top_k=5)
806
- except Exception as e:
807
- print(f"Warning: Could not extract model output: {e}")
808
-
809
- # Build output dictionary
810
- result = {
811
- "model": getattr(model.config, "name_or_path", "unknown"),
812
- "prompt": prompt,
813
- "input_ids": safe_to_serializable(inputs["input_ids"]),
814
- "attention_modules": list(attention_outputs.keys()),
815
- "attention_outputs": attention_outputs,
816
- "block_modules": list(block_outputs.keys()),
817
- "block_outputs": block_outputs,
818
- "norm_parameters": norm_parameters,
819
- "norm_data": norm_data,
820
- "actual_output": actual_output,
821
- "global_top5_tokens": global_top5_tokens,
822
- "ablated_layer": ablate_layer_num
823
- }
824
-
825
- return result
826
-
827
-
828
  def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
829
  ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
830
  """
@@ -1159,85 +991,6 @@ def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, to
1159
  return None
1160
 
1161
 
1162
- def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokenizer, check_token: str) -> Optional[Dict[str, Any]]:
1163
- """
1164
- Collect check token probabilities across all layers.
1165
-
1166
- Sums probabilities of token variants (with and without leading space).
1167
- Returns layer numbers and merged probabilities for plotting.
1168
- """
1169
- if not check_token or not check_token.strip():
1170
- return None
1171
-
1172
- try:
1173
- # Get block modules (all layers)
1174
- layer_modules = activation_data.get('block_modules', [])
1175
- if not layer_modules:
1176
- return None
1177
-
1178
- # Extract and sort layers
1179
- layer_info = sorted(
1180
- [(int(re.findall(r'\d+', name)[0]), name)
1181
- for name in layer_modules if re.findall(r'\d+', name)]
1182
- )
1183
-
1184
- # Try tokenizing with and without leading space
1185
- token_variants = [
1186
- (check_token.strip(), tokenizer.encode(check_token.strip(), add_special_tokens=False)),
1187
- (' ' + check_token.strip(), tokenizer.encode(' ' + check_token.strip(), add_special_tokens=False))
1188
- ]
1189
-
1190
- # Get token IDs for both variants (if they exist and differ)
1191
- target_token_ids = []
1192
- for variant_text, token_ids in token_variants:
1193
- if token_ids:
1194
- tid = token_ids[-1] # Use last sub-token
1195
- if tid not in target_token_ids:
1196
- target_token_ids.append(tid)
1197
-
1198
- if not target_token_ids:
1199
- return None
1200
-
1201
- # Get norm parameter
1202
- norm_params = activation_data.get('norm_parameters', [])
1203
- norm_parameter = norm_params[0] if norm_params else None
1204
- final_norm = get_norm_layer_from_parameter(model, norm_parameter)
1205
- lm_head = model.get_output_embeddings()
1206
-
1207
- # Collect probabilities for all layers (sum both variants)
1208
- layers = []
1209
- probabilities = []
1210
-
1211
- for layer_num, module_name in layer_info:
1212
- layer_output = activation_data['block_outputs'][module_name]['output']
1213
-
1214
- with torch.no_grad():
1215
- hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
1216
- if hidden.dim() == 4:
1217
- hidden = hidden.squeeze(0)
1218
-
1219
- if final_norm is not None:
1220
- hidden = final_norm(hidden)
1221
-
1222
- logits = lm_head(hidden)
1223
- probs = F.softmax(logits[0, -1, :], dim=-1)
1224
-
1225
- # Sum probabilities of all variants
1226
- merged_prob = sum(probs[tid].item() for tid in target_token_ids)
1227
-
1228
- layers.append(layer_num)
1229
- probabilities.append(merged_prob)
1230
-
1231
- return {
1232
- 'token': check_token.strip(), # Return canonical form without leading space
1233
- 'layers': layers,
1234
- 'probabilities': probabilities
1235
- }
1236
- except Exception as e:
1237
- print(f"Error computing check token probabilities: {e}")
1238
- return None
1239
-
1240
-
1241
  def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
1242
  layer_wise_deltas: Dict[int, Dict[str, float]],
1243
  actual_output_token: str,
@@ -1281,246 +1034,13 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
1281
  return significant_layers
1282
 
1283
 
1284
- def _get_top_attended_tokens(activation_data: Dict[str, Any], layer_num: int, tokenizer, top_k: int = 3) -> Optional[List[Tuple[str, float]]]:
1285
- """
1286
- DEPRECATED: This function is deprecated and will be removed in a future version.
1287
- Use head categorization from head_detection.py instead for more meaningful attention analysis.
1288
-
1289
- Get top-K attended input tokens for the current position (last token) in a layer.
1290
- Averages attention across all heads.
1291
-
1292
- Args:
1293
- activation_data: Output from execute_forward_pass
1294
- layer_num: Layer number to analyze
1295
- tokenizer: Tokenizer for decoding tokens
1296
- top_k: Number of top attended tokens to return
1297
-
1298
- Returns:
1299
- List of (token_string, attention_weight) tuples, sorted by weight (highest first)
1300
- """
1301
- import warnings
1302
- warnings.warn(
1303
- "_get_top_attended_tokens is deprecated. Use categorize_all_heads() from head_detection.py instead.",
1304
- DeprecationWarning,
1305
- stacklevel=2
1306
- )
1307
- try:
1308
- attention_outputs = activation_data.get('attention_outputs', {})
1309
- input_ids = activation_data.get('input_ids', [])
1310
-
1311
- # print(f"DEBUG _get_top_attended_tokens: layer_num={layer_num}, attention_outputs keys={list(attention_outputs.keys())}")
1312
-
1313
- if not attention_outputs or not input_ids:
1314
- print(f"DEBUG _get_top_attended_tokens: Missing data - attention_outputs empty={not attention_outputs}, input_ids empty={not input_ids}")
1315
- return None
1316
-
1317
- # Find attention output for this layer
1318
- target_module = None
1319
- for module_name in attention_outputs.keys():
1320
- numbers = re.findall(r'\d+', module_name)
1321
- if numbers and int(numbers[0]) == layer_num:
1322
- target_module = module_name
1323
- break
1324
-
1325
- if not target_module:
1326
- return None
1327
-
1328
- attention_output = attention_outputs[target_module]['output']
1329
- if not isinstance(attention_output, list) or len(attention_output) < 2:
1330
- return None
1331
-
1332
- # Get attention weights: [batch, heads, seq_len, seq_len]
1333
- attention_weights = torch.tensor(attention_output[1])
1334
-
1335
- # Average across heads: [seq_len, seq_len]
1336
- avg_attention = attention_weights[0].mean(dim=0)
1337
-
1338
- # Get attention from last position to all positions
1339
- last_pos_attention = avg_attention[-1, :] # [seq_len]
1340
-
1341
- # Get top-K attended positions
1342
- top_values, top_indices = torch.topk(last_pos_attention, min(top_k, len(last_pos_attention)))
1343
-
1344
- # Convert to tokens
1345
- input_ids_tensor = torch.tensor(input_ids[0]) if isinstance(input_ids[0], list) else torch.tensor(input_ids)
1346
- result = []
1347
- for idx, weight in zip(top_indices, top_values):
1348
- token_id = input_ids_tensor[idx].item()
1349
- token_str = tokenizer.decode([token_id], skip_special_tokens=False)
1350
- result.append((token_str, weight.item()))
1351
-
1352
- return result
1353
-
1354
- except Exception as e:
1355
- print(f"Warning: Could not compute attended tokens for layer {layer_num}: {e}")
1356
- return None
1357
-
1358
-
1359
- def compute_position_layer_matrix(activation_data: Dict[str, Any], model, tokenizer) -> Dict[str, Any]:
1360
- """
1361
- Compute a 2D matrix of layer-to-layer deltas for each token position.
1362
-
1363
- This function computes the top-token probability delta at each (layer, position) pair,
1364
- creating a heatmap-ready data structure.
1365
-
1366
- Args:
1367
- activation_data: Activation data from forward pass
1368
- model: Transformer model for logit lens computation
1369
- tokenizer: Tokenizer for decoding tokens
1370
-
1371
- Returns:
1372
- Dict with:
1373
- - 'matrix': 2D list [num_layers, seq_len] of delta values
1374
- - 'tokens': List of token strings for X-axis labels
1375
- - 'layer_nums': List of layer numbers for Y-axis labels
1376
- - 'top_tokens': 2D list [num_layers, seq_len] of top token strings at each cell
1377
- """
1378
- import copy
1379
- import numpy as np
1380
-
1381
- input_ids = activation_data.get('input_ids', [[]])
1382
- if not input_ids or not input_ids[0]:
1383
- return {'matrix': [], 'tokens': [], 'layer_nums': [], 'top_tokens': []}
1384
-
1385
- seq_len = len(input_ids[0])
1386
-
1387
- # Get token strings for X-axis labels
1388
- tokens = [tokenizer.decode([tid]) for tid in input_ids[0]]
1389
-
1390
- # Get layer modules and sort by layer number
1391
- layer_modules = activation_data.get('block_modules', [])
1392
- if not layer_modules:
1393
- return {'matrix': [], 'tokens': tokens, 'layer_nums': [], 'top_tokens': []}
1394
-
1395
- layer_info = sorted(
1396
- [(int(re.findall(r'\d+', name)[0]), name)
1397
- for name in layer_modules if re.findall(r'\d+', name)]
1398
- )
1399
- layer_nums = [ln for ln, _ in layer_info]
1400
- num_layers = len(layer_nums)
1401
-
1402
- # Helper function to slice data to a specific position (adapted from app.py)
1403
- def slice_data(data, pos):
1404
- if not data:
1405
- return data
1406
- sliced = copy.deepcopy(data)
1407
-
1408
- # Slice Block Outputs: [batch, seq, hidden] -> [batch, 1, hidden]
1409
- if 'block_outputs' in sliced:
1410
- for mod in sliced['block_outputs']:
1411
- out = sliced['block_outputs'][mod]['output']
1412
- if isinstance(out, list) and len(out) > 0 and isinstance(out[0], list):
1413
- if pos < len(out[0]):
1414
- sliced['block_outputs'][mod]['output'] = [[out[0][pos]]]
1415
-
1416
- # Slice Attention Outputs: [batch, heads, seq, seq] -> [batch, heads, 1, seq]
1417
- if 'attention_outputs' in sliced:
1418
- for mod in sliced['attention_outputs']:
1419
- out = sliced['attention_outputs'][mod]['output']
1420
- if len(out) > 1:
1421
- attns = out[1]
1422
- if isinstance(attns, list) and len(attns) > 0:
1423
- batch_0 = attns[0]
1424
- new_batch_0 = []
1425
- for head in batch_0:
1426
- if pos < len(head):
1427
- new_batch_0.append([head[pos]])
1428
- sliced['attention_outputs'][mod]['output'] = [out[0], [new_batch_0]] + out[2:]
1429
-
1430
- # Slice input_ids
1431
- if 'input_ids' in sliced:
1432
- ids = sliced['input_ids'][0]
1433
- if pos < len(ids):
1434
- sliced['input_ids'][0] = ids[:pos+1]
1435
-
1436
- return sliced
1437
-
1438
- # Initialize matrix and top_tokens 2D array
1439
- matrix = [[0.0] * seq_len for _ in range(num_layers)]
1440
- top_tokens_matrix = [[''] * seq_len for _ in range(num_layers)]
1441
-
1442
- # Compute delta for each position
1443
- for pos in range(seq_len):
1444
- sliced = slice_data(activation_data, pos)
1445
- layer_data = extract_layer_data(sliced, model, tokenizer)
1446
-
1447
- if not layer_data:
1448
- continue
1449
-
1450
- # Fill in matrix for this position
1451
- for layer_info_item in layer_data:
1452
- layer_num = layer_info_item.get('layer_num')
1453
- if layer_num is None or layer_num not in layer_nums:
1454
- continue
1455
-
1456
- layer_idx = layer_nums.index(layer_num)
1457
-
1458
- # Get top token and its delta (layer-to-layer change)
1459
- top_token = layer_info_item.get('top_token', '')
1460
- deltas = layer_info_item.get('deltas', {})
1461
-
1462
- # The delta for the top token represents how much it changed from prev layer
1463
- delta = deltas.get(top_token, 0.0) if top_token else 0.0
1464
-
1465
- matrix[layer_idx][pos] = delta
1466
- top_tokens_matrix[layer_idx][pos] = top_token if top_token else ''
1467
-
1468
- return {
1469
- 'matrix': matrix,
1470
- 'tokens': tokens,
1471
- 'layer_nums': layer_nums,
1472
- 'top_tokens': top_tokens_matrix
1473
- }
1474
-
1475
-
1476
- def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]], activation_data: Dict[str, Any]) -> Dict[str, Any]:
1477
- """
1478
- Compute summary structures from layer data for easy access.
1479
-
1480
- Args:
1481
- layer_data: List of layer data dicts from extract_layer_data()
1482
- activation_data: Activation data containing actual output token
1483
-
1484
- Returns:
1485
- Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
1486
- """
1487
- layer_wise_top5_probs = {} # layer_num -> {token: prob}
1488
- layer_wise_top5_deltas = {} # layer_num -> {token: delta}
1489
-
1490
- for layer_info in layer_data:
1491
- layer_num = layer_info.get('layer_num')
1492
- if layer_num is not None:
1493
- layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
1494
- layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
1495
-
1496
- # Extract actual output token from activation data
1497
- actual_output = activation_data.get('actual_output', {})
1498
- actual_output_token = actual_output.get('token', '').strip() if actual_output else ''
1499
-
1500
- # Detect significant layers based on actual output token
1501
- significant_layers = []
1502
- if actual_output_token:
1503
- significant_layers = detect_significant_probability_increases(
1504
- layer_wise_top5_probs,
1505
- layer_wise_top5_deltas,
1506
- actual_output_token,
1507
- threshold=1.0
1508
- )
1509
-
1510
- return {
1511
- 'layer_wise_top5_probs': layer_wise_top5_probs,
1512
- 'layer_wise_top5_deltas': layer_wise_top5_deltas,
1513
- 'significant_layers': significant_layers
1514
- }
1515
-
1516
-
1517
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
1518
  """
1519
  Extract layer-by-layer data for accordion display with top-5, deltas, and attention.
1520
  Also tracks global top 5 tokens across all layers.
1521
 
1522
  Returns:
1523
- List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas, top_attended_tokens,
1524
  global_top5_probs, global_top5_deltas
1525
  """
1526
  layer_modules = activation_data.get('block_modules', [])
@@ -1561,11 +1081,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
1561
  for layer_num, module_name in layer_info:
1562
  top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if can_compute_predictions else None
1563
 
1564
- # NOTE: top_attended_tokens is deprecated. Use categorize_all_heads() from
1565
- # head_detection.py instead for more meaningful attention analysis.
1566
- # Kept as None for backward compatibility with existing code.
1567
- top_attended = None
1568
-
1569
  # Get probabilities for global top 5 tokens at this layer
1570
  global_top5_probs = {}
1571
  global_top5_deltas = {}
@@ -1596,7 +1111,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
1596
  'top_3_tokens': top_tokens[:3], # Keep for backward compatibility
1597
  'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
1598
  'deltas': deltas,
1599
- 'top_attended_tokens': top_attended,
1600
  'global_top5_probs': global_top5_probs, # New: global top 5 probs at this layer
1601
  'global_top5_deltas': global_top5_deltas # New: global top 5 deltas
1602
  })
@@ -1613,7 +1127,6 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
1613
  'top_3_tokens': [],
1614
  'top_5_tokens': [],
1615
  'deltas': {},
1616
- 'top_attended_tokens': top_attended,
1617
  'global_top5_probs': {},
1618
  'global_top5_deltas': {}
1619
  })
@@ -1758,164 +1271,6 @@ def generate_bertviz_html(activation_data: Dict[str, Any], layer_index: int, vie
1758
  return f"<p>Error generating visualization: {str(e)}</p>"
1759
 
1760
 
1761
- def generate_category_bertviz_html(activation_data: Dict[str, Any], category_heads: List[Dict[str, Any]]) -> str:
1762
- """
1763
- Generate BertViz attention visualization HTML for a specific category of heads.
1764
-
1765
- Shows only the attention patterns for heads in the specified category.
1766
-
1767
- Args:
1768
- activation_data: Output from execute_forward_pass
1769
- category_heads: List of head info dicts for this category (from categorize_all_heads)
1770
-
1771
- Returns:
1772
- HTML string for the visualization
1773
- """
1774
- try:
1775
- from bertviz import head_view
1776
- from transformers import AutoTokenizer
1777
-
1778
- if not category_heads:
1779
- return "<p>No heads in this category.</p>"
1780
-
1781
- # Extract attention modules and sort by layer
1782
- attention_outputs = activation_data.get('attention_outputs', {})
1783
- if not attention_outputs:
1784
- return "<p>No attention data available</p>"
1785
-
1786
- # Build a map of layer -> head indices for this category
1787
- category_map = {} # layer_num -> list of head indices
1788
- for head_info in category_heads:
1789
- layer = head_info['layer']
1790
- head = head_info['head']
1791
- if layer not in category_map:
1792
- category_map[layer] = []
1793
- category_map[layer].append(head)
1794
-
1795
- # Sort attention modules by layer number and filter heads
1796
- # Track which layers we've already processed to avoid duplicates
1797
- layer_attention_pairs = []
1798
- processed_layers = set()
1799
-
1800
- for module_name in attention_outputs.keys():
1801
- numbers = re.findall(r'\d+', module_name)
1802
- if numbers:
1803
- layer_num = int(numbers[0])
1804
-
1805
- # Skip layers not in this category
1806
- if layer_num not in category_map:
1807
- continue
1808
-
1809
- # Skip if we've already processed this layer (prevents duplicate/mismatched tensors)
1810
- if layer_num in processed_layers:
1811
- continue
1812
-
1813
- attention_output = attention_outputs[module_name]['output']
1814
- if isinstance(attention_output, list) and len(attention_output) >= 2:
1815
- # Get attention weights (element 1 of the output tuple)
1816
- full_attention = torch.tensor(attention_output[1]) # [batch, heads, seq, seq]
1817
-
1818
- # Filter to only include heads in this category
1819
- head_indices = category_map[layer_num]
1820
- filtered_attention = full_attention[:, head_indices, :, :] # Select specific heads
1821
-
1822
- layer_attention_pairs.append((layer_num, filtered_attention))
1823
- processed_layers.add(layer_num)
1824
-
1825
- if not layer_attention_pairs:
1826
- return "<p>No valid attention data found for this category.</p>"
1827
-
1828
- # Sort by layer number and extract attention tensors
1829
- layer_attention_pairs.sort(key=lambda x: x[0])
1830
- attentions = tuple(attn for _, attn in layer_attention_pairs)
1831
-
1832
- # Get tokens
1833
- input_ids = torch.tensor(activation_data['input_ids'])
1834
- model_name = activation_data.get('model', 'unknown')
1835
-
1836
- # Load tokenizer and convert to tokens
1837
- tokenizer = AutoTokenizer.from_pretrained(model_name)
1838
- raw_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
1839
- # Clean up tokens (remove special tokenizer artifacts like Ġ for GPT-2)
1840
- tokens = [token.replace('Ġ', ' ') if token.startswith('Ġ') else token for token in raw_tokens]
1841
-
1842
- # Generate visualization using head_view (better for showing specific heads)
1843
- html_result = head_view(attentions, tokens, html_action='return')
1844
- base_html = html_result.data if hasattr(html_result, 'data') else str(html_result)
1845
-
1846
- # Create a legend mapping head indices to their actual layer-head labels
1847
- legend_items = []
1848
- head_counter = 0
1849
- for layer_num, _ in layer_attention_pairs:
1850
- head_indices = category_map[layer_num]
1851
- for head_idx in head_indices:
1852
- legend_items.append(f"Head {head_counter}: L{layer_num}-H{head_idx}")
1853
- head_counter += 1
1854
-
1855
- legend_html = """
1856
- <div style="background-color: #f8f9fa; padding: 10px; margin-bottom: 10px; border-radius: 5px; border: 1px solid #dee2e6;">
1857
- <strong style="color: #495057;">Head Index Reference:</strong><br/>
1858
- <div style="font-size: 12px; color: #6c757d; margin-top: 5px; display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 5px;">
1859
- {items}
1860
- </div>
1861
- </div>
1862
- """.format(items=''.join(f'<span>{item}</span>' for item in legend_items))
1863
-
1864
- # Prepend legend to the visualization
1865
- return legend_html + base_html
1866
-
1867
- except Exception as e:
1868
- import traceback
1869
- traceback.print_exc()
1870
- return f"<p>Error generating category visualization: {str(e)}</p>"
1871
-
1872
-
1873
- def generate_head_view_with_categories(activation_data: Dict[str, Any]) -> Dict[str, Any]:
1874
- """
1875
- Generate BertViz head view HTML along with head categorization data.
1876
-
1877
- Combines the head_view visualization with categorization from head_detection.py
1878
- to provide both visual attention patterns and semantic categorization.
1879
-
1880
- Args:
1881
- activation_data: Output from execute_forward_pass with attention data
1882
-
1883
- Returns:
1884
- Dict with:
1885
- - 'html': BertViz head_view HTML string
1886
- - 'categories': Dict from categorize_all_heads (category -> list of head info)
1887
- - 'summary': Formatted text summary of head categorization
1888
- - 'error': Error message if visualization failed (optional)
1889
- """
1890
- from .head_detection import categorize_all_heads, format_categorization_summary
1891
-
1892
- result = {
1893
- 'html': None,
1894
- 'categories': {},
1895
- 'summary': '',
1896
- 'error': None
1897
- }
1898
-
1899
- # Generate the base head_view visualization
1900
- try:
1901
- result['html'] = generate_bertviz_html(activation_data, layer_index=0, view_type='full')
1902
- except Exception as e:
1903
- result['error'] = f"Failed to generate head view: {str(e)}"
1904
- result['html'] = f"<p>Error generating visualization: {str(e)}</p>"
1905
-
1906
- # Generate head categorization
1907
- try:
1908
- result['categories'] = categorize_all_heads(activation_data)
1909
- result['summary'] = format_categorization_summary(result['categories'])
1910
- except Exception as e:
1911
- if result['error']:
1912
- result['error'] += f"; Categorization failed: {str(e)}"
1913
- else:
1914
- result['error'] = f"Categorization failed: {str(e)}"
1915
-
1916
- return result
1917
-
1918
-
1919
  def get_head_category_counts(activation_data: Dict[str, Any]) -> Dict[str, int]:
1920
  """
1921
  Get counts of attention heads in each category.
 
657
  return result
658
 
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
661
  ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
662
  """
 
991
  return None
992
 
993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
  def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
995
  layer_wise_deltas: Dict[int, Dict[str, float]],
996
  actual_output_token: str,
 
1034
  return significant_layers
1035
 
1036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1037
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
1038
  """
1039
  Extract layer-by-layer data for accordion display with top-5, deltas, and attention.
1040
  Also tracks global top 5 tokens across all layers.
1041
 
1042
  Returns:
1043
+ List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas,
1044
  global_top5_probs, global_top5_deltas
1045
  """
1046
  layer_modules = activation_data.get('block_modules', [])
 
1081
  for layer_num, module_name in layer_info:
1082
  top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if can_compute_predictions else None
1083
 
 
 
 
 
 
1084
  # Get probabilities for global top 5 tokens at this layer
1085
  global_top5_probs = {}
1086
  global_top5_deltas = {}
 
1111
  'top_3_tokens': top_tokens[:3], # Keep for backward compatibility
1112
  'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
1113
  'deltas': deltas,
 
1114
  'global_top5_probs': global_top5_probs, # New: global top 5 probs at this layer
1115
  'global_top5_deltas': global_top5_deltas # New: global top 5 deltas
1116
  })
 
1127
  'top_3_tokens': [],
1128
  'top_5_tokens': [],
1129
  'deltas': {},
 
1130
  'global_top5_probs': {},
1131
  'global_top5_deltas': {}
1132
  })
 
1271
  return f"<p>Error generating visualization: {str(e)}</p>"
1272
 
1273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274
  def get_head_category_counts(activation_data: Dict[str, Any]) -> Dict[str, int]:
1275
  """
1276
  Get counts of attention heads in each category.