cdpearlman commited on
Commit
2ad1c2e
·
1 Parent(s): c19d5a8

Ablation updated for full sequence, needs refactor for front-end and workflow

Browse files
app.py CHANGED
@@ -11,7 +11,8 @@ import json
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
  categorize_single_layer_heads, format_categorization_summary,
14
- compute_layer_wise_summaries, perform_beam_search, compute_sequence_trajectory)
 
15
  from utils.model_config import get_auto_selections, get_model_family
16
 
17
  # Import modular components
@@ -1729,10 +1730,11 @@ def handle_head_selection(n_clicks_list, selected_heads):
1729
  [State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
1730
  State('session-activation-store', 'data'),
1731
  State('model-dropdown', 'value'),
1732
- State('prompt-input', 'value')],
 
1733
  prevent_initial_call=True
1734
  )
1735
- def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model_name, prompt):
1736
  """Run forward pass with selected heads ablated."""
1737
  # Identify which button was clicked
1738
  ctx = dash.callback_context
@@ -1745,7 +1747,6 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1745
  return no_update, no_update, no_update
1746
 
1747
  # Find the index in the states_list that corresponds to this layer
1748
- # ctx.states_list contains the State values in order
1749
  button_index = None
1750
  if hasattr(ctx, 'states_list') and ctx.states_list:
1751
  # states_list[0] corresponds to selected-heads-store
@@ -1756,7 +1757,6 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1756
 
1757
  # Fallback: if states_list doesn't work, try matching by iterating
1758
  if button_index is None:
1759
- # This shouldn't happen, but as a fallback, just return error
1760
  return no_update, no_update, html.Div([
1761
  html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
1762
  f"Could not determine button index for layer {layer_num}"
@@ -1772,12 +1772,14 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1772
 
1773
  try:
1774
  from transformers import AutoModelForCausalLM, AutoTokenizer
1775
- from utils import execute_forward_pass_with_head_ablation
1776
 
1777
  # Save original activation data before ablation
1778
  import copy
1779
  original_data = copy.deepcopy(activation_data)
1780
 
 
 
 
1781
  # Load model and tokenizer
1782
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
1783
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -1789,18 +1791,86 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1789
  'norm_parameters': activation_data.get('norm_parameters', [])
1790
  }
1791
 
1792
- # Run ablation
1793
  ablated_data = execute_forward_pass_with_head_ablation(
1794
- model, tokenizer, prompt, config, layer_num, selected_heads
 
 
 
 
 
 
 
 
1795
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1796
 
1797
  # Update activation data with ablated results
1798
- # Mark as ablated for visual indication
1799
  ablated_data['ablated'] = True
1800
  ablated_data['ablated_layer'] = layer_num
1801
  ablated_data['ablated_heads'] = selected_heads
1802
 
1803
- # Preserve input_ids from original data if not present (prompt is unchanged)
1804
  if 'input_ids' not in ablated_data and 'input_ids' in activation_data:
1805
  ablated_data['input_ids'] = activation_data['input_ids']
1806
 
@@ -1808,7 +1878,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
1808
  heads_str = ', '.join([f"H{h}" for h in sorted(selected_heads)])
1809
  success_message = html.Div([
1810
  html.I(className="fas fa-check-circle", style={'marginRight': '8px', 'color': '#28a745'}),
1811
- f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
1812
  ], className="status-success")
1813
 
1814
  return ablated_data, original_data, success_message
@@ -1870,5 +1940,133 @@ def reset_ablation(n_clicks, original_data):
1870
  return original_data, {}, success_message
1871
 
1872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1873
  if __name__ == '__main__':
1874
  app.run(debug=True, port=8050)
 
11
  import torch
12
  from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
13
  categorize_single_layer_heads, format_categorization_summary,
14
+ compute_layer_wise_summaries, perform_beam_search, compute_sequence_trajectory,
15
+ execute_forward_pass_with_head_ablation, evaluate_sequence_ablation, score_sequence)
16
  from utils.model_config import get_auto_selections, get_model_family
17
 
18
  # Import modular components
 
1730
  [State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
1731
  State('session-activation-store', 'data'),
1732
  State('model-dropdown', 'value'),
1733
+ State('prompt-input', 'value'),
1734
+ State('generation-results-store', 'data')],
1735
  prevent_initial_call=True
1736
  )
