Spaces:
Sleeping
Add ablation experiment before/after comparison with Reset button
Browse files- Added session-activation-store-original to preserve pre-ablation data
- Modified run_head_ablation callback to save original data before ablation
- Updated create_layer_accordions to detect ablation mode and show before/after sections for all layers
- Modified _create_single_prompt_chart and _create_token_probability_delta_chart to accept title suffixes
- Added before/after comparison for line graphs showing top 5 token probabilities across layers
- Added Reset Ablation button in sidebar (visible only during ablation)
- Implemented reset_ablation callback to restore original data and clear ablation
- Added visual polish with section headers, background colors (blue for before, orange for after), and explanatory tooltips
- All changes maintain backward compatibility with non-ablation mode
- app.py +209 -25
- components/__pycache__/sidebar.cpython-311.pyc +0 -0
- components/sidebar.py +14 -1
- utils/__pycache__/model_patterns.cpython-311.pyc +0 -0
|
@@ -105,6 +105,8 @@ app.layout = html.Div([
|
|
| 105 |
# Session storage for activation data
|
| 106 |
dcc.Store(id='session-activation-store', storage_type='session'),
|
| 107 |
dcc.Store(id='session-patterns-store', storage_type='session'),
|
|
|
|
|
|
|
| 108 |
# Sidebar collapse state (default: collapsed = True)
|
| 109 |
dcc.Store(id='sidebar-collapse-store', storage_type='session', data=True),
|
| 110 |
# Comparison mode state (default: not comparing)
|
|
@@ -496,12 +498,13 @@ def _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top
|
|
| 496 |
return fig
|
| 497 |
|
| 498 |
|
| 499 |
-
def _create_single_prompt_chart(layer_data):
|
| 500 |
"""
|
| 501 |
Create a single prompt bar chart (existing functionality).
|
| 502 |
|
| 503 |
Args:
|
| 504 |
layer_data: Layer data dict (with top_5_tokens, deltas, certainty)
|
|
|
|
| 505 |
|
| 506 |
Returns:
|
| 507 |
Plotly Figure with horizontal bars
|
|
@@ -548,9 +551,14 @@ def _create_single_prompt_chart(layer_data):
|
|
| 548 |
)
|
| 549 |
])
|
| 550 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
fig.update_layout(
|
| 552 |
title={
|
| 553 |
-
'text':
|
| 554 |
'font': {'size': 14}
|
| 555 |
},
|
| 556 |
xaxis={'title': 'Probability', 'range': [0, max(probs) * 1.15]},
|
|
@@ -702,7 +710,7 @@ def _create_comparison_bar_chart(layer_data1, layer_data2, layer_num):
|
|
| 702 |
return fig
|
| 703 |
|
| 704 |
|
| 705 |
-
def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tokens):
|
| 706 |
"""
|
| 707 |
Create horizontal bar chart showing change in probabilities for global top 5 tokens.
|
| 708 |
|
|
@@ -710,6 +718,7 @@ def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tok
|
|
| 710 |
layer_data: Layer data dict with global_top5_deltas
|
| 711 |
layer_num: Layer number for title
|
| 712 |
global_top5_tokens: List of (token, prob) tuples for final global top 5
|
|
|
|
| 713 |
|
| 714 |
Returns:
|
| 715 |
Plotly Figure with horizontal bars (green for positive, red for negative)
|
|
@@ -759,9 +768,13 @@ def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tok
|
|
| 759 |
|
| 760 |
# Update layout
|
| 761 |
prev_layer_text = "Embedding" if layer_num == 0 else f"Layer {layer_num - 1}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
fig.update_layout(
|
| 763 |
title={
|
| 764 |
-
'text':
|
| 765 |
'font': {'size': 13}
|
| 766 |
},
|
| 767 |
xaxis={'title': 'Probability Change', 'range': x_range, 'zeroline': True, 'zerolinewidth': 2, 'zerolinecolor': '#999'},
|
|
@@ -866,10 +879,11 @@ def _create_comparison_delta_chart(layer_data1, layer_data2, layer_num, global_t
|
|
| 866 |
@app.callback(
|
| 867 |
Output('layer-accordions-container', 'children'),
|
| 868 |
[Input('session-activation-store', 'data'),
|
| 869 |
-
Input('session-activation-store-2', 'data')
|
|
|
|
| 870 |
[State('model-dropdown', 'value')]
|
| 871 |
)
|
| 872 |
-
def create_layer_accordions(activation_data, activation_data2, model_name):
|
| 873 |
"""Create accordion panels for each layer with top-5 bar charts, deltas, and certainty."""
|
| 874 |
if not activation_data or not model_name:
|
| 875 |
return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
|
|
@@ -881,18 +895,34 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
|
|
| 881 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 882 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 883 |
|
| 884 |
-
#
|
|
|
|
|
|
|
|
|
|
| 885 |
layer_data = extract_layer_data(activation_data, model, tokenizer)
|
| 886 |
|
| 887 |
if not layer_data:
|
| 888 |
return html.P("No layer data available.", className="placeholder-text")
|
| 889 |
|
| 890 |
-
# Compute layer-wise probability tracking
|
| 891 |
tracking_data = compute_layer_wise_summaries(layer_data)
|
| 892 |
layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
|
| 893 |
significant_layers = tracking_data.get('significant_layers', [])
|
| 894 |
global_top5 = activation_data.get('global_top5_tokens', [])
|
| 895 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
# Check if second prompt exists and extract its layer data
|
| 897 |
layer_data2 = None
|
| 898 |
layer_wise_probs2 = {}
|
|
@@ -1158,16 +1188,68 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
|
|
| 1158 |
|
| 1159 |
# Add delta chart after attention head categories
|
| 1160 |
content_items.append(html.Hr(style={'margin': '15px 0'}))
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1169 |
else:
|
| 1170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
|
| 1172 |
# Add "Explore These Changes" button after delta chart
|
| 1173 |
content_items.append(explore_button_section)
|
|
@@ -1187,7 +1269,60 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
|
|
| 1187 |
# Create line graph(s) for top 5 tokens across layers
|
| 1188 |
line_graphs = []
|
| 1189 |
|
| 1190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1191 |
fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
|
| 1192 |
if fig:
|
| 1193 |
tooltip_text = ("This graph shows how the model's confidence in the final top 5 predictions "
|
|
@@ -1212,7 +1347,7 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
|
|
| 1212 |
|
| 1213 |
line_graphs.append(graph_container)
|
| 1214 |
|
| 1215 |
-
# In comparison mode, create a second graph or side-by-side display
|
| 1216 |
if comparison_mode and layer_wise_probs2 and global_top5_2:
|
| 1217 |
fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
|
| 1218 |
if fig2:
|
|
@@ -1610,6 +1745,7 @@ def handle_head_selection(n_clicks_list, selected_heads):
|
|
| 1610 |
# Run ablation experiment
|
| 1611 |
@app.callback(
|
| 1612 |
[Output('session-activation-store', 'data', allow_duplicate=True),
|
|
|
|
| 1613 |
Output('model-status', 'children', allow_duplicate=True)],
|
| 1614 |
Input({'type': 'run-ablation-btn', 'layer': ALL}, 'n_clicks'),
|
| 1615 |
[State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
|
|
@@ -1623,12 +1759,12 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1623 |
# Identify which button was clicked
|
| 1624 |
ctx = dash.callback_context
|
| 1625 |
if not ctx.triggered or not ctx.triggered_id:
|
| 1626 |
-
return no_update, no_update
|
| 1627 |
|
| 1628 |
# Get the layer number from the triggered button ID
|
| 1629 |
layer_num = ctx.triggered_id.get('layer')
|
| 1630 |
if layer_num is None:
|
| 1631 |
-
return no_update, no_update
|
| 1632 |
|
| 1633 |
# Find the index in the states_list that corresponds to this layer
|
| 1634 |
# ctx.states_list contains the State values in order
|
|
@@ -1643,7 +1779,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1643 |
# Fallback: if states_list doesn't work, try matching by iterating
|
| 1644 |
if button_index is None:
|
| 1645 |
# This shouldn't happen, but as a fallback, just return error
|
| 1646 |
-
return no_update, html.Div([
|
| 1647 |
html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
|
| 1648 |
f"Could not determine button index for layer {layer_num}"
|
| 1649 |
], className="status-error")
|
|
@@ -1654,12 +1790,16 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1654 |
selected_heads = selected_heads_list[button_index]
|
| 1655 |
|
| 1656 |
if not selected_heads or not activation_data:
|
| 1657 |
-
return no_update, no_update
|
| 1658 |
|
| 1659 |
try:
|
| 1660 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1661 |
from utils import execute_forward_pass_with_head_ablation
|
| 1662 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1663 |
# Load model and tokenizer
|
| 1664 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 1665 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
@@ -1694,7 +1834,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1694 |
f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
|
| 1695 |
], className="status-success")
|
| 1696 |
|
| 1697 |
-
return ablated_data, success_message
|
| 1698 |
|
| 1699 |
except Exception as e:
|
| 1700 |
print(f"Error running ablation: {e}")
|
|
@@ -1706,7 +1846,51 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1706 |
f"Ablation error: {str(e)}"
|
| 1707 |
], className="status-error")
|
| 1708 |
|
| 1709 |
-
return no_update, error_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1710 |
|
| 1711 |
|
| 1712 |
if __name__ == '__main__':
|
|
|
|
| 105 |
# Session storage for activation data
|
| 106 |
dcc.Store(id='session-activation-store', storage_type='session'),
|
| 107 |
dcc.Store(id='session-patterns-store', storage_type='session'),
|
| 108 |
+
# Store original activation data before ablation for comparison
|
| 109 |
+
dcc.Store(id='session-activation-store-original', storage_type='session'),
|
| 110 |
# Sidebar collapse state (default: collapsed = True)
|
| 111 |
dcc.Store(id='sidebar-collapse-store', storage_type='session', data=True),
|
| 112 |
# Comparison mode state (default: not comparing)
|
|
|
|
| 498 |
return fig
|
| 499 |
|
| 500 |
|
| 501 |
+
def _create_single_prompt_chart(layer_data, title_suffix=''):
|
| 502 |
"""
|
| 503 |
Create a single prompt bar chart (existing functionality).
|
| 504 |
|
| 505 |
Args:
|
| 506 |
layer_data: Layer data dict (with top_5_tokens, deltas, certainty)
|
| 507 |
+
title_suffix: Optional suffix to add to title (e.g., "Before Ablation", "After Ablation")
|
| 508 |
|
| 509 |
Returns:
|
| 510 |
Plotly Figure with horizontal bars
|
|
|
|
| 551 |
)
|
| 552 |
])
|
| 553 |
|
| 554 |
+
# Build title with optional suffix
|
| 555 |
+
title_text = f'Top 5 Predictions (Certainty: {certainty:.2f})'
|
| 556 |
+
if title_suffix:
|
| 557 |
+
title_text = f'Top 5 Predictions {title_suffix} (Certainty: {certainty:.2f})'
|
| 558 |
+
|
| 559 |
fig.update_layout(
|
| 560 |
title={
|
| 561 |
+
'text': title_text,
|
| 562 |
'font': {'size': 14}
|
| 563 |
},
|
| 564 |
xaxis={'title': 'Probability', 'range': [0, max(probs) * 1.15]},
|
|
|
|
| 710 |
return fig
|
| 711 |
|
| 712 |
|
| 713 |
+
def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tokens, title_suffix=''):
|
| 714 |
"""
|
| 715 |
Create horizontal bar chart showing change in probabilities for global top 5 tokens.
|
| 716 |
|
|
|
|
| 718 |
layer_data: Layer data dict with global_top5_deltas
|
| 719 |
layer_num: Layer number for title
|
| 720 |
global_top5_tokens: List of (token, prob) tuples for final global top 5
|
| 721 |
+
title_suffix: Optional suffix to add to title (e.g., "Before Ablation", "After Ablation")
|
| 722 |
|
| 723 |
Returns:
|
| 724 |
Plotly Figure with horizontal bars (green for positive, red for negative)
|
|
|
|
| 768 |
|
| 769 |
# Update layout
|
| 770 |
prev_layer_text = "Embedding" if layer_num == 0 else f"Layer {layer_num - 1}"
|
| 771 |
+
title_text = f'Change in Token Probabilities (from {prev_layer_text} to Layer {layer_num})'
|
| 772 |
+
if title_suffix:
|
| 773 |
+
title_text = f'Change in Token Probabilities {title_suffix} (from {prev_layer_text} to Layer {layer_num})'
|
| 774 |
+
|
| 775 |
fig.update_layout(
|
| 776 |
title={
|
| 777 |
+
'text': title_text,
|
| 778 |
'font': {'size': 13}
|
| 779 |
},
|
| 780 |
xaxis={'title': 'Probability Change', 'range': x_range, 'zeroline': True, 'zerolinewidth': 2, 'zerolinecolor': '#999'},
|
|
|
|
| 879 |
@app.callback(
|
| 880 |
Output('layer-accordions-container', 'children'),
|
| 881 |
[Input('session-activation-store', 'data'),
|
| 882 |
+
Input('session-activation-store-2', 'data'),
|
| 883 |
+
Input('session-activation-store-original', 'data')],
|
| 884 |
[State('model-dropdown', 'value')]
|
| 885 |
)
|
| 886 |
+
def create_layer_accordions(activation_data, activation_data2, original_activation_data, model_name):
|
| 887 |
"""Create accordion panels for each layer with top-5 bar charts, deltas, and certainty."""
|
| 888 |
if not activation_data or not model_name:
|
| 889 |
return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
|
|
|
|
| 895 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 896 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 897 |
|
| 898 |
+
# Check if we're in ablation mode
|
| 899 |
+
ablation_mode = activation_data.get('ablated', False) and original_activation_data
|
| 900 |
+
|
| 901 |
+
# Extract layer data for current activation (may be ablated)
|
| 902 |
layer_data = extract_layer_data(activation_data, model, tokenizer)
|
| 903 |
|
| 904 |
if not layer_data:
|
| 905 |
return html.P("No layer data available.", className="placeholder-text")
|
| 906 |
|
| 907 |
+
# Compute layer-wise probability tracking
|
| 908 |
tracking_data = compute_layer_wise_summaries(layer_data)
|
| 909 |
layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
|
| 910 |
significant_layers = tracking_data.get('significant_layers', [])
|
| 911 |
global_top5 = activation_data.get('global_top5_tokens', [])
|
| 912 |
|
| 913 |
+
# If in ablation mode, also extract original layer data
|
| 914 |
+
original_layer_data = None
|
| 915 |
+
original_layer_wise_probs = {}
|
| 916 |
+
original_significant_layers = []
|
| 917 |
+
original_global_top5 = []
|
| 918 |
+
|
| 919 |
+
if ablation_mode:
|
| 920 |
+
original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
|
| 921 |
+
original_tracking_data = compute_layer_wise_summaries(original_layer_data)
|
| 922 |
+
original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
|
| 923 |
+
original_significant_layers = original_tracking_data.get('significant_layers', [])
|
| 924 |
+
original_global_top5 = original_activation_data.get('global_top5_tokens', [])
|
| 925 |
+
|
| 926 |
# Check if second prompt exists and extract its layer data
|
| 927 |
layer_data2 = None
|
| 928 |
layer_wise_probs2 = {}
|
|
|
|
| 1188 |
|
| 1189 |
# Add delta chart after attention head categories
|
| 1190 |
content_items.append(html.Hr(style={'margin': '15px 0'}))
|
| 1191 |
+
|
| 1192 |
+
# If in ablation mode, show before/after comparison
|
| 1193 |
+
if ablation_mode and original_layer_data:
|
| 1194 |
+
# Find corresponding original layer
|
| 1195 |
+
original_layer = next((l for l in original_layer_data if l['layer_num'] == layer_num), None)
|
| 1196 |
+
|
| 1197 |
+
if original_layer:
|
| 1198 |
+
# Add explanatory note about ablation comparison
|
| 1199 |
+
content_items.append(html.Div([
|
| 1200 |
+
html.I(className="fas fa-info-circle", style={'marginRight': '8px', 'color': '#667eea'}),
|
| 1201 |
+
f"Comparing probabilities before and after ablating Layer {activation_data.get('ablated_layer')}, " +
|
| 1202 |
+
f"Heads {', '.join([f'H{h}' for h in sorted(activation_data.get('ablated_heads', []))])}"
|
| 1203 |
+
], style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '15px', 'padding': '10px',
|
| 1204 |
+
'backgroundColor': '#f8f9fa', 'borderRadius': '6px', 'border': '1px solid #dee2e6'}))
|
| 1205 |
+
|
| 1206 |
+
# Before Ablation Section
|
| 1207 |
+
content_items.append(html.Div([
|
| 1208 |
+
html.H6("Before Ablation", style={
|
| 1209 |
+
'marginBottom': '10px', 'color': '#495057', 'fontSize': '14px',
|
| 1210 |
+
'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
|
| 1211 |
+
}),
|
| 1212 |
+
dcc.Graph(
|
| 1213 |
+
figure=_create_token_probability_delta_chart(original_layer, layer_num, original_global_top5, '(Before Ablation)'),
|
| 1214 |
+
config={'displayModeBar': False},
|
| 1215 |
+
style={'marginBottom': '10px'}
|
| 1216 |
+
)
|
| 1217 |
+
], style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '15px'}))
|
| 1218 |
+
|
| 1219 |
+
# After Ablation Section
|
| 1220 |
+
content_items.append(html.Div([
|
| 1221 |
+
html.H6("After Ablation", style={
|
| 1222 |
+
'marginBottom': '10px', 'color': '#495057', 'fontSize': '14px',
|
| 1223 |
+
'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
|
| 1224 |
+
}),
|
| 1225 |
+
dcc.Graph(
|
| 1226 |
+
figure=_create_token_probability_delta_chart(layer, layer_num, global_top5, '(After Ablation)'),
|
| 1227 |
+
config={'displayModeBar': False},
|
| 1228 |
+
style={'marginBottom': '10px'}
|
| 1229 |
+
)
|
| 1230 |
+
], style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '15px'}))
|
| 1231 |
+
else:
|
| 1232 |
+
# Fallback if original layer not found
|
| 1233 |
+
if delta_fig:
|
| 1234 |
+
content_items.append(
|
| 1235 |
+
dcc.Graph(
|
| 1236 |
+
figure=delta_fig,
|
| 1237 |
+
config={'displayModeBar': False},
|
| 1238 |
+
style={'marginBottom': '15px'}
|
| 1239 |
+
)
|
| 1240 |
+
)
|
| 1241 |
else:
|
| 1242 |
+
# Normal mode (not ablation): show single delta chart
|
| 1243 |
+
if delta_fig:
|
| 1244 |
+
content_items.append(
|
| 1245 |
+
dcc.Graph(
|
| 1246 |
+
figure=delta_fig,
|
| 1247 |
+
config={'displayModeBar': False},
|
| 1248 |
+
style={'marginBottom': '15px'}
|
| 1249 |
+
)
|
| 1250 |
+
)
|
| 1251 |
+
else:
|
| 1252 |
+
content_items.append(html.P("No probability changes available", style={'color': '#6c757d', 'fontSize': '13px'}))
|
| 1253 |
|
| 1254 |
# Add "Explore These Changes" button after delta chart
|
| 1255 |
content_items.append(explore_button_section)
|
|
|
|
| 1269 |
# Create line graph(s) for top 5 tokens across layers
|
| 1270 |
line_graphs = []
|
| 1271 |
|
| 1272 |
+
# If in ablation mode, show before/after comparison
|
| 1273 |
+
if ablation_mode and original_layer_wise_probs and original_global_top5:
|
| 1274 |
+
# Before Ablation Graph
|
| 1275 |
+
fig_before = _create_top5_by_layer_graph(original_layer_wise_probs, original_significant_layers, original_global_top5)
|
| 1276 |
+
if fig_before:
|
| 1277 |
+
# Update title to indicate "Before Ablation"
|
| 1278 |
+
fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
|
| 1279 |
+
|
| 1280 |
+
graph_container_before = html.Div([
|
| 1281 |
+
html.H5("Before Ablation", style={
|
| 1282 |
+
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1283 |
+
'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
|
| 1284 |
+
}),
|
| 1285 |
+
html.Div([
|
| 1286 |
+
html.I(className="fas fa-info-circle",
|
| 1287 |
+
style={'marginRight': '8px', 'color': '#667eea'}),
|
| 1288 |
+
"This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
|
| 1289 |
+
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1290 |
+
dcc.Graph(figure=fig_before, config={'displayModeBar': False})
|
| 1291 |
+
], style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '20px'})
|
| 1292 |
+
|
| 1293 |
+
line_graphs.append(graph_container_before)
|
| 1294 |
+
|
| 1295 |
+
# After Ablation Graph
|
| 1296 |
+
if layer_wise_probs and global_top5:
|
| 1297 |
+
fig_after = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
|
| 1298 |
+
if fig_after:
|
| 1299 |
+
# Update title to indicate "After Ablation"
|
| 1300 |
+
fig_after.update_layout(title="Top 5 Token Probabilities Across Layers (After Ablation)")
|
| 1301 |
+
|
| 1302 |
+
ablated_layer = activation_data.get('ablated_layer')
|
| 1303 |
+
ablated_heads = activation_data.get('ablated_heads', [])
|
| 1304 |
+
heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
|
| 1305 |
+
|
| 1306 |
+
graph_container_after = html.Div([
|
| 1307 |
+
html.H5("After Ablation", style={
|
| 1308 |
+
'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
|
| 1309 |
+
'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
|
| 1310 |
+
}),
|
| 1311 |
+
html.Div([
|
| 1312 |
+
html.I(className="fas fa-info-circle",
|
| 1313 |
+
style={'marginRight': '8px', 'color': '#f57c00'}),
|
| 1314 |
+
f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
|
| 1315 |
+
"Compare with the graph above to see the impact of the ablation."
|
| 1316 |
+
], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
|
| 1317 |
+
dcc.Graph(figure=fig_after, config={'displayModeBar': False}),
|
| 1318 |
+
html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
|
| 1319 |
+
style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
|
| 1320 |
+
], style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '20px'})
|
| 1321 |
+
|
| 1322 |
+
line_graphs.append(graph_container_after)
|
| 1323 |
+
|
| 1324 |
+
# Normal mode (not ablation): show single line graph
|
| 1325 |
+
elif layer_wise_probs and global_top5:
|
| 1326 |
fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
|
| 1327 |
if fig:
|
| 1328 |
tooltip_text = ("This graph shows how the model's confidence in the final top 5 predictions "
|
|
|
|
| 1347 |
|
| 1348 |
line_graphs.append(graph_container)
|
| 1349 |
|
| 1350 |
+
# In comparison mode (two prompts), create a second graph or side-by-side display
|
| 1351 |
if comparison_mode and layer_wise_probs2 and global_top5_2:
|
| 1352 |
fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
|
| 1353 |
if fig2:
|
|
|
|
| 1745 |
# Run ablation experiment
|
| 1746 |
@app.callback(
|
| 1747 |
[Output('session-activation-store', 'data', allow_duplicate=True),
|
| 1748 |
+
Output('session-activation-store-original', 'data'),
|
| 1749 |
Output('model-status', 'children', allow_duplicate=True)],
|
| 1750 |
Input({'type': 'run-ablation-btn', 'layer': ALL}, 'n_clicks'),
|
| 1751 |
[State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
|
|
|
|
| 1759 |
# Identify which button was clicked
|
| 1760 |
ctx = dash.callback_context
|
| 1761 |
if not ctx.triggered or not ctx.triggered_id:
|
| 1762 |
+
return no_update, no_update, no_update
|
| 1763 |
|
| 1764 |
# Get the layer number from the triggered button ID
|
| 1765 |
layer_num = ctx.triggered_id.get('layer')
|
| 1766 |
if layer_num is None:
|
| 1767 |
+
return no_update, no_update, no_update
|
| 1768 |
|
| 1769 |
# Find the index in the states_list that corresponds to this layer
|
| 1770 |
# ctx.states_list contains the State values in order
|
|
|
|
| 1779 |
# Fallback: if states_list doesn't work, try matching by iterating
|
| 1780 |
if button_index is None:
|
| 1781 |
# This shouldn't happen, but as a fallback, just return error
|
| 1782 |
+
return no_update, no_update, html.Div([
|
| 1783 |
html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
|
| 1784 |
f"Could not determine button index for layer {layer_num}"
|
| 1785 |
], className="status-error")
|
|
|
|
| 1790 |
selected_heads = selected_heads_list[button_index]
|
| 1791 |
|
| 1792 |
if not selected_heads or not activation_data:
|
| 1793 |
+
return no_update, no_update, no_update
|
| 1794 |
|
| 1795 |
try:
|
| 1796 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1797 |
from utils import execute_forward_pass_with_head_ablation
|
| 1798 |
|
| 1799 |
+
# Save original activation data before ablation
|
| 1800 |
+
import copy
|
| 1801 |
+
original_data = copy.deepcopy(activation_data)
|
| 1802 |
+
|
| 1803 |
# Load model and tokenizer
|
| 1804 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 1805 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 1834 |
f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
|
| 1835 |
], className="status-success")
|
| 1836 |
|
| 1837 |
+
return ablated_data, original_data, success_message
|
| 1838 |
|
| 1839 |
except Exception as e:
|
| 1840 |
print(f"Error running ablation: {e}")
|
|
|
|
| 1846 |
f"Ablation error: {str(e)}"
|
| 1847 |
], className="status-error")
|
| 1848 |
|
| 1849 |
+
return no_update, no_update, error_message
|
| 1850 |
+
|
| 1851 |
+
|
| 1852 |
+
# Show/hide Reset Ablation button based on ablation mode
|
| 1853 |
+
@app.callback(
|
| 1854 |
+
Output('reset-ablation-container', 'style'),
|
| 1855 |
+
Input('session-activation-store', 'data'),
|
| 1856 |
+
prevent_initial_call=False
|
| 1857 |
+
)
|
| 1858 |
+
def toggle_reset_ablation_button(activation_data):
|
| 1859 |
+
"""Show Reset Ablation button when in ablation mode, hide otherwise."""
|
| 1860 |
+
if activation_data and activation_data.get('ablated', False):
|
| 1861 |
+
return {'display': 'block'}
|
| 1862 |
+
else:
|
| 1863 |
+
return {'display': 'none'}
|
| 1864 |
+
|
| 1865 |
+
|
| 1866 |
+
# Reset ablation experiment
|
| 1867 |
+
@app.callback(
|
| 1868 |
+
[Output('session-activation-store', 'data', allow_duplicate=True),
|
| 1869 |
+
Output('session-activation-store-original', 'data', allow_duplicate=True),
|
| 1870 |
+
Output('model-status', 'children', allow_duplicate=True)],
|
| 1871 |
+
Input('reset-ablation-btn', 'n_clicks'),
|
| 1872 |
+
[State('session-activation-store-original', 'data')],
|
| 1873 |
+
prevent_initial_call=True
|
| 1874 |
+
)
|
| 1875 |
+
def reset_ablation(n_clicks, original_data):
|
| 1876 |
+
"""Reset ablation by restoring original data and clearing the original store."""
|
| 1877 |
+
if not n_clicks:
|
| 1878 |
+
return no_update, no_update, no_update
|
| 1879 |
+
|
| 1880 |
+
if not original_data:
|
| 1881 |
+
error_message = html.Div([
|
| 1882 |
+
html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
|
| 1883 |
+
"No original data to restore"
|
| 1884 |
+
], className="status-error")
|
| 1885 |
+
return no_update, no_update, error_message
|
| 1886 |
+
|
| 1887 |
+
# Restore original data to main store and clear original store
|
| 1888 |
+
success_message = html.Div([
|
| 1889 |
+
html.I(className="fas fa-undo", style={'marginRight': '8px', 'color': '#28a745'}),
|
| 1890 |
+
"Ablation reset - original data restored"
|
| 1891 |
+
], className="status-success")
|
| 1892 |
+
|
| 1893 |
+
return original_data, {}, success_message
|
| 1894 |
|
| 1895 |
|
| 1896 |
if __name__ == '__main__':
|
|
Binary files a/components/__pycache__/sidebar.cpython-311.pyc and b/components/__pycache__/sidebar.cpython-311.pyc differ
|
|
|
|
@@ -93,7 +93,20 @@ def create_sidebar():
|
|
| 93 |
id="clear-selections-btn",
|
| 94 |
className="action-button secondary-button"
|
| 95 |
)
|
| 96 |
-
], className="button-container")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
], id="sidebar-content", className="sidebar-content")
|
| 98 |
|
| 99 |
])
|
|
|
|
| 93 |
id="clear-selections-btn",
|
| 94 |
className="action-button secondary-button"
|
| 95 |
)
|
| 96 |
+
], className="button-container"),
|
| 97 |
+
|
| 98 |
+
# Reset Ablation button (hidden by default, shown when in ablation mode)
|
| 99 |
+
html.Div([
|
| 100 |
+
html.Button(
|
| 101 |
+
[
|
| 102 |
+
html.I(className="fas fa-undo", style={'marginRight': '8px'}),
|
| 103 |
+
"Reset Ablation"
|
| 104 |
+
],
|
| 105 |
+
id="reset-ablation-btn",
|
| 106 |
+
className="action-button warning-button",
|
| 107 |
+
style={'marginTop': '10px'}
|
| 108 |
+
)
|
| 109 |
+
], id="reset-ablation-container", style={'display': 'none'})
|
| 110 |
], id="sidebar-content", className="sidebar-content")
|
| 111 |
|
| 112 |
])
|
|
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
|
|
|