Spaces:
Sleeping
Sleeping
Commit ·
9d22c3c
1
Parent(s): b828e58
Update layer highlighting to 100% threshold for actual output token only
Browse files- Changed detect_significant_probability_increases to only check actual output token instead of all top 5 tokens
- Updated threshold from 50% to 100% (doubled probability) for highlighting layers
- Modified compute_layer_wise_summaries to accept activation_data and extract actual output token
- Updated all calls to compute_layer_wise_summaries in app.py to pass activation_data
- Added _create_actual_output_display helper function to show actual model output
- Display actual output token below final output graphs (normal mode, ablation mode, comparison mode)
- Added tooltip explaining why actual output may differ from highest probability token in final layer due to residual connections
- Updated graph tooltip text to reflect 100% threshold change
- app.py +129 -15
- components/__pycache__/sidebar.cpython-311.pyc +0 -0
- components/__pycache__/tokenization_panel.cpython-311.pyc +0 -0
- utils/__pycache__/head_detection.cpython-311.pyc +0 -0
- utils/__pycache__/model_config.cpython-311.pyc +0 -0
- utils/__pycache__/model_patterns.cpython-311.pyc +0 -0
- utils/__pycache__/prompt_comparison.cpython-311.pyc +0 -0
- utils/model_patterns.py +29 -19
app.py
CHANGED
|
@@ -855,6 +855,78 @@ def _create_comparison_delta_chart(layer_data1, layer_data2, layer_num, global_t
|
|
| 855 |
return fig
|
| 856 |
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
# Callback to create accordion panels from layer data
|
| 859 |
@app.callback(
|
| 860 |
Output('layer-accordions-container', 'children'),
|
|
@@ -885,7 +957,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 885 |
return html.P("No layer data available.", className="placeholder-text")
|
| 886 |
|
| 887 |
# Compute layer-wise probability tracking
|
| 888 |
-
tracking_data = compute_layer_wise_summaries(layer_data)
|
| 889 |
layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
|
| 890 |
significant_layers = tracking_data.get('significant_layers', [])
|
| 891 |
global_top5 = activation_data.get('global_top5_tokens', [])
|
|
@@ -898,7 +970,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 898 |
|
| 899 |
if ablation_mode:
|
| 900 |
original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
|
| 901 |
-
original_tracking_data = compute_layer_wise_summaries(original_layer_data)
|
| 902 |
original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
|
| 903 |
original_significant_layers = original_tracking_data.get('significant_layers', [])
|
| 904 |
original_global_top5 = original_activation_data.get('global_top5_tokens', [])
|
|
@@ -912,7 +984,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 912 |
|
| 913 |
if comparison_mode:
|
| 914 |
layer_data2 = extract_layer_data(activation_data2, model, tokenizer)
|
| 915 |
-
tracking_data2 = compute_layer_wise_summaries(layer_data2)
|
| 916 |
layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
|
| 917 |
significant_layers2 = tracking_data2.get('significant_layers', [])
|
| 918 |
global_top5_2 = activation_data2.get('global_top5_tokens', [])
|
|
@@ -1271,7 +1343,8 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1271 |
# Update title to indicate "Before Ablation"
|
| 1272 |
fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
|
| 1273 |
|
| 1274 |
-
|
|
|
|
| 1275 |
html.H5("Before Ablation", style={
|
| 1276 |
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1277 |
'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
|
|
@@ -1282,7 +1355,15 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1282 |
"This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
|
| 1283 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1284 |
dcc.Graph(figure=fig_before, config={'displayModeBar': False})
|
| 1285 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1286 |
|
| 1287 |
line_graphs.append(graph_container_before)
|
| 1288 |
|
|
@@ -1297,7 +1378,8 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1297 |
ablated_heads = activation_data.get('ablated_heads', [])
|
| 1298 |
heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
|
| 1299 |
|
| 1300 |
-
|
|
|
|
| 1301 |
html.H5("After Ablation", style={
|
| 1302 |
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1303 |
'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
|
|
@@ -1308,10 +1390,22 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1308 |
f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
|
| 1309 |
"Compare with the graph above to see the impact of the ablation."
|
| 1310 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1311 |
-
dcc.Graph(figure=fig_after, config={'displayModeBar': False})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1312 |
html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
|
| 1313 |
style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
|
| 1314 |
-
|
|
|
|
|
|
|
|
|
|
| 1315 |
|
| 1316 |
line_graphs.append(graph_container_after)
|
| 1317 |
|
|
@@ -1320,23 +1414,35 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1320 |
fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
|
| 1321 |
if fig:
|
| 1322 |
tooltip_text = ("This graph shows how confident the model is in its top 5 predictions as it processes through each layer. "
|
| 1323 |
-
"Yellow highlights mark layers where the model's confidence
|
| 1324 |
"These are the layers where the model made important decisions. "
|
| 1325 |
-
"Click on the Transformer Layers section
|
| 1326 |
|
| 1327 |
merge_note = ("Note: Some tokens appear with a space before them (like ' cat') and some without (like 'cat'). "
|
| 1328 |
"We automatically combine these to make the graph easier to read.")
|
| 1329 |
|
| 1330 |
-
|
|
|
|
| 1331 |
html.Div([
|
| 1332 |
html.I(className="fas fa-info-circle",
|
| 1333 |
style={'marginRight': '8px', 'color': '#667eea'}),
|
| 1334 |
tooltip_text
|
| 1335 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1336 |
-
dcc.Graph(figure=fig, config={'displayModeBar': False})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1337 |
html.Small(merge_note,
|
| 1338 |
style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
|
| 1339 |
-
|
|
|
|
|
|
|
| 1340 |
|
| 1341 |
line_graphs.append(graph_container)
|
| 1342 |
|
|
@@ -1344,10 +1450,18 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
|
|
| 1344 |
if comparison_mode and layer_wise_probs2 and global_top5_2:
|
| 1345 |
fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
|
| 1346 |
if fig2:
|
| 1347 |
-
|
|
|
|
| 1348 |
html.H6("Prompt 2", style={'color': '#495057', 'marginBottom': '10px'}),
|
| 1349 |
dcc.Graph(figure=fig2, config={'displayModeBar': False})
|
| 1350 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1351 |
line_graphs.append(graph_container2)
|
| 1352 |
|
| 1353 |
# Create stacked visual representation for collapsed state
|
|
|
|
| 855 |
return fig
|
| 856 |
|
| 857 |
|
| 858 |
+
def _create_actual_output_display(activation_data):
|
| 859 |
+
"""
|
| 860 |
+
Create a display element showing the actual output token with tooltip.
|
| 861 |
+
|
| 862 |
+
Args:
|
| 863 |
+
activation_data: Activation data containing actual_output
|
| 864 |
+
|
| 865 |
+
Returns:
|
| 866 |
+
Dash HTML component displaying the actual output token
|
| 867 |
+
"""
|
| 868 |
+
actual_output = activation_data.get('actual_output')
|
| 869 |
+
if not actual_output:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
token = actual_output.get('token', 'N/A')
|
| 873 |
+
probability = actual_output.get('probability', 0.0)
|
| 874 |
+
|
| 875 |
+
tooltip_text = ("The actual output token may differ from the highest probability token shown in the final layer. "
|
| 876 |
+
"This is because the model uses residual connections (skip links) that add information across layers. "
|
| 877 |
+
"The final output is determined after all residual streams are combined. "
|
| 878 |
+
"See transformer layer implementations for details.")
|
| 879 |
+
|
| 880 |
+
return html.Div([
|
| 881 |
+
html.Div([
|
| 882 |
+
html.Strong("Actual Model Output: ", style={'color': '#495057', 'fontSize': '14px'}),
|
| 883 |
+
html.Span(f'"{token}"', style={
|
| 884 |
+
'backgroundColor': '#e8f5e9',
|
| 885 |
+
'padding': '4px 10px',
|
| 886 |
+
'borderRadius': '4px',
|
| 887 |
+
'fontFamily': 'monospace',
|
| 888 |
+
'fontSize': '14px',
|
| 889 |
+
'fontWeight': '600',
|
| 890 |
+
'color': '#2e7d32',
|
| 891 |
+
'border': '1px solid #4caf50'
|
| 892 |
+
}),
|
| 893 |
+
html.Span(f" (probability: {probability:.4f})", style={
|
| 894 |
+
'color': '#6c757d',
|
| 895 |
+
'fontSize': '13px',
|
| 896 |
+
'marginLeft': '8px'
|
| 897 |
+
}),
|
| 898 |
+
html.I(
|
| 899 |
+
className="fas fa-info-circle",
|
| 900 |
+
id="actual-output-info-icon",
|
| 901 |
+
style={
|
| 902 |
+
'marginLeft': '10px',
|
| 903 |
+
'color': '#667eea',
|
| 904 |
+
'cursor': 'pointer',
|
| 905 |
+
'fontSize': '14px'
|
| 906 |
+
}
|
| 907 |
+
)
|
| 908 |
+
], style={'marginBottom': '8px'}),
|
| 909 |
+
html.Div([
|
| 910 |
+
html.I(className="fas fa-lightbulb", style={'marginRight': '6px', 'color': '#ffa726'}),
|
| 911 |
+
tooltip_text
|
| 912 |
+
], style={
|
| 913 |
+
'fontSize': '12px',
|
| 914 |
+
'color': '#6c757d',
|
| 915 |
+
'backgroundColor': '#fff8e1',
|
| 916 |
+
'padding': '10px',
|
| 917 |
+
'borderRadius': '5px',
|
| 918 |
+
'borderLeft': '3px solid #ffa726',
|
| 919 |
+
'lineHeight': '1.6'
|
| 920 |
+
})
|
| 921 |
+
], style={
|
| 922 |
+
'marginTop': '15px',
|
| 923 |
+
'padding': '12px',
|
| 924 |
+
'backgroundColor': '#f8f9fa',
|
| 925 |
+
'borderRadius': '6px',
|
| 926 |
+
'border': '1px solid #dee2e6'
|
| 927 |
+
})
|
| 928 |
+
|
| 929 |
+
|
| 930 |
# Callback to create accordion panels from layer data
|
| 931 |
@app.callback(
|
| 932 |
Output('layer-accordions-container', 'children'),
|
|
|
|
| 957 |
return html.P("No layer data available.", className="placeholder-text")
|
| 958 |
|
| 959 |
# Compute layer-wise probability tracking
|
| 960 |
+
tracking_data = compute_layer_wise_summaries(layer_data, activation_data)
|
| 961 |
layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
|
| 962 |
significant_layers = tracking_data.get('significant_layers', [])
|
| 963 |
global_top5 = activation_data.get('global_top5_tokens', [])
|
|
|
|
| 970 |
|
| 971 |
if ablation_mode:
|
| 972 |
original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
|
| 973 |
+
original_tracking_data = compute_layer_wise_summaries(original_layer_data, original_activation_data)
|
| 974 |
original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
|
| 975 |
original_significant_layers = original_tracking_data.get('significant_layers', [])
|
| 976 |
original_global_top5 = original_activation_data.get('global_top5_tokens', [])
|
|
|
|
| 984 |
|
| 985 |
if comparison_mode:
|
| 986 |
layer_data2 = extract_layer_data(activation_data2, model, tokenizer)
|
| 987 |
+
tracking_data2 = compute_layer_wise_summaries(layer_data2, activation_data2)
|
| 988 |
layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
|
| 989 |
significant_layers2 = tracking_data2.get('significant_layers', [])
|
| 990 |
global_top5_2 = activation_data2.get('global_top5_tokens', [])
|
|
|
|
| 1343 |
# Update title to indicate "Before Ablation"
|
| 1344 |
fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
|
| 1345 |
|
| 1346 |
+
# Build children for before ablation graph
|
| 1347 |
+
before_children = [
|
| 1348 |
html.H5("Before Ablation", style={
|
| 1349 |
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1350 |
'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
|
|
|
|
| 1355 |
"This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
|
| 1356 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1357 |
dcc.Graph(figure=fig_before, config={'displayModeBar': False})
|
| 1358 |
+
]
|
| 1359 |
+
|
| 1360 |
+
# Add actual output display
|
| 1361 |
+
actual_output_display_before = _create_actual_output_display(original_activation_data)
|
| 1362 |
+
if actual_output_display_before:
|
| 1363 |
+
before_children.append(actual_output_display_before)
|
| 1364 |
+
|
| 1365 |
+
graph_container_before = html.Div(before_children,
|
| 1366 |
+
style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '20px'})
|
| 1367 |
|
| 1368 |
line_graphs.append(graph_container_before)
|
| 1369 |
|
|
|
|
| 1378 |
ablated_heads = activation_data.get('ablated_heads', [])
|
| 1379 |
heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
|
| 1380 |
|
| 1381 |
+
# Build children for after ablation graph
|
| 1382 |
+
after_children = [
|
| 1383 |
html.H5("After Ablation", style={
|
| 1384 |
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1385 |
'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
|
|
|
|
| 1390 |
f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
|
| 1391 |
"Compare with the graph above to see the impact of the ablation."
|
| 1392 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1393 |
+
dcc.Graph(figure=fig_after, config={'displayModeBar': False})
|
| 1394 |
+
]
|
| 1395 |
+
|
| 1396 |
+
# Add actual output display
|
| 1397 |
+
actual_output_display_after = _create_actual_output_display(activation_data)
|
| 1398 |
+
if actual_output_display_after:
|
| 1399 |
+
after_children.append(actual_output_display_after)
|
| 1400 |
+
|
| 1401 |
+
# Add merge note at the end
|
| 1402 |
+
after_children.append(
|
| 1403 |
html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
|
| 1404 |
style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
|
| 1405 |
+
)
|
| 1406 |
+
|
| 1407 |
+
graph_container_after = html.Div(after_children,
|
| 1408 |
+
style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '20px'})
|
| 1409 |
|
| 1410 |
line_graphs.append(graph_container_after)
|
| 1411 |
|
|
|
|
| 1414 |
fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
|
| 1415 |
if fig:
|
| 1416 |
tooltip_text = ("This graph shows how confident the model is in its top 5 predictions as it processes through each layer. "
|
| 1417 |
+
"Yellow highlights mark layers where the model's confidence in the actual output token doubled (100% or more increase). "
|
| 1418 |
"These are the layers where the model made important decisions. "
|
| 1419 |
+
"Click on the Transformer Layers section to see what each layer did.")
|
| 1420 |
|
| 1421 |
merge_note = ("Note: Some tokens appear with a space before them (like ' cat') and some without (like 'cat'). "
|
| 1422 |
"We automatically combine these to make the graph easier to read.")
|
| 1423 |
|
| 1424 |
+
# Create list of children for graph container
|
| 1425 |
+
graph_children = [
|
| 1426 |
html.Div([
|
| 1427 |
html.I(className="fas fa-info-circle",
|
| 1428 |
style={'marginRight': '8px', 'color': '#667eea'}),
|
| 1429 |
tooltip_text
|
| 1430 |
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1431 |
+
dcc.Graph(figure=fig, config={'displayModeBar': False})
|
| 1432 |
+
]
|
| 1433 |
+
|
| 1434 |
+
# Add actual output display
|
| 1435 |
+
actual_output_display = _create_actual_output_display(activation_data)
|
| 1436 |
+
if actual_output_display:
|
| 1437 |
+
graph_children.append(actual_output_display)
|
| 1438 |
+
|
| 1439 |
+
# Add merge note at the end
|
| 1440 |
+
graph_children.append(
|
| 1441 |
html.Small(merge_note,
|
| 1442 |
style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
|
| 1443 |
+
)
|
| 1444 |
+
|
| 1445 |
+
graph_container = html.Div(graph_children, style={'marginBottom': '20px'})
|
| 1446 |
|
| 1447 |
line_graphs.append(graph_container)
|
| 1448 |
|
|
|
|
| 1450 |
if comparison_mode and layer_wise_probs2 and global_top5_2:
|
| 1451 |
fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
|
| 1452 |
if fig2:
|
| 1453 |
+
# Build children for second prompt graph
|
| 1454 |
+
children2 = [
|
| 1455 |
html.H6("Prompt 2", style={'color': '#495057', 'marginBottom': '10px'}),
|
| 1456 |
dcc.Graph(figure=fig2, config={'displayModeBar': False})
|
| 1457 |
+
]
|
| 1458 |
+
|
| 1459 |
+
# Add actual output display for second prompt
|
| 1460 |
+
actual_output_display2 = _create_actual_output_display(activation_data2)
|
| 1461 |
+
if actual_output_display2:
|
| 1462 |
+
children2.append(actual_output_display2)
|
| 1463 |
+
|
| 1464 |
+
graph_container2 = html.Div(children2, style={'marginTop': '20px'})
|
| 1465 |
line_graphs.append(graph_container2)
|
| 1466 |
|
| 1467 |
# Create stacked visual representation for collapsed state
|
components/__pycache__/sidebar.cpython-311.pyc
CHANGED
|
Binary files a/components/__pycache__/sidebar.cpython-311.pyc and b/components/__pycache__/sidebar.cpython-311.pyc differ
|
|
|
components/__pycache__/tokenization_panel.cpython-311.pyc
CHANGED
|
Binary files a/components/__pycache__/tokenization_panel.cpython-311.pyc and b/components/__pycache__/tokenization_panel.cpython-311.pyc differ
|
|
|
utils/__pycache__/head_detection.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/head_detection.cpython-311.pyc and b/utils/__pycache__/head_detection.cpython-311.pyc differ
|
|
|
utils/__pycache__/model_config.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/model_config.cpython-311.pyc and b/utils/__pycache__/model_config.cpython-311.pyc differ
|
|
|
utils/__pycache__/model_patterns.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
|
|
|
utils/__pycache__/prompt_comparison.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/prompt_comparison.cpython-311.pyc and b/utils/__pycache__/prompt_comparison.cpython-311.pyc differ
|
|
|
utils/model_patterns.py
CHANGED
|
@@ -863,24 +863,25 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
|
|
| 863 |
|
| 864 |
def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
|
| 865 |
layer_wise_deltas: Dict[int, Dict[str, float]],
|
| 866 |
-
|
|
|
|
| 867 |
"""
|
| 868 |
-
Detect layers where
|
| 869 |
|
| 870 |
-
A layer is significant if
|
| 871 |
-
Example: 0.20 → 0.
|
| 872 |
|
| 873 |
-
This threshold
|
| 874 |
-
|
| 875 |
-
in the model's confidence that is pedagogically useful to highlight.
|
| 876 |
|
| 877 |
Args:
|
| 878 |
layer_wise_probs: Dict mapping layer_num → {token: prob}
|
| 879 |
layer_wise_deltas: Dict mapping layer_num → {token: delta}
|
| 880 |
-
|
|
|
|
| 881 |
|
| 882 |
Returns:
|
| 883 |
-
List of layer numbers with significant increases
|
| 884 |
"""
|
| 885 |
significant_layers = []
|
| 886 |
|
|
@@ -888,8 +889,10 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
|
|
| 888 |
probs = layer_wise_probs[layer_num]
|
| 889 |
deltas = layer_wise_deltas.get(layer_num, {})
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
|
|
|
|
|
|
| 893 |
prev_prob = prob - delta
|
| 894 |
|
| 895 |
# Check for significant relative increase (avoid division by zero)
|
|
@@ -897,7 +900,6 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
|
|
| 897 |
relative_increase = delta / prev_prob
|
| 898 |
if relative_increase >= threshold:
|
| 899 |
significant_layers.append(layer_num)
|
| 900 |
-
break # Only need to flag layer once
|
| 901 |
|
| 902 |
return significant_layers
|
| 903 |
|
|
@@ -968,12 +970,13 @@ def _get_top_attended_tokens(activation_data: Dict[str, Any], layer_num: int, to
|
|
| 968 |
return None
|
| 969 |
|
| 970 |
|
| 971 |
-
def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 972 |
"""
|
| 973 |
Compute summary structures from layer data for easy access.
|
| 974 |
|
| 975 |
Args:
|
| 976 |
layer_data: List of layer data dicts from extract_layer_data()
|
|
|
|
| 977 |
|
| 978 |
Returns:
|
| 979 |
Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
|
|
@@ -987,12 +990,19 @@ def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]]) -> Dict[str,
|
|
| 987 |
layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
|
| 988 |
layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
|
| 989 |
|
| 990 |
-
#
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
|
| 997 |
return {
|
| 998 |
'layer_wise_top5_probs': layer_wise_top5_probs,
|
|
|
|
| 863 |
|
| 864 |
def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
|
| 865 |
layer_wise_deltas: Dict[int, Dict[str, float]],
|
| 866 |
+
actual_output_token: str,
|
| 867 |
+
threshold: float = 1.0) -> List[int]:
|
| 868 |
"""
|
| 869 |
+
Detect layers where the actual output token has significant probability increase.
|
| 870 |
|
| 871 |
+
A layer is significant if the actual output token has ≥100% relative increase from previous layer.
|
| 872 |
+
Example: 0.20 → 0.40 is (0.40-0.20)/0.20 = 100% increase.
|
| 873 |
|
| 874 |
+
This threshold highlights layers where the model's confidence in the actual output
|
| 875 |
+
doubles, representing a pedagogically significant shift in the prediction.
|
|
|
|
| 876 |
|
| 877 |
Args:
|
| 878 |
layer_wise_probs: Dict mapping layer_num → {token: prob}
|
| 879 |
layer_wise_deltas: Dict mapping layer_num → {token: delta}
|
| 880 |
+
actual_output_token: The token that the model actually outputs (predicted token)
|
| 881 |
+
threshold: Relative increase threshold (default: 1.0 = 100%)
|
| 882 |
|
| 883 |
Returns:
|
| 884 |
+
List of layer numbers with significant increases in the actual output token
|
| 885 |
"""
|
| 886 |
significant_layers = []
|
| 887 |
|
|
|
|
| 889 |
probs = layer_wise_probs[layer_num]
|
| 890 |
deltas = layer_wise_deltas.get(layer_num, {})
|
| 891 |
|
| 892 |
+
# Only check the actual output token
|
| 893 |
+
if actual_output_token in probs:
|
| 894 |
+
prob = probs[actual_output_token]
|
| 895 |
+
delta = deltas.get(actual_output_token, 0.0)
|
| 896 |
prev_prob = prob - delta
|
| 897 |
|
| 898 |
# Check for significant relative increase (avoid division by zero)
|
|
|
|
| 900 |
relative_increase = delta / prev_prob
|
| 901 |
if relative_increase >= threshold:
|
| 902 |
significant_layers.append(layer_num)
|
|
|
|
| 903 |
|
| 904 |
return significant_layers
|
| 905 |
|
|
|
|
| 970 |
return None
|
| 971 |
|
| 972 |
|
| 973 |
+
def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]], activation_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 974 |
"""
|
| 975 |
Compute summary structures from layer data for easy access.
|
| 976 |
|
| 977 |
Args:
|
| 978 |
layer_data: List of layer data dicts from extract_layer_data()
|
| 979 |
+
activation_data: Activation data containing actual output token
|
| 980 |
|
| 981 |
Returns:
|
| 982 |
Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
|
|
|
|
| 990 |
layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
|
| 991 |
layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
|
| 992 |
|
| 993 |
+
# Extract actual output token from activation data
|
| 994 |
+
actual_output = activation_data.get('actual_output', {})
|
| 995 |
+
actual_output_token = actual_output.get('token', '').strip() if actual_output else ''
|
| 996 |
+
|
| 997 |
+
# Detect significant layers based on actual output token
|
| 998 |
+
significant_layers = []
|
| 999 |
+
if actual_output_token:
|
| 1000 |
+
significant_layers = detect_significant_probability_increases(
|
| 1001 |
+
layer_wise_top5_probs,
|
| 1002 |
+
layer_wise_top5_deltas,
|
| 1003 |
+
actual_output_token,
|
| 1004 |
+
threshold=1.0
|
| 1005 |
+
)
|
| 1006 |
|
| 1007 |
return {
|
| 1008 |
'layer_wise_top5_probs': layer_wise_top5_probs,
|