1737
+ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model_name, prompt_input, generation_results):
1738
  """Run forward pass with selected heads ablated."""
1739
  # Identify which button was clicked
1740
  ctx = dash.callback_context
 
1747
  return no_update, no_update, no_update
1748
 
1749
  # Find the index in the states_list that corresponds to this layer
 
1750
  button_index = None
1751
  if hasattr(ctx, 'states_list') and ctx.states_list:
1752
  # states_list[0] corresponds to selected-heads-store
 
1757
 
1758
  # Fallback: if states_list doesn't work, try matching by iterating
1759
  if button_index is None:
 
1760
  return no_update, no_update, html.Div([
1761
  html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
1762
  f"Could not determine button index for layer {layer_num}"
 
1772
 
1773
  try:
1774
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
1775
 
1776
  # Save original activation data before ablation
1777
  import copy
1778
  original_data = copy.deepcopy(activation_data)
1779
 
1780
+ # Determine the sequence to analyze (prefer activation data prompt over input box)
1781
+ sequence_text = activation_data.get('prompt', prompt_input)
1782
+
1783
  # Load model and tokenizer
1784
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
1785
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
1791
  'norm_parameters': activation_data.get('norm_parameters', [])
1792
  }
1793
 
1794
+ # 1. Run Standard Ablation (Forward Pass)
1795
  ablated_data = execute_forward_pass_with_head_ablation(
1796
+ model, tokenizer, sequence_text, config, layer_num, selected_heads
1797
+ )
1798
+
1799
+ # 2. Compute Full Sequence Metrics (KL Divergence, Delta Probs)
1800
+ # This requires re-running passes (Original & Ablated) on the full sequence
1801
+ # We use a helper that handles the ablation hooking internally for the metric pass
1802
+ seq_metrics = evaluate_sequence_ablation(
1803
+ model, tokenizer, sequence_text, config,
1804
+ ablation_type='head', ablation_target=(layer_num, selected_heads)
1805
  )
1806
+ ablated_data['sequence_metrics'] = seq_metrics
1807
+
1808
+ # 3. Re-score Top Generated Sequences (if available)
1809
+ if generation_results:
1810
+ top_sequences_comparison = []
1811
+
1812
+ # Helper to run ablation for scoring (we need to apply hook again)
1813
+ # Since we can't easily pass 'ablated model' around, we re-apply hooks
1814
+ # Simplification: We already have 'evaluate_sequence_ablation'.
1815
+ # But that compares Ref vs Abl.
1816
+ # Here we just want Ablated Score.
1817
+ # Actually, `score_sequence` runs valid forward pass.
1818
+ # We need to apply ablation hooks to validly score.
1819
+
1820
+ # Define localized hook manager for scoring
1821
+ def get_ablated_score(seq_text):
1822
+ # Apply hook
1823
+ hooks = []
1824
+ def head_ablation_hook(module, input, output):
1825
+ # Similar to evaluate_sequence_ablation hook
1826
+ if isinstance(output, tuple): h = output[0]
1827
+ else: h = output
1828
+ if not isinstance(h, torch.Tensor): h = torch.tensor(h)
1829
+
1830
+ num_heads = model.config.num_attention_heads
1831
+ head_dim = h.shape[-1] // num_heads
1832
+ new_shape = h.shape[:-1] + (num_heads, head_dim)
1833
+ reshaped = h.view(new_shape).clone()
1834
+ for h_idx in selected_heads: reshaped[..., h_idx, :] = 0
1835
+ ablated = reshaped.view(h.shape)
1836
+ return (ablated,) + output[1:] if isinstance(output, tuple) else ablated
1837
+
1838
+ # Register
1839
+ target_module = None
1840
+ for name, mod in model.named_modules():
1841
+ if f"layers.{layer_num}.self_attn" in name or f"h.{layer_num}.attn" in name:
1842
+ if "k_proj" not in name:
1843
+ target_module = mod; break
1844
+
1845
+ if target_module:
1846
+ hooks.append(target_module.register_forward_hook(head_ablation_hook))
1847
+
1848
+ try:
1849
+ score = score_sequence(model, tokenizer, seq_text)
1850
+ finally:
1851
+ for hook in hooks: hook.remove()
1852
+ return score
1853
+
1854
+ for res in generation_results:
1855
+ txt = res['text']
1856
+ orig_score = res['score']
1857
+ new_score = get_ablated_score(txt)
1858
+ top_sequences_comparison.append({
1859
+ 'text': txt,
1860
+ 'original_score': orig_score,
1861
+ 'ablated_score': new_score,
1862
+ 'delta': new_score - orig_score
1863
+ })
1864
+
1865
+ ablated_data['top_sequences_comparison'] = top_sequences_comparison
1866
+
1867
 
1868
  # Update activation data with ablated results
 
1869
  ablated_data['ablated'] = True
1870
  ablated_data['ablated_layer'] = layer_num
1871
  ablated_data['ablated_heads'] = selected_heads
1872
 
1873
+ # Preserve input_ids if needed
1874
  if 'input_ids' not in ablated_data and 'input_ids' in activation_data:
1875
  ablated_data['input_ids'] = activation_data['input_ids']
1876
 
 
1878
  heads_str = ', '.join([f"H{h}" for h in sorted(selected_heads)])
1879
  success_message = html.Div([
1880
  html.I(className="fas fa-check-circle", style={'marginRight': '8px', 'color': '#28a745'}),
1881
+ f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed. Scroll down for sequence analysis."
1882
  ], className="status-success")
1883
 
1884
  return ablated_data, original_data, success_message
 
1940
  return original_data, {}, success_message
1941
 
1942
 
1943
+
1944
+ # Callback to update sequence ablation analysis view
1945
+ @app.callback(
1946
+ [Output('sequence-ablation-results-container', 'children'),
1947
+ Output('sequence-ablation-results-container', 'style')],
1948
+ Input('session-activation-store', 'data'),
1949
+ prevent_initial_call=False
1950
+ )
1951
+ def update_sequence_ablation_view(activation_data):
1952
+ """Update the sequence ablation results view (KL Divergence, Sequence Comparison)."""
1953
+ if not activation_data or not activation_data.get('ablated', False):
1954
+ return [], {'display': 'none'}
1955
+
1956
+ try:
1957
+ import plotly.graph_objs as go
1958
+ from dash import html, dcc
1959
+
1960
+ children = []
1961
+
1962
+ # 1. Header
1963
+ children.append(html.H3("Full Sequence Ablation Analysis", style={'marginTop': '0', 'marginBottom': '20px', 'color': '#2d3748'}))
1964
+
1965
+ # 2. Top-5 Sequence Comparison Table
1966
+ top_seqs = activation_data.get('top_sequences_comparison', [])
1967
+ if top_seqs:
1968
+ rows = []
1969
+ for i, seq in enumerate(top_seqs):
1970
+ delta = seq['delta']
1971
+ delta_color = '#28a745' if delta > 0 else '#dc3545' if delta < 0 else '#6c757d'
1972
+
1973
+ rows.append(html.Tr([
1974
+ html.Td(f"#{i+1}", style={'fontWeight': 'bold'}),
1975
+ html.Td(seq['text'], style={'fontFamily': 'monospace', 'maxWidth': '400px', 'overflow': 'hidden', 'textOverflow': 'ellipsis', 'whiteSpace': 'nowrap'}),
1976
+ html.Td(f"{seq['original_score']:.4f}"),
1977
+ html.Td(f"{seq['ablated_score']:.4f}"),
1978
+ html.Td(f"{delta:+.4f}", style={'color': delta_color, 'fontWeight': 'bold'})
1979
+ ]))
1980
+
1981
+ table_header = html.Thead(html.Tr([
1982
+ html.Th("Rank"), html.Th("Sequence"), html.Th("Original Score"), html.Th("Ablated Score"), html.Th("Delta")
1983
+ ]))
1984
+ table_body = html.Tbody(rows)
1985
+
1986
+ children.append(html.Div([
1987
+ html.H5("Top Sequences Impact", style={'marginBottom': '10px'}),
1988
+ html.Table([table_header, table_body], className="table table-striped table-bordered")
1989
+ ], style={'marginBottom': '30px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
1990
+
1991
+ # 3. KL Divergence Chart
1992
+ seq_metrics = activation_data.get('sequence_metrics', {})
1993
+ kl_divs = seq_metrics.get('kl_divergence', [])
1994
+ tokens = seq_metrics.get('tokens', [])
1995
+
1996
+ if kl_divs:
1997
+ # KL Chart
1998
+ fig_kl = go.Figure()
1999
+ fig_kl.add_trace(go.Scatter(
2000
+ x=list(range(len(kl_divs))),
2001
+ y=kl_divs,
2002
+ mode='lines+markers',
2003
+ name='KL Divergence',
2004
+ line=dict(color='#6610f2', width=2),
2005
+ hovertext=[f"Token: {t}<br>KL: {v:.4f}" for t, v in zip(tokens, kl_divs)],
2006
+ hoverinfo='text'
2007
+ ))
2008
+
2009
+ fig_kl.update_layout(
2010
+ title="KL Divergence per Position (Distribution Shift)",
2011
+ xaxis_title="Position / Token",
2012
+ yaxis_title="KL Divergence (nats)",
2013
+ margin=dict(l=20, r=20, t=40, b=20),
2014
+ height=300,
2015
+ xaxis=dict(
2016
+ tickmode='array',
2017
+ tickvals=list(range(len(tokens))),
2018
+ ticktext=tokens
2019
+ )
2020
+ )
2021
+
2022
+ children.append(html.Div([
2023
+ dcc.Graph(figure=fig_kl, config={'displayModeBar': False})
2024
+ ], style={'marginBottom': '20px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
2025
+
2026
+ # 4. Target Probability Deltas Chart
2027
+ prob_deltas = seq_metrics.get('probability_deltas', [])
2028
+ if prob_deltas:
2029
+ # Shift tokens for x-axis (deltas are for prediction of next token)
2030
+ # Input: T0, T1, T2
2031
+ # Delta 0: Change in P(T1|T0)
2032
+ # So x-axis should be T1, T2...
2033
+ target_tokens = tokens[1:] if len(tokens) > 1 else []
2034
+
2035
+ fig_delta = go.Figure()
2036
+ fig_delta.add_trace(go.Bar(
2037
+ x=list(range(len(prob_deltas))),
2038
+ y=prob_deltas,
2039
+ name='Prob Delta',
2040
+ marker_color=['#28a745' if v >= 0 else '#dc3545' for v in prob_deltas],
2041
+ hovertext=[f"Target: {t}<br>Change: {v:+.4f}" for t, v in zip(target_tokens, prob_deltas)],
2042
+ hoverinfo='text'
2043
+ ))
2044
+
2045
+ fig_delta.update_layout(
2046
+ title="Target Probability Change per Position",
2047
+ xaxis_title="Target Token",
2048
+ yaxis_title="Probability Delta",
2049
+ margin=dict(l=20, r=20, t=40, b=20),
2050
+ height=300,
2051
+ xaxis=dict(
2052
+ tickmode='array',
2053
+ tickvals=list(range(len(target_tokens))),
2054
+ ticktext=target_tokens
2055
+ )
2056
+ )
2057
+
2058
+ children.append(html.Div([
2059
+ dcc.Graph(figure=fig_delta, config={'displayModeBar': False})
2060
+ ], style={'marginBottom': '20px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
2061
+
2062
+ return children, {'display': 'block', 'marginTop': '30px', 'paddingTop': '30px', 'borderTop': '1px solid #dee2e6'}
2063
+
2064
+ except Exception as e:
2065
+ print(f"Error in ablation view: {e}")
2066
+ import traceback
2067
+ traceback.print_exc()
2068
+ return html.Div(f"Error loading visualization: {str(e)}"), {'display': 'block'}
2069
+
2070
+
2071
  if __name__ == '__main__':
2072
  app.run(debug=True, port=8050)
components/main_panel.py CHANGED
@@ -105,7 +105,10 @@ def create_main_panel():
105
  html.I(className="fas fa-spinner fa-spin", style={'fontSize': '24px', 'color': '#667eea', 'marginRight': '10px'}),
106
  html.Span("Loading visuals...", style={'fontSize': '16px', 'color': '#495057'})
107
  ], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center', 'padding': '2rem'})
108
- )
 
 
 
109
  ], className="visualization-section")
110
  ])
111
  ], id="analysis-view-container", style={'display': 'none'}) # Hidden by default
 
105
  html.I(className="fas fa-spinner fa-spin", style={'fontSize': '24px', 'color': '#667eea', 'marginRight': '10px'}),
106
  html.Span("Loading visuals...", style={'fontSize': '16px', 'color': '#495057'})
107
  ], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center', 'padding': '2rem'})
108
+ ),
109
+
110
+ # Sequence Ablation Results (New)
111
+ html.Div(id='sequence-ablation-results-container', style={'marginTop': '30px', 'display': 'none'})
112
  ], className="visualization-section")
113
  ])
114
  ], id="analysis-view-container", style={'display': 'none'}) # Hidden by default
utils/__init__.py CHANGED
@@ -1,14 +1,17 @@
1
- from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation, execute_forward_pass_with_head_ablation, merge_token_probabilities, compute_global_top5_tokens, detect_significant_probability_increases, compute_layer_wise_summaries
2
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
3
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
4
  from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
5
  from .beam_search import perform_beam_search, compute_sequence_trajectory
 
 
6
 
7
  __all__ = [
8
  'load_model_and_get_patterns',
9
  'execute_forward_pass',
10
  'execute_forward_pass_with_layer_ablation',
11
  'execute_forward_pass_with_head_ablation',
 
12
  'logit_lens_transformation',
13
  'extract_layer_data',
14
  'generate_bertviz_html',
@@ -32,5 +35,9 @@ __all__ = [
32
  'format_comparison_summary',
33
  'ComparisonConfig',
34
  'perform_beam_search',
35
- 'compute_sequence_trajectory'
 
 
 
 
36
  ]
 
1
+ from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation, execute_forward_pass_with_head_ablation, merge_token_probabilities, compute_global_top5_tokens, detect_significant_probability_increases, compute_layer_wise_summaries, evaluate_sequence_ablation
2
  from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
3
  from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
4
  from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
5
  from .beam_search import perform_beam_search, compute_sequence_trajectory
6
+ from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
7
+
8
 
9
  __all__ = [
10
  'load_model_and_get_patterns',
11
  'execute_forward_pass',
12
  'execute_forward_pass_with_layer_ablation',
13
  'execute_forward_pass_with_head_ablation',
14
+ 'evaluate_sequence_ablation',
15
  'logit_lens_transformation',
16
  'extract_layer_data',
17
  'generate_bertviz_html',
 
35
  'format_comparison_summary',
36
  'ComparisonConfig',
37
  'perform_beam_search',
38
+ 'perform_beam_search',
39
+ 'compute_sequence_trajectory',
40
+ 'compute_kl_divergence',
41
+ 'score_sequence',
42
+ 'get_token_probability_deltas'
43
  ]
utils/ablation_metrics.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+
6
+ def compute_kl_divergence(logits_p: torch.Tensor, logits_q: torch.Tensor) -> List[float]:
7
+ """
8
+ Compute KL Divergence KL(P || Q) for each position in the sequence.
9
+ P is the reference distribution (logits_p), Q is the ablated distribution (logits_q).
10
+
11
+ Args:
12
+ logits_p: Reference logits [batch, seq_len, vocab_size]
13
+ logits_q: Ablated logits [batch, seq_len, vocab_size]
14
+
15
+ Returns:
16
+ List of KL divergence values for each position.
17
+ """
18
+ with torch.no_grad():
19
+ # Ensure batch size 1 or handle appropriately
20
+ if logits_p.dim() == 3:
21
+ logits_p = logits_p.squeeze(0)
22
+ if logits_q.dim() == 3:
23
+ logits_q = logits_q.squeeze(0)
24
+
25
+ # P = softmax(logits_p)
26
+ # Q = softmax(logits_q)
27
+ # KL(P||Q) = sum(P * (log P - log Q))
28
+
29
+ # Use log_softmax for stability
30
+ log_probs_p = F.log_softmax(logits_p, dim=-1)
31
+ log_probs_q = F.log_softmax(logits_q, dim=-1)
32
+ probs_p = torch.exp(log_probs_p)
33
+
34
+ # Element-wise KL
35
+ kl_divs = torch.sum(probs_p * (log_probs_p - log_probs_q), dim=-1)
36
+
37
+ return kl_divs.tolist()
38
+
39
+ def score_sequence(model, tokenizer, text: str) -> float:
40
+ """
41
+ Compute the total log probability (score) of a text sequence.
42
+
43
+ Args:
44
+ model: HuggingFace model
45
+ tokenizer: Tokenizer
46
+ text: The sequence to score
47
+
48
+ Returns:
49
+ Total log probability.
50
+ """
51
+ inputs = tokenizer(text, return_tensors="pt")
52
+ input_ids = inputs["input_ids"].to(model.device)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(input_ids)
56
+ logits = outputs.logits # [1, seq_len, vocab_size]
57
+
58
+ # We want P(token_i | tokens_<i)
59
+ # The logits at position i-1 predict position i
60
+
61
+ # Shift logits and labels
62
+ shift_logits = logits[0, :-1, :].contiguous()
63
+ shift_labels = input_ids[0, 1:].contiguous()
64
+
65
+ # Helper to pick specific token probabilities
66
+ # log_softmax
67
+ log_probs_all = F.log_softmax(shift_logits, dim=-1)
68
+
69
+ # Gather only the target label log probs
70
+ # gather needs index column vector
71
+ target_log_probs = log_probs_all.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
72
+
73
+ total_score = target_log_probs.sum().item()
74
+
75
+ return total_score
76
+
77
+ def get_token_probability_deltas(logits_ref: torch.Tensor, logits_abl: torch.Tensor, input_ids: torch.Tensor) -> List[float]:
78
+ """
79
+ Compute the change in probability (Prob_abl - Prob_ref) for the actual target tokens.
80
+
81
+ Args:
82
+ logits_ref: Reference logits
83
+ logits_abl: Ablated logits
84
+ input_ids: The sequence token IDs [1, seq_len]
85
+
86
+ Returns:
87
+ List of probability deltas for each position (starting from first prediction).
88
+ """
89
+ with torch.no_grad():
90
+ if logits_ref.dim() == 3: logits_ref = logits_ref.squeeze(0)
91
+ if logits_abl.dim() == 3: logits_abl = logits_abl.squeeze(0)
92
+
93
+ target_ids = input_ids[0, 1:] # Targets are from index 1 onwards
94
+
95
+ # Probabilities
96
+ probs_ref = F.softmax(logits_ref[:-1], dim=-1) # Predicts 1..N
97
+ probs_abl = F.softmax(logits_abl[:-1], dim=-1)
98
+
99
+ # Gather target probs
100
+ ref_target_probs = probs_ref.gather(1, target_ids.unsqueeze(1)).squeeze(1)
101
+ abl_target_probs = probs_abl.gather(1, target_ids.unsqueeze(1)).squeeze(1)
102
+
103
+ deltas = (abl_target_probs - ref_target_probs).tolist()
104
+
105
+ return deltas
utils/model_patterns.py CHANGED
@@ -614,6 +614,160 @@ def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, conf
614
  return result
615
 
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
618
  """
619
  Transform layer output to top K token probabilities using logit lens.
 
614
  return result
615
 
616
 
617
+ def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
618
+ ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
619
+ """
620
+ Evaluate the impact of ablation on a full sequence.
621
+
622
+ This runs TWO forward passes on the FULL sequence:
623
+ 1. Reference pass (original model) -> Capture logits/probs
624
+ 2. Ablated pass (modified model) -> Capture logits/probs
625
+
626
+ Then computes metrics: KL Divergence, Target Prob Changes.
627
+
628
+ Args:
629
+ model: Loaded transformer model
630
+ tokenizer: Tokenizer
631
+ sequence_text: The full text sequence to evaluate
632
+ config: Module configuration (needed for ablation setup)
633
+ ablation_type: 'head' or 'layer'
634
+ ablation_target: tuple (layer, head_indices) or int (layer_num)
635
+
636
+ Returns:
637
+ Dict with evaluation metrics.
638
+ """
639
+ from .ablation_metrics import compute_kl_divergence, get_token_probability_deltas
640
+
641
+ print(f"Evaluating sequence ablation: Type={ablation_type}, Target={ablation_target}")
642
+
643
+ inputs = tokenizer(sequence_text, return_tensors="pt")
644
+ input_ids = inputs["input_ids"].to(model.device)
645
+
646
+ # --- 1. Reference Pass ---
647
+ with torch.no_grad():
648
+ outputs_ref = model(input_ids)
649
+ logits_ref = outputs_ref.logits # [1, seq_len, vocab_size]
650
+
651
+ # --- 2. Ablated Pass ---
652
+ # Setup ablation based on type
653
+
654
+ # We need to wrap the model using PyVene logic or custom hooks just for this pass
655
+ # Since we already have logic in execute_forward_pass_with_..._ablation, we can reuse the Hook logic
656
+ # But we want the full logits, not just captured activations.
657
+
658
+ # Let's manually register hooks here for simplicity and control
659
+ hooks = []
660
+
661
+ def head_ablation_hook_factory(layer_idx, head_indices):
662
+ def hook(module, input, output):
663
+ # output is (hidden_states, ...) or hidden_states
664
+ if isinstance(output, tuple):
665
+ hidden_states = output[0]
666
+ else:
667
+ hidden_states = output
668
+
669
+ # Assume hidden_states is [batch, seq, hidden]
670
+ # Reshape, zero out heads, Reshape back
671
+ if not isinstance(hidden_states, torch.Tensor):
672
+ if isinstance(hidden_states, list): hidden_states = torch.tensor(hidden_states)
673
+
674
+ # Move to device if needed? They should be on device.
675
+
676
+ num_heads = model.config.num_attention_heads
677
+ head_dim = hidden_states.shape[-1] // num_heads
678
+
679
+ # view: [batch, seq, heads, dim]
680
+ new_shape = hidden_states.shape[:-1] + (num_heads, head_dim)
681
+ reshaped = hidden_states.view(new_shape)
682
+
683
+ # Create mask or just zero out
684
+ # We can't modify in place securely with autograd usually, but here no_grad is on.
685
+ # Clone to be safe
686
+ reshaped = reshaped.clone()
687
+
688
+ for h_idx in head_indices:
689
+ reshaped[..., h_idx, :] = 0
690
+
691
+ ablated_hidden = reshaped.view(hidden_states.shape)
692
+
693
+ if isinstance(output, tuple):
694
+ return (ablated_hidden,) + output[1:]
695
+ return ablated_hidden
696
+ return hook
697
+
698
+ # Hook for Layer Ablation (Identity/Skip or Zero)
699
+ # We'll use Identity (Skip Layer) as a simpler approximation of "removing logic"
700
+ # OR Mean Ablation if we had the mean.
701
+ # For now, let's just do nothing for layer ablation or return error,
702
+ # as the user primarily asks for "ablation experiment updates" which often means Heads.
703
+ # But to be safe, let's implement the same Mean Ablation if possible, or Identity.
704
+ # Identity (Skip) is easier:
705
+ def identity_hook(module, input, output):
706
+ # input is tuple (hidden_states, ...)
707
+ return input if isinstance(input, tuple) else (input,)
708
+
709
+ try:
710
+ if ablation_type == 'head':
711
+ layer_num, head_indices = ablation_target
712
+ # Find module
713
+ # Standard transformers: model.layers[i].self_attn
714
+ # We need the exact module name map standard to HuggingFace
715
+ # Or use the config's mapping if available.
716
+ # Let's rely on standard naming or search
717
+
718
+ # Simple heuristic: find 'layers.X.self_attn' or 'h.X.attn'
719
+ target_module = None
720
+ for name, mod in model.named_modules():
721
+ # Check for standard patterns
722
+ # layer_num is int
723
+ if f"layers.{layer_num}.self_attn" in name or f"h.{layer_num}.attn" in name or f"blocks.{layer_num}.attn" in name:
724
+ if "k_proj" not in name and "v_proj" not in name and "q_proj" not in name: # avoid submodules
725
+ target_module = mod
726
+ break
727
+
728
+ if target_module:
729
+ hooks.append(target_module.register_forward_hook(head_ablation_hook_factory(layer_num, head_indices)))
730
+ else:
731
+ print(f"Warning: Could not find attention module for layer {layer_num}")
732
+
733
+ elif ablation_type == 'layer':
734
+ layer_num = ablation_target
735
+ target_module = None
736
+ for name, mod in model.named_modules():
737
+ # Layers are usually 'model.layers.X' or 'transformer.h.X'
738
+ # We want the module that corresponds to the layer block
739
+ # Be careful not to pick 'layers.X.mlp'
740
+ if (f"layers.{layer_num}" in name or f"h.{layer_num}" in name) and name.count('.') <= 2: # heuristic for top-level layer
741
+ target_module = mod
742
+ break
743
+
744
+ if target_module:
745
+ # Skip layer (Identity)
746
+ hooks.append(target_module.register_forward_hook(lambda m, i, o: i[0] if isinstance(i, tuple) else i))
747
+
748
+ # Run Ablated Pass
749
+ with torch.no_grad():
750
+ outputs_abl = model(input_ids)
751
+ logits_abl = outputs_abl.logits
752
+
753
+ finally:
754
+ for hook in hooks:
755
+ hook.remove()
756
+
757
+ # --- 3. Compute Metrics ---
758
+ # KL Divergence [seq_len]
759
+ kl_div = compute_kl_divergence(logits_ref, logits_abl)
760
+
761
+ # Prob Deltas for actual tokens [seq_len-1] (shifted)
762
+ prob_deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
763
+
764
+ return {
765
+ "kl_divergence": kl_div,
766
+ "probability_deltas": prob_deltas,
767
+ "tokens": [tokenizer.decode([tid]) for tid in input_ids[0].tolist()]
768
+ }
769
+
770
+
771
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
772
  """
773
  Transform layer output to top K token probabilities using logit lens.