cdpearlman commited on
Commit
9d22c3c
·
1 Parent(s): b828e58

Update layer highlighting to 100% threshold for actual output token only

Browse files

- Changed detect_significant_probability_increases to only check actual output token instead of all top 5 tokens
- Updated threshold from 50% to 100% (doubled probability) for highlighting layers
- Modified compute_layer_wise_summaries to accept activation_data and extract actual output token
- Updated all calls to compute_layer_wise_summaries in app.py to pass activation_data
- Added _create_actual_output_display helper function to show actual model output
- Display actual output token below final output graphs (normal mode, ablation mode, comparison mode)
- Added tooltip explaining why actual output may differ from highest probability token in final layer due to residual connections
- Updated graph tooltip text to reflect 100% threshold change

app.py CHANGED
@@ -855,6 +855,78 @@ def _create_comparison_delta_chart(layer_data1, layer_data2, layer_num, global_t
855
  return fig
856
 
857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
  # Callback to create accordion panels from layer data
859
  @app.callback(
860
  Output('layer-accordions-container', 'children'),
@@ -885,7 +957,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
885
  return html.P("No layer data available.", className="placeholder-text")
886
 
887
  # Compute layer-wise probability tracking
888
- tracking_data = compute_layer_wise_summaries(layer_data)
889
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
890
  significant_layers = tracking_data.get('significant_layers', [])
891
  global_top5 = activation_data.get('global_top5_tokens', [])
@@ -898,7 +970,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
898
 
899
  if ablation_mode:
900
  original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
901
- original_tracking_data = compute_layer_wise_summaries(original_layer_data)
902
  original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
903
  original_significant_layers = original_tracking_data.get('significant_layers', [])
904
  original_global_top5 = original_activation_data.get('global_top5_tokens', [])
@@ -912,7 +984,7 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
912
 
913
  if comparison_mode:
914
  layer_data2 = extract_layer_data(activation_data2, model, tokenizer)
915
- tracking_data2 = compute_layer_wise_summaries(layer_data2)
916
  layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
917
  significant_layers2 = tracking_data2.get('significant_layers', [])
918
  global_top5_2 = activation_data2.get('global_top5_tokens', [])
@@ -1271,7 +1343,8 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1271
  # Update title to indicate "Before Ablation"
1272
  fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
1273
 
1274
- graph_container_before = html.Div([
 
1275
  html.H5("Before Ablation", style={
1276
  'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1277
  'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
@@ -1282,7 +1355,15 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1282
  "This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
1283
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1284
  dcc.Graph(figure=fig_before, config={'displayModeBar': False})
1285
- ], style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '20px'})
 
 
 
 
 
 
 
 
1286
 
1287
  line_graphs.append(graph_container_before)
1288
 
@@ -1297,7 +1378,8 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1297
  ablated_heads = activation_data.get('ablated_heads', [])
1298
  heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
1299
 
1300
- graph_container_after = html.Div([
 
1301
  html.H5("After Ablation", style={
1302
  'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1303
  'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
@@ -1308,10 +1390,22 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1308
  f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
1309
  "Compare with the graph above to see the impact of the ablation."
1310
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1311
- dcc.Graph(figure=fig_after, config={'displayModeBar': False}),
 
 
 
 
 
 
 
 
 
1312
  html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
1313
  style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
1314
- ], style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '20px'})
 
 
 
1315
 
1316
  line_graphs.append(graph_container_after)
1317
 
@@ -1320,23 +1414,35 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1320
  fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
1321
  if fig:
1322
  tooltip_text = ("This graph shows how confident the model is in its top 5 predictions as it processes through each layer. "
1323
- "Yellow highlights mark layers where the model's confidence jumped significantly (50% or more increase). "
1324
  "These are the layers where the model made important decisions. "
1325
- "Click on the Transformer Layers section below to see what each layer did.")
1326
 
1327
  merge_note = ("Note: Some tokens appear with a space before them (like ' cat') and some without (like 'cat'). "
1328
  "We automatically combine these to make the graph easier to read.")
1329
 
1330
- graph_container = html.Div([
 
1331
  html.Div([
1332
  html.I(className="fas fa-info-circle",
1333
  style={'marginRight': '8px', 'color': '#667eea'}),
1334
  tooltip_text
1335
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1336
- dcc.Graph(figure=fig, config={'displayModeBar': False}),
 
 
 
 
 
 
 
 
 
1337
  html.Small(merge_note,
1338
  style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
1339
- ], style={'marginBottom': '20px'})
 
 
1340
 
1341
  line_graphs.append(graph_container)
1342
 
@@ -1344,10 +1450,18 @@ def create_layer_accordions(activation_data, activation_data2, original_activati
1344
  if comparison_mode and layer_wise_probs2 and global_top5_2:
1345
  fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
1346
  if fig2:
1347
- graph_container2 = html.Div([
 
1348
  html.H6("Prompt 2", style={'color': '#495057', 'marginBottom': '10px'}),
1349
  dcc.Graph(figure=fig2, config={'displayModeBar': False})
1350
- ], style={'marginTop': '20px'})
 
 
 
 
 
 
 
1351
  line_graphs.append(graph_container2)
1352
 
1353
  # Create stacked visual representation for collapsed state
 
855
  return fig
856
 
857
 
858
+ def _create_actual_output_display(activation_data):
859
+ """
860
+ Create a display element showing the actual output token with tooltip.
861
+
862
+ Args:
863
+ activation_data: Activation data containing actual_output
864
+
865
+ Returns:
866
+ Dash HTML component displaying the actual output token
867
+ """
868
+ actual_output = activation_data.get('actual_output')
869
+ if not actual_output:
870
+ return None
871
+
872
+ token = actual_output.get('token', 'N/A')
873
+ probability = actual_output.get('probability', 0.0)
874
+
875
+ tooltip_text = ("The actual output token may differ from the highest probability token shown in the final layer. "
876
+ "This is because the model uses residual connections (skip links) that add information across layers. "
877
+ "The final output is determined after all residual streams are combined. "
878
+ "See transformer layer implementations for details.")
879
+
880
+ return html.Div([
881
+ html.Div([
882
+ html.Strong("Actual Model Output: ", style={'color': '#495057', 'fontSize': '14px'}),
883
+ html.Span(f'"{token}"', style={
884
+ 'backgroundColor': '#e8f5e9',
885
+ 'padding': '4px 10px',
886
+ 'borderRadius': '4px',
887
+ 'fontFamily': 'monospace',
888
+ 'fontSize': '14px',
889
+ 'fontWeight': '600',
890
+ 'color': '#2e7d32',
891
+ 'border': '1px solid #4caf50'
892
+ }),
893
+ html.Span(f" (probability: {probability:.4f})", style={
894
+ 'color': '#6c757d',
895
+ 'fontSize': '13px',
896
+ 'marginLeft': '8px'
897
+ }),
898
+ html.I(
899
+ className="fas fa-info-circle",
900
+ id="actual-output-info-icon",
901
+ style={
902
+ 'marginLeft': '10px',
903
+ 'color': '#667eea',
904
+ 'cursor': 'pointer',
905
+ 'fontSize': '14px'
906
+ }
907
+ )
908
+ ], style={'marginBottom': '8px'}),
909
+ html.Div([
910
+ html.I(className="fas fa-lightbulb", style={'marginRight': '6px', 'color': '#ffa726'}),
911
+ tooltip_text
912
+ ], style={
913
+ 'fontSize': '12px',
914
+ 'color': '#6c757d',
915
+ 'backgroundColor': '#fff8e1',
916
+ 'padding': '10px',
917
+ 'borderRadius': '5px',
918
+ 'borderLeft': '3px solid #ffa726',
919
+ 'lineHeight': '1.6'
920
+ })
921
+ ], style={
922
+ 'marginTop': '15px',
923
+ 'padding': '12px',
924
+ 'backgroundColor': '#f8f9fa',
925
+ 'borderRadius': '6px',
926
+ 'border': '1px solid #dee2e6'
927
+ })
928
+
929
+
930
  # Callback to create accordion panels from layer data
931
  @app.callback(
932
  Output('layer-accordions-container', 'children'),
 
957
  return html.P("No layer data available.", className="placeholder-text")
958
 
959
  # Compute layer-wise probability tracking
960
+ tracking_data = compute_layer_wise_summaries(layer_data, activation_data)
961
  layer_wise_probs = tracking_data.get('layer_wise_top5_probs', {})
962
  significant_layers = tracking_data.get('significant_layers', [])
963
  global_top5 = activation_data.get('global_top5_tokens', [])
 
970
 
971
  if ablation_mode:
972
  original_layer_data = extract_layer_data(original_activation_data, model, tokenizer)
973
+ original_tracking_data = compute_layer_wise_summaries(original_layer_data, original_activation_data)
974
  original_layer_wise_probs = original_tracking_data.get('layer_wise_top5_probs', {})
975
  original_significant_layers = original_tracking_data.get('significant_layers', [])
976
  original_global_top5 = original_activation_data.get('global_top5_tokens', [])
 
984
 
985
  if comparison_mode:
986
  layer_data2 = extract_layer_data(activation_data2, model, tokenizer)
987
+ tracking_data2 = compute_layer_wise_summaries(layer_data2, activation_data2)
988
  layer_wise_probs2 = tracking_data2.get('layer_wise_top5_probs', {})
989
  significant_layers2 = tracking_data2.get('significant_layers', [])
990
  global_top5_2 = activation_data2.get('global_top5_tokens', [])
 
1343
  # Update title to indicate "Before Ablation"
1344
  fig_before.update_layout(title="Top 5 Token Probabilities Across Layers (Before Ablation)")
1345
 
1346
+ # Build children for before ablation graph
1347
+ before_children = [
1348
  html.H5("Before Ablation", style={
1349
  'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1350
  'fontWeight': '600', 'borderLeft': '4px solid #74b9ff', 'paddingLeft': '10px'
 
1355
  "This graph shows how the model's confidence in the final top 5 predictions evolves through each layer before ablation."
1356
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1357
  dcc.Graph(figure=fig_before, config={'displayModeBar': False})
1358
+ ]
1359
+
1360
+ # Add actual output display
1361
+ actual_output_display_before = _create_actual_output_display(original_activation_data)
1362
+ if actual_output_display_before:
1363
+ before_children.append(actual_output_display_before)
1364
+
1365
+ graph_container_before = html.Div(before_children,
1366
+ style={'padding': '15px', 'backgroundColor': '#e3f2fd', 'borderRadius': '8px', 'marginBottom': '20px'})
1367
 
1368
  line_graphs.append(graph_container_before)
1369
 
 
1378
  ablated_heads = activation_data.get('ablated_heads', [])
1379
  heads_str = ', '.join([f'H{h}' for h in sorted(ablated_heads)])
1380
 
1381
+ # Build children for after ablation graph
1382
+ after_children = [
1383
  html.H5("After Ablation", style={
1384
  'marginBottom': '10px', 'color': '#495057', 'fontSize': '16px',
1385
  'fontWeight': '600', 'borderLeft': '4px solid #ffb74d', 'paddingLeft': '10px'
 
1390
  f"This graph shows how probabilities changed after removing Layer {ablated_layer}, Heads {heads_str}. " +
1391
  "Compare with the graph above to see the impact of the ablation."
1392
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1393
+ dcc.Graph(figure=fig_after, config={'displayModeBar': False})
1394
+ ]
1395
+
1396
+ # Add actual output display
1397
+ actual_output_display_after = _create_actual_output_display(activation_data)
1398
+ if actual_output_display_after:
1399
+ after_children.append(actual_output_display_after)
1400
+
1401
+ # Add merge note at the end
1402
+ after_children.append(
1403
  html.Small("Note: Tokens with and without leading spaces (e.g., ' cat' and 'cat') are automatically merged.",
1404
  style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
1405
+ )
1406
+
1407
+ graph_container_after = html.Div(after_children,
1408
+ style={'padding': '15px', 'backgroundColor': '#fff3e0', 'borderRadius': '8px', 'marginBottom': '20px'})
1409
 
1410
  line_graphs.append(graph_container_after)
1411
 
 
1414
  fig = _create_top5_by_layer_graph(layer_wise_probs, significant_layers, global_top5)
1415
  if fig:
1416
  tooltip_text = ("This graph shows how confident the model is in its top 5 predictions as it processes through each layer. "
1417
+ "Yellow highlights mark layers where the model's confidence in the actual output token doubled (100% or more increase). "
1418
  "These are the layers where the model made important decisions. "
1419
+ "Click on the Transformer Layers section to see what each layer did.")
1420
 
1421
  merge_note = ("Note: Some tokens appear with a space before them (like ' cat') and some without (like 'cat'). "
1422
  "We automatically combine these to make the graph easier to read.")
1423
 
1424
+ # Create list of children for graph container
1425
+ graph_children = [
1426
  html.Div([
1427
  html.I(className="fas fa-info-circle",
1428
  style={'marginRight': '8px', 'color': '#667eea'}),
1429
  tooltip_text
1430
  ], style={'fontSize': '13px', 'color': '#6c757d', 'marginBottom': '10px', 'lineHeight': '1.5'}),
1431
+ dcc.Graph(figure=fig, config={'displayModeBar': False})
1432
+ ]
1433
+
1434
+ # Add actual output display
1435
+ actual_output_display = _create_actual_output_display(activation_data)
1436
+ if actual_output_display:
1437
+ graph_children.append(actual_output_display)
1438
+
1439
+ # Add merge note at the end
1440
+ graph_children.append(
1441
  html.Small(merge_note,
1442
  style={'fontSize': '11px', 'color': '#6c757d', 'fontStyle': 'italic'})
1443
+ )
1444
+
1445
+ graph_container = html.Div(graph_children, style={'marginBottom': '20px'})
1446
 
1447
  line_graphs.append(graph_container)
1448
 
 
1450
  if comparison_mode and layer_wise_probs2 and global_top5_2:
1451
  fig2 = _create_top5_by_layer_graph(layer_wise_probs2, significant_layers2, global_top5_2)
1452
  if fig2:
1453
+ # Build children for second prompt graph
1454
+ children2 = [
1455
  html.H6("Prompt 2", style={'color': '#495057', 'marginBottom': '10px'}),
1456
  dcc.Graph(figure=fig2, config={'displayModeBar': False})
1457
+ ]
1458
+
1459
+ # Add actual output display for second prompt
1460
+ actual_output_display2 = _create_actual_output_display(activation_data2)
1461
+ if actual_output_display2:
1462
+ children2.append(actual_output_display2)
1463
+
1464
+ graph_container2 = html.Div(children2, style={'marginTop': '20px'})
1465
  line_graphs.append(graph_container2)
1466
 
1467
  # Create stacked visual representation for collapsed state
components/__pycache__/sidebar.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/sidebar.cpython-311.pyc and b/components/__pycache__/sidebar.cpython-311.pyc differ
 
components/__pycache__/tokenization_panel.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/tokenization_panel.cpython-311.pyc and b/components/__pycache__/tokenization_panel.cpython-311.pyc differ
 
utils/__pycache__/head_detection.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/head_detection.cpython-311.pyc and b/utils/__pycache__/head_detection.cpython-311.pyc differ
 
utils/__pycache__/model_config.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_config.cpython-311.pyc and b/utils/__pycache__/model_config.cpython-311.pyc differ
 
utils/__pycache__/model_patterns.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
 
utils/__pycache__/prompt_comparison.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/prompt_comparison.cpython-311.pyc and b/utils/__pycache__/prompt_comparison.cpython-311.pyc differ
 
utils/model_patterns.py CHANGED
@@ -863,24 +863,25 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
863
 
864
  def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
865
  layer_wise_deltas: Dict[int, Dict[str, float]],
866
- threshold: float = 0.50) -> List[int]:
 
867
  """
868
- Detect layers where any global top 5 token has significant probability increase.
869
 
870
- A layer is significant if any token has ≥50% relative increase from previous layer.
871
- Example: 0.20 → 0.30 is (0.30-0.20)/0.20 = 50% increase.
872
 
873
- This threshold balances sensitivity (catching meaningful changes) with specificity
874
- (avoiding too many flagged layers). A 50% increase represents a substantial shift
875
- in the model's confidence that is pedagogically useful to highlight.
876
 
877
  Args:
878
  layer_wise_probs: Dict mapping layer_num → {token: prob}
879
  layer_wise_deltas: Dict mapping layer_num → {token: delta}
880
- threshold: Relative increase threshold (default: 0.50 = 50%)
 
881
 
882
  Returns:
883
- List of layer numbers with significant increases
884
  """
885
  significant_layers = []
886
 
@@ -888,8 +889,10 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
888
  probs = layer_wise_probs[layer_num]
889
  deltas = layer_wise_deltas.get(layer_num, {})
890
 
891
- for token, prob in probs.items():
892
- delta = deltas.get(token, 0.0)
 
 
893
  prev_prob = prob - delta
894
 
895
  # Check for significant relative increase (avoid division by zero)
@@ -897,7 +900,6 @@ def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[st
897
  relative_increase = delta / prev_prob
898
  if relative_increase >= threshold:
899
  significant_layers.append(layer_num)
900
- break # Only need to flag layer once
901
 
902
  return significant_layers
903
 
@@ -968,12 +970,13 @@ def _get_top_attended_tokens(activation_data: Dict[str, Any], layer_num: int, to
968
  return None
969
 
970
 
971
- def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]]) -> Dict[str, Any]:
972
  """
973
  Compute summary structures from layer data for easy access.
974
 
975
  Args:
976
  layer_data: List of layer data dicts from extract_layer_data()
 
977
 
978
  Returns:
979
  Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
@@ -987,12 +990,19 @@ def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]]) -> Dict[str,
987
  layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
988
  layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
989
 
990
- # Detect significant layers
991
- significant_layers = detect_significant_probability_increases(
992
- layer_wise_top5_probs,
993
- layer_wise_top5_deltas,
994
- threshold=0.50
995
- )
 
 
 
 
 
 
 
996
 
997
  return {
998
  'layer_wise_top5_probs': layer_wise_top5_probs,
 
863
 
864
  def detect_significant_probability_increases(layer_wise_probs: Dict[int, Dict[str, float]],
865
  layer_wise_deltas: Dict[int, Dict[str, float]],
866
+ actual_output_token: str,
867
+ threshold: float = 1.0) -> List[int]:
868
  """
869
+ Detect layers where the actual output token has significant probability increase.
870
 
871
+ A layer is significant if the actual output token has ≥100% relative increase from previous layer.
872
+ Example: 0.20 → 0.40 is (0.40-0.20)/0.20 = 100% increase.
873
 
874
+ This threshold highlights layers where the model's confidence in the actual output
875
+ doubles, representing a pedagogically significant shift in the prediction.
 
876
 
877
  Args:
878
  layer_wise_probs: Dict mapping layer_num → {token: prob}
879
  layer_wise_deltas: Dict mapping layer_num → {token: delta}
880
+ actual_output_token: The token that the model actually outputs (predicted token)
881
+ threshold: Relative increase threshold (default: 1.0 = 100%)
882
 
883
  Returns:
884
+ List of layer numbers with significant increases in the actual output token
885
  """
886
  significant_layers = []
887
 
 
889
  probs = layer_wise_probs[layer_num]
890
  deltas = layer_wise_deltas.get(layer_num, {})
891
 
892
+ # Only check the actual output token
893
+ if actual_output_token in probs:
894
+ prob = probs[actual_output_token]
895
+ delta = deltas.get(actual_output_token, 0.0)
896
  prev_prob = prob - delta
897
 
898
  # Check for significant relative increase (avoid division by zero)
 
900
  relative_increase = delta / prev_prob
901
  if relative_increase >= threshold:
902
  significant_layers.append(layer_num)
 
903
 
904
  return significant_layers
905
 
 
970
  return None
971
 
972
 
973
+ def compute_layer_wise_summaries(layer_data: List[Dict[str, Any]], activation_data: Dict[str, Any]) -> Dict[str, Any]:
974
  """
975
  Compute summary structures from layer data for easy access.
976
 
977
  Args:
978
  layer_data: List of layer data dicts from extract_layer_data()
979
+ activation_data: Activation data containing actual output token
980
 
981
  Returns:
982
  Dict with: layer_wise_top5_probs, layer_wise_top5_deltas, significant_layers
 
990
  layer_wise_top5_probs[layer_num] = layer_info.get('global_top5_probs', {})
991
  layer_wise_top5_deltas[layer_num] = layer_info.get('global_top5_deltas', {})
992
 
993
+ # Extract actual output token from activation data
994
+ actual_output = activation_data.get('actual_output', {})
995
+ actual_output_token = actual_output.get('token', '').strip() if actual_output else ''
996
+
997
+ # Detect significant layers based on actual output token
998
+ significant_layers = []
999
+ if actual_output_token:
1000
+ significant_layers = detect_significant_probability_increases(
1001
+ layer_wise_top5_probs,
1002
+ layer_wise_top5_deltas,
1003
+ actual_output_token,
1004
+ threshold=1.0
1005
+ )
1006
 
1007
  return {
1008
  'layer_wise_top5_probs': layer_wise_top5_probs,