cdpearlman commited on
Commit
d701d46
·
1 Parent(s): f4fa674

Add ablation experiment before/after comparison with Reset button

Browse files

- Added session-activation-store-original to preserve pre-ablation data
- Modified run_head_ablation callback to save original data before ablation
- Updated create_layer_accordions to detect ablation mode and show before/after sections for all layers
- Modified _create_single_prompt_chart and _create_token_probability_delta_chart to accept title suffixes
- Added before/after comparison for line graphs showing top 5 token probabilities across layers
- Added Reset Ablation button in sidebar (visible only during ablation)
- Implemented reset_ablation callback to restore original data and clear ablation
- Added visual polish with section headers, background colors (blue for before, orange for after), and explanatory tooltips
- All changes maintain backward compatibility with non-ablation mode

app.py CHANGED
@@ -105,6 +105,8 @@ app.layout = html.Div([
105
  # Session storage for activation data
106
  dcc.Store(id='session-activation-store', storage_type='session'),
107
  dcc.Store(id='session-patterns-store', storage_type='session'),
 
 
108
  # Sidebar collapse state (default: collapsed = True)
109
  dcc.Store(id='sidebar-collapse-store', storage_type='session', data=True),
110
  # Comparison mode state (default: not comparing)
@@ -496,12 +498,13 @@ def _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top
496
  return fig
497
 
498
 
499
- def _create_single_prompt_chart(layer_data):
500
  """
501
  Create a single prompt bar chart (existing functionality).
502
 
503
  Args:
504
  layer_data: Layer data dict (with top_5_tokens, deltas, certainty)
 
505
 
506
  Returns:
507
  Plotly Figure with horizontal bars
@@ -548,9 +551,14 @@ def _create_single_prompt_chart(layer_data):
548
  )
549
  ])
550
 
 
 
 
 
 
551
  fig.update_layout(
552
  title={
553
- 'text': f'Top 5 Predictions (Certainty: {certainty:.2f})',
554
  'font': {'size': 14}
555
  },
556
  xaxis={'title': 'Probability', 'range': [0, max(probs) * 1.15]},
@@ -702,7 +710,7 @@ def _create_comparison_bar_chart(layer_data1, layer_data2, layer_num):
702
  return fig
703
 
704
 
705
- def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tokens):
706
  """
707
  Create horizontal bar chart showing change in probabilities for global top 5 tokens.
708
 
@@ -710,6 +718,7 @@ def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tok
710
  layer_data: Layer data dict with global_top5_deltas
711
  layer_num: Layer number for title
712
  global_top5_tokens: List of (token, prob) tuples for final global top 5
 
713
 
714
  Returns:
715
  Plotly Figure with horizontal bars (green for positive, red for negative)
@@ -759,9 +768,13 @@ def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tok
759
 
760
  # Update layout
761
  prev_layer_text = "Embedding" if layer_num == 0 else f"Layer {layer_num - 1}"
 
 
 
 
762
  fig.update_layout(
763
  title={
764
- 'text': f'Change in Token Probabilities (from {prev_layer_text} to Layer {layer_num})',
765
  'font': {'size': 13}
766
  },
767
  xaxis={'title': 'Probability Change', 'range': x_range, 'zeroline': True, 'zerolinewidth': 2, 'zerolinecolor': '#999'},
@@ -866,10 +879,11 @@ def _create_comparison_delta_chart(layer_data1, layer_data2, layer_num, global_t
866
  @app.callback(
867
  Output('layer-accordions-container', 'children'),
868
  [Input('session-activation-store', 'data'),
869
- Input('session-activation-store-2', 'data')],
 
870
  [State('model-dropdown', 'value')]
871
  )
872
- def create_layer_accordions(activation_data, activation_data2, model_name):
873
  """Create accordion panels for each layer with top-5 bar charts, deltas, and certainty."""
874
  if not activation_data or not model_name:
875
  return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
@@ -881,18 +895,34 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
881
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
882
  tokenizer = AutoTokenizer.from_pretrained(model_name)
883
 
884
- # Extract layer data for first prompt
 
 
 
885
  layer_data = extract_layer_data(activation_data, model, tokenizer)
886
 
887
  if not layer_data:
888
  return html.P("No layer data available.", className="placeholder-text")
889
 
890
- # Compute layer-wise probability tracking for first prompt
891
  tracking_data = compute_layer_wise_summaries(layer_data)
892
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
893
  significant_layers = tracking_data.get('significant_layers', [])
894
  global_top5 = activation_data.get('global_top5_tokens', [])
895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  # Check if second prompt exists and extract its layer data
897
  layer_data2 = None
898
  layer_wise_probs2 = {}
@@ -1158,16 +1188,68 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
1158
 
1159
  # Add delta chart after attention head categories
1160
  content_items.append(html.Hr(style={'margin': '15px 0'}))
1161
- if delta_fig:
1162
- content_items.append(
1163
- dcc.Graph(
1164
- figure=delta_fig,
1165
- config={'displayModeBar': False},
1166
- style={'marginBottom': '15px'}
1167
- )
1168
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1169
  else:
1170
- content_items.append(html.P("No probability changes available", style={'color': '#6c757d', 'fontSize': '13px'}))
 
 
 
 
 
 
 
 
 
 
1171
 
1172
  # Add "Explore These Changes" button after delta chart
1173
  content_items.append(explore_button_section)
@@ -1187,7 +1269,60 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
1187
  # Create line graph(s) for top 5 tokens across layers
1188
  line_graphs = []
1189
 
1190
- if layer_wise_probs and global_top5:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1191
  fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
1192
  if fig:
1193
  tooltip_text = ("This graph shows how the model's confidence in the final top 5 predictions "
@@ -1212,7 +1347,7 @@ def create_layer_accordions(activation_data, activation_data2, model_name):
1212
 
1213
  line_graphs.append(graph_container)
1214
 
1215
- # In comparison mode, create a second graph or side-by-side display
1216
  if comparison_mode and layer_wise_probs2 and global_top5_2:
1217
  fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
1218
  if fig2:
@@ -1610,6 +1745,7 @@ def handle_head_selection(n_clicks_list, selected_heads):
1610
  # Run ablation experiment
1611
  @app.callback(
1612
  [Output('session-activation-store', 'data', allow_duplicate=True),
 
1613
  Output('model-status', 'children', allow_duplicate=True)],
1614
  Input({'type': 'run-ablation-btn', 'layer': ALL}, 'n_clicks'),
1615
  [State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
@@ -1623,12 +1759,12 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1623
  # Identify which button was clicked
1624
  ctx = dash.callback_context
1625
  if not ctx.triggered or not ctx.triggered_id:
1626
- return no_update, no_update
1627
 
1628
  # Get the layer number from the triggered button ID
1629
  layer_num = ctx.triggered_id.get('layer')
1630
  if layer_num is None:
1631
- return no_update, no_update
1632
 
1633
  # Find the index in the states_list that corresponds to this layer
1634
  # ctx.states_list contains the State values in order
@@ -1643,7 +1779,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1643
  # Fallback: if states_list doesn't work, try matching by iterating
1644
  if button_index is None:
1645
  # This shouldn't happen, but as a fallback, just return error
1646
- return no_update, html.Div([
1647
  html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
1648
  f"Could not determine button index for layer {layer_num}"
1649
  ], className="status-error")
@@ -1654,12 +1790,16 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1654
  selected_heads = selected_heads_list[button_index]
1655
 
1656
  if not selected_heads or not activation_data:
1657
- return no_update, no_update
1658
 
1659
  try:
1660
  from transformers import AutoModelForCausalLM, AutoTokenizer
1661
  from utils import execute_forward_pass_with_head_ablation
1662
 
 
 
 
 
1663
  # Load model and tokenizer
1664
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
1665
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -1694,7 +1834,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1694
  f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
1695
  ], className="status-success")
1696
 
1697
- return ablated_data, success_message
1698
 
1699
  except Exception as e:
1700
  print(f"Error running ablation: {e}")
@@ -1706,7 +1846,51 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1706
  f"Ablation error: {str(e)}"
1707
  ], className="status-error")
1708
 
1709
- return no_update, error_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1710
 
1711
 
1712
  if __name__ == '__main__':
 
105
  # Session storage for activation data
106
  dcc.Store(id='session-activation-store', storage_type='session'),
107
  dcc.Store(id='session-patterns-store', storage_type='session'),
108
+ # Store original activation data before ablation for comparison
109
+ dcc.Store(id='session-activation-store-original', storage_type='session'),
110
  # Sidebar collapse state (default: collapsed = True)
111
  dcc.Store(id='sidebar-collapse-store', storage_type='session', data=True),
112
  # Comparison mode state (default: not comparing)
 
498
  return fig
499
 
500
 
501
+ def _create_single_prompt_chart(layer_data, title_suffix=''):
502
  """
503
  Create a single prompt bar chart (existing functionality).
504
 
505
  Args:
506
  layer_data: Layer data dict (with top_5_tokens, deltas, certainty)
507
+ title_suffix: Optional suffix to add to title (e.g., "Before Ablation", "After Ablation")
508
 
509
  Returns:
510
  Plotly Figure with horizontal bars
 
551
  )
552
  ])
553
 
554
+ # Build title with optional suffix
555
+ title_text = f'Top 5 Predictions (Certainty: {certainty:.2f})'
556
+ if title_suffix:
557
+ title_text = f'Top 5 Predictions {title_suffix} (Certainty: {certainty:.2f})'
558
+
559
  fig.update_layout(
560
  title={
561
+ 'text': title_text,
562
  'font': {'size': 14}
563
  },
564
  xaxis={'title': 'Probability', 'range': [0, max(probs) * 1.15]},
 
710
  return fig
711
 
712
 
713
+ def _create_token_probability_delta_chart(layer_data, layer_num, global_top5_tokens, title_suffix=''):
714
  """
715
  Create horizontal bar chart showing change in probabilities for global top 5 tokens.
716
 
 
718
  layer_data: Layer data dict with global_top5_deltas
719
  layer_num: Layer number for title
720
  global_top5_tokens: List of (token, prob) tuples for final global top 5
721
+ title_suffix: Optional suffix to add to title (e.g., "Before Ablation", "After Ablation")
722
 
723
  Returns:
724
  Plotly Figure with horizontal bars (green for positive, red for negative)
 
768
 
769
  # Update layout
770
  prev_layer_text = "Embedding" if layer_num == 0 else f"Layer {layer_num - 1}"
771
+ title_text = f'Change in Token Probabilities (from {prev_layer_text} to Layer {layer_num})'
772
+ if title_suffix:
773
+ title_text = f'Change in Token Probabilities {title_suffix} (from {prev_layer_text} to Layer {layer_num})'
774
+
775
  fig.update_layout(
776
  title={
777
+ 'text': title_text,
778
  'font': {'size': 13}
779
  },
780
  xaxis={'title': 'Probability Change', 'range': x_range, 'zeroline': True, 'zerolinewidth': 2, 'zerolinecolor': '#999'},
 
879
  @app.callback(
880
  Output('layer-accordions-container', 'children'),
881
  [Input('session-activation-store', 'data'),
882
+ Input('session-activation-store-2', 'data'),
883
+ Input('session-activation-store-original', 'data')],
884
  [State('model-dropdown', 'value')]
885
  )
886
+ def create_layer_accordions(activation_data, activation_data2, original_activation_data, model_name):
887
  """Create accordion panels for each layer with top-5 bar charts, deltas, and certainty."""
888
  if not activation_data or not model_name:
889
  return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
 
895
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
896
  tokenizer = AutoTokenizer.from_pretrained(model_name)
897
 
898
+ # Check if we're in ablation mode
899
+ ablation_mode = activation_data.get('ablated', False) and original_activation_data
900
+
901
+ # Extract layer data for current activation (may be ablated)
902
  layer_data = extract_layer_data(activation_data, model, tokenizer)
903
 
904
  if not layer_data:
905
  return html.P("No layer data available.", className="placeholder-text")
906
 
907
+ # Compute layer-wise probability tracking
908
  tracking_data = compute_layer_wise_summaries(layer_data)
909
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
910
  significant_layers = tracking_data.get('significant_layers', [])
911
  global_top5 = activation_data.get('global_top5_tokens', [])
912
 
913
+ # If in ablation mode, also extract original layer data
914
+ original_layer_data = None
915
+ original_layer_wise_probs = {}
916
+ original_significant_layers = []
917
+ original_global_top5 = []
918
+
919
+ if ablation_mode:
920
+ original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
921
+ original_tracking_data = compute_layer_wise_summaries(original_layer_data)
922
+ original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
923
+ original_significant_layers = original_tracking_data.get('significant_layers', [])
924
+ original_global_top5 = original_activation_data.get('global_top5_tokens', [])
925
+
926
  # Check if second prompt exists and extract its layer data
927
  layer_data2 = None
928
  layer_wise_probs2 = {}
 
1188
 
1189
  # Add delta chart after attention head categories
1190
  content_items.append(html.Hr(style={'margin': '15px 0'}))
1191
+
1192
+ # If in ablation mode, show before/after comparison
1193
+ if ablation_mode and original_layer_data:
1194
+ # Find corresponding original layer
1195
+ original_layer = next((l for l in original_layer_data if l['layer_num'] == layer_num), None)
1196
+
1197
+ if original_layer:
1198
+ # Add explanatory note about ablation comparison
1199
+ content_items.append(html.Div([
1200
+ html.I(className="fas fa-info-circle", style={'marginRight': '8px', 'color': '#667eea'}),
1201
+ f"Comparing probabilities before and after ablating Layer {activation_data.get('ablated_layer')}, " +
1202
+ f"Heads {', '.join([f'H{h}' for h in sorted(activation_data.get('ablated_heads', []))])}"
1203
+ ], style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '15px', 'padding': '10px',
1204
+ 'backgroundColor': '#f8f9fa', 'borderRadius': '6px', 'border': '1px solid #dee2e6'}))
1205
+
1206
+ # Before Ablation Section
1207
+ content_items.append(html.Div([
1208
+ html.H6("Before Ablation", style={
1209
+ 'marginBottom': '10px', 'color': '#495057', 'fontSize': '14px',
1210
+ 'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
1211
+ }),
1212
+ dcc.Graph(
1213
+ figure=_create_token_probability_delta_chart(original_layer, layer_num, original_global_top5, '(Before Ablation)'),
1214
+ config={'displayModeBar': False},
1215
+ style={'marginBottom': '10px'}
1216
+ )
1217
+ ], style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '15px'}))
1218
+
1219
+ # After Ablation Section
1220
+ content_items.append(html.Div([
1221
+ html.H6("After Ablation", style={
1222
+ 'marginBottom': '10px', 'color': '#495057', 'fontSize': '14px',
1223
+ 'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
1224
+ }),
1225
+ dcc.Graph(
1226
+ figure=_create_token_probability_delta_chart(layer, layer_num, global_top5, '(After Ablation)'),
1227
+ config={'displayModeBar': False},
1228
+ style={'marginBottom': '10px'}
1229
+ )
1230
+ ], style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '15px'}))
1231
+ else:
1232
+ # Fallback if original layer not found
1233
+ if delta_fig:
1234
+ content_items.append(
1235
+ dcc.Graph(
1236
+ figure=delta_fig,
1237
+ config={'displayModeBar': False},
1238
+ style={'marginBottom': '15px'}
1239
+ )
1240
+ )
1241
  else:
1242
+ # Normal mode (not ablation): show single delta chart
1243
+ if delta_fig:
1244
+ content_items.append(
1245
+ dcc.Graph(
1246
+ figure=delta_fig,
1247
+ config={'displayModeBar': False},
1248
+ style={'marginBottom': '15px'}
1249
+ )
1250
+ )
1251
+ else:
1252
+ content_items.append(html.P("No probability changes available", style={'color': '#6c757d', 'fontSize': '13px'}))
1253
 
1254
  # Add "Explore These Changes" button after delta chart
1255
  content_items.append(explore_button_section)
 
1269
  # Create line graph(s) for top 5 tokens across layers
1270
  line_graphs = []
1271
 
1272
+ # If in ablation mode, show before/after comparison
1273
+ if ablation_mode and original_layer_wise_probs and original_global_top5:
1274
+ # Before Ablation Graph
1275
+ fig_before = _create_top5_by_layer_graph(original_layer_wise_probs, original_significant_layers, original_global_top5)
1276
+ if fig_before:
1277
+ # Update title to indicate "Before Ablation"
1278
+ fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
1279
+
1280
+ graph_container_before = html.Div([
1281
+ html.H5("Before Ablation", style={
1282
+ 'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1283
+ 'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
1284
+ }),
1285
+ html.Div([
1286
+ html.I(className="fas fa-info-circle",
1287
+ style={'marginRight': '8px', 'color': '#667eea'}),
1288
+ "This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
1289
+ ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1290
+ dcc.Graph(figure=fig_before, config={'displayModeBar': False})
1291
+ ], style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '20px'})
1292
+
1293
+ line_graphs.append(graph_container_before)
1294
+
1295
+ # After Ablation Graph
1296
+ if layer_wise_probs and global_top5:
1297
+ fig_after = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
1298
+ if fig_after:
1299
+ # Update title to indicate "After Ablation"
1300
+ fig_after.update_layout(title="Top 5 Token Probabilities Across Layers (After Ablation)")
1301
+
1302
+ ablated_layer = activation_data.get('ablated_layer')
1303
+ ablated_heads = activation_data.get('ablated_heads', [])
1304
+ heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
1305
+
1306
+ graph_container_after = html.Div([
1307
+ html.H5("After Ablation", style={
1308
+ 'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1309
+ 'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
1310
+ }),
1311
+ html.Div([
1312
+ html.I(className="fas fa-info-circle",
1313
+ style={'marginRight': '8px', 'color': '#f57c00'}),
1314
+ f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
1315
+ "Compare with the graph above to see the impact of the ablation."
1316
+ ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1317
+ dcc.Graph(figure=fig_after, config={'displayModeBar': False}),
1318
+ html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
1319
+ style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
1320
+ ], style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '20px'})
1321
+
1322
+ line_graphs.append(graph_container_after)
1323
+
1324
+ # Normal mode (not ablation): show single line graph
1325
+ elif layer_wise_probs and global_top5:
1326
  fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
1327
  if fig:
1328
  tooltip_text = ("This graph shows how the model's confidence in the final top 5 predictions "
 
1347
 
1348
  line_graphs.append(graph_container)
1349
 
1350
+ # In comparison mode (two prompts), create a second graph or side-by-side display
1351
  if comparison_mode and layer_wise_probs2 and global_top5_2:
1352
  fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
1353
  if fig2:
 
1745
  # Run ablation experiment
1746
  @app.callback(
1747
  [Output('session-activation-store', 'data', allow_duplicate=True),
1748
+ Output('session-activation-store-original', 'data'),
1749
  Output('model-status', 'children', allow_duplicate=True)],
1750
  Input({'type': 'run-ablation-btn', 'layer': ALL}, 'n_clicks'),
1751
  [State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
 
1759
  # Identify which button was clicked
1760
  ctx = dash.callback_context
1761
  if not ctx.triggered or not ctx.triggered_id:
1762
+ return no_update, no_update, no_update
1763
 
1764
  # Get the layer number from the triggered button ID
1765
  layer_num = ctx.triggered_id.get('layer')
1766
  if layer_num is None:
1767
+ return no_update, no_update, no_update
1768
 
1769
  # Find the index in the states_list that corresponds to this layer
1770
  # ctx.states_list contains the State values in order
 
1779
  # Fallback: if states_list doesn't work, try matching by iterating
1780
  if button_index is None:
1781
  # This shouldn't happen, but as a fallback, just return error
1782
+ return no_update, no_update, html.Div([
1783
  html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
1784
  f"Could not determine button index for layer {layer_num}"
1785
  ], className="status-error")
 
1790
  selected_heads = selected_heads_list[button_index]
1791
 
1792
  if not selected_heads or not activation_data:
1793
+ return no_update, no_update, no_update
1794
 
1795
  try:
1796
  from transformers import AutoModelForCausalLM, AutoTokenizer
1797
  from utils import execute_forward_pass_with_head_ablation
1798
 
1799
+ # Save original activation data before ablation
1800
+ import copy
1801
+ original_data = copy.deepcopy(activation_data)
1802
+
1803
  # Load model and tokenizer
1804
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
1805
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
1834
  f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
1835
  ], className="status-success")
1836
 
1837
+ return ablated_data, original_data, success_message
1838
 
1839
  except Exception as e:
1840
  print(f"Error running ablation: {e}")
 
1846
  f"Ablation error: {str(e)}"
1847
  ], className="status-error")
1848
 
1849
+ return no_update, no_update, error_message
1850
+
1851
+
1852
+ # Show/hide Reset Ablation button based on ablation mode
1853
+ @app.callback(
1854
+ Output('reset-ablation-container', 'style'),
1855
+ Input('session-activation-store', 'data'),
1856
+ prevent_initial_call=False
1857
+ )
1858
+ def toggle_reset_ablation_button(activation_data):
1859
+ """Show Reset Ablation button when in ablation mode, hide otherwise."""
1860
+ if activation_data and activation_data.get('ablated', False):
1861
+ return {'display': 'block'}
1862
+ else:
1863
+ return {'display': 'none'}
1864
+
1865
+
1866
+ # Reset ablation experiment
1867
+ @app.callback(
1868
+ [Output('session-activation-store', 'data', allow_duplicate=True),
1869
+ Output('session-activation-store-original', 'data', allow_duplicate=True),
1870
+ Output('model-status', 'children', allow_duplicate=True)],
1871
+ Input('reset-ablation-btn', 'n_clicks'),
1872
+ [State('session-activation-store-original', 'data')],
1873
+ prevent_initial_call=True
1874
+ )
1875
+ def reset_ablation(n_clicks, original_data):
1876
+ """Reset ablation by restoring original data and clearing the original store."""
1877
+ if not n_clicks:
1878
+ return no_update, no_update, no_update
1879
+
1880
+ if not original_data:
1881
+ error_message = html.Div([
1882
+ html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
1883
+ "No original data to restore"
1884
+ ], className="status-error")
1885
+ return no_update, no_update, error_message
1886
+
1887
+ # Restore original data to main store and clear original store
1888
+ success_message = html.Div([
1889
+ html.I(className="fas fa-undo", style={'marginRight': '8px', 'color': '#28a745'}),
1890
+ "Ablation reset - original data restored"
1891
+ ], className="status-success")
1892
+
1893
+ return original_data, {}, success_message
1894
 
1895
 
1896
  if __name__ == '__main__':
components/__pycache__/sidebar.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/sidebar.cpython-311.pyc and b/components/__pycache__/sidebar.cpython-311.pyc differ
 
components/sidebar.py CHANGED
@@ -93,7 +93,20 @@ def create_sidebar():
93
  id="clear-selections-btn",
94
  className="action-button secondary-button"
95
  )
96
- ], className="button-container")
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ], id="sidebar-content", className="sidebar-content")
98
 
99
  ])
 
93
  id="clear-selections-btn",
94
  className="action-button secondary-button"
95
  )
96
+ ], className="button-container"),
97
+
98
+ # Reset Ablation button (hidden by default, shown when in ablation mode)
99
+ html.Div([
100
+ html.Button(
101
+ [
102
+ html.I(className="fas fa-undo", style={'marginRight': '8px'}),
103
+ "Reset Ablation"
104
+ ],
105
+ id="reset-ablation-btn",
106
+ className="action-button warning-button",
107
+ style={'marginTop': '10px'}
108
+ )
109
+ ], id="reset-ablation-container", style={'display': 'none'})
110
  ], id="sidebar-content", className="sidebar-content")
111
 
112
  ])
utils/__pycache__/model_patterns.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