Spaces:
Sleeping
Sleeping
Commit ·
2ad1c2e
1
Parent(s): c19d5a8
Ablation updated for full sequence, needs refactor for front-end and workflow
Browse files- app.py +209 -11
- components/main_panel.py +4 -1
- utils/__init__.py +9 -2
- utils/ablation_metrics.py +105 -0
- utils/model_patterns.py +154 -0
app.py
CHANGED
|
@@ -11,7 +11,8 @@ import json
|
|
| 11 |
import torch
|
| 12 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 13 |
categorize_single_layer_heads, format_categorization_summary,
|
| 14 |
-
compute_layer_wise_summaries, perform_beam_search, compute_sequence_trajectory
|
|
|
|
| 15 |
from utils.model_config import get_auto_selections, get_model_family
|
| 16 |
|
| 17 |
# Import modular components
|
|
@@ -1729,10 +1730,11 @@ def handle_head_selection(n_clicks_list, selected_heads):
|
|
| 1729 |
[State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
|
| 1730 |
State('session-activation-store', 'data'),
|
| 1731 |
State('model-dropdown', 'value'),
|
| 1732 |
-
State('prompt-input', 'value')
|
|
|
|
| 1733 |
prevent_initial_call=True
|
| 1734 |
)
|
| 1735 |
-
def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model_name,
|
| 1736 |
"""Run forward pass with selected heads ablated."""
|
| 1737 |
# Identify which button was clicked
|
| 1738 |
ctx = dash.callback_context
|
|
@@ -1745,7 +1747,6 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1745 |
return no_update, no_update, no_update
|
| 1746 |
|
| 1747 |
# Find the index in the states_list that corresponds to this layer
|
| 1748 |
-
# ctx.states_list contains the State values in order
|
| 1749 |
button_index = None
|
| 1750 |
if hasattr(ctx, 'states_list') and ctx.states_list:
|
| 1751 |
# states_list[0] corresponds to selected-heads-store
|
|
@@ -1756,7 +1757,6 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1756 |
|
| 1757 |
# Fallback: if states_list doesn't work, try matching by iterating
|
| 1758 |
if button_index is None:
|
| 1759 |
-
# This shouldn't happen, but as a fallback, just return error
|
| 1760 |
return no_update, no_update, html.Div([
|
| 1761 |
html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
|
| 1762 |
f"Could not determine button index for layer {layer_num}"
|
|
@@ -1772,12 +1772,14 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1772 |
|
| 1773 |
try:
|
| 1774 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 1775 |
-
from utils import execute_forward_pass_with_head_ablation
|
| 1776 |
|
| 1777 |
# Save original activation data before ablation
|
| 1778 |
import copy
|
| 1779 |
original_data = copy.deepcopy(activation_data)
|
| 1780 |
|
|
|
|
|
|
|
|
|
|
| 1781 |
# Load model and tokenizer
|
| 1782 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 1783 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
@@ -1789,18 +1791,86 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1789 |
'norm_parameters': activation_data.get('norm_parameters', [])
|
| 1790 |
}
|
| 1791 |
|
| 1792 |
-
# Run
|
| 1793 |
ablated_data = execute_forward_pass_with_head_ablation(
|
| 1794 |
-
model, tokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1795 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1796 |
|
| 1797 |
# Update activation data with ablated results
|
| 1798 |
-
# Mark as ablated for visual indication
|
| 1799 |
ablated_data['ablated'] = True
|
| 1800 |
ablated_data['ablated_layer'] = layer_num
|
| 1801 |
ablated_data['ablated_heads'] = selected_heads
|
| 1802 |
|
| 1803 |
-
# Preserve input_ids
|
| 1804 |
if 'input_ids' not in ablated_data and 'input_ids' in activation_data:
|
| 1805 |
ablated_data['input_ids'] = activation_data['input_ids']
|
| 1806 |
|
|
@@ -1808,7 +1878,7 @@ def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model
|
|
| 1808 |
heads_str = ', '.join([f"H{h}" for h in sorted(selected_heads)])
|
| 1809 |
success_message = html.Div([
|
| 1810 |
html.I(className="fas fa-check-circle", style={'marginRight': '8px', 'color': '#28a745'}),
|
| 1811 |
-
f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed"
|
| 1812 |
], className="status-success")
|
| 1813 |
|
| 1814 |
return ablated_data, original_data, success_message
|
|
@@ -1870,5 +1940,133 @@ def reset_ablation(n_clicks, original_data):
|
|
| 1870 |
return original_data, {}, success_message
|
| 1871 |
|
| 1872 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1873 |
if __name__ == '__main__':
|
| 1874 |
app.run(debug=True, port=8050)
|
|
|
|
| 11 |
import torch
|
| 12 |
from utils import (load_model_and_get_patterns, execute_forward_pass, extract_layer_data,
|
| 13 |
categorize_single_layer_heads, format_categorization_summary,
|
| 14 |
+
compute_layer_wise_summaries, perform_beam_search, compute_sequence_trajectory,
|
| 15 |
+
execute_forward_pass_with_head_ablation, evaluate_sequence_ablation, score_sequence)
|
| 16 |
from utils.model_config import get_auto_selections, get_model_family
|
| 17 |
|
| 18 |
# Import modular components
|
|
|
|
| 1730 |
[State({'type': 'selected-heads-store', 'layer': ALL}, 'data'),
|
| 1731 |
State('session-activation-store', 'data'),
|
| 1732 |
State('model-dropdown', 'value'),
|
| 1733 |
+
State('prompt-input', 'value'),
|
| 1734 |
+
State('generation-results-store', 'data')],
|
| 1735 |
prevent_initial_call=True
|
| 1736 |
)
|
| 1737 |
+
def run_head_ablation(n_clicks_list, selected_heads_list, activation_data, model_name, prompt_input, generation_results):
|
| 1738 |
"""Run forward pass with selected heads ablated."""
|
| 1739 |
# Identify which button was clicked
|
| 1740 |
ctx = dash.callback_context
|
|
|
|
| 1747 |
return no_update, no_update, no_update
|
| 1748 |
|
| 1749 |
# Find the index in the states_list that corresponds to this layer
|
|
|
|
| 1750 |
button_index = None
|
| 1751 |
if hasattr(ctx, 'states_list') and ctx.states_list:
|
| 1752 |
# states_list[0] corresponds to selected-heads-store
|
|
|
|
| 1757 |
|
| 1758 |
# Fallback: if states_list doesn't work, try matching by iterating
|
| 1759 |
if button_index is None:
|
|
|
|
| 1760 |
return no_update, no_update, html.Div([
|
| 1761 |
html.I(className="fas fa-exclamation-circle", style={'marginRight': '8px', 'color': '#dc3545'}),
|
| 1762 |
f"Could not determine button index for layer {layer_num}"
|
|
|
|
| 1772 |
|
| 1773 |
try:
|
| 1774 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 1775 |
|
| 1776 |
# Save original activation data before ablation
|
| 1777 |
import copy
|
| 1778 |
original_data = copy.deepcopy(activation_data)
|
| 1779 |
|
| 1780 |
+
# Determine the sequence to analyze (prefer activation data prompt over input box)
|
| 1781 |
+
sequence_text = activation_data.get('prompt', prompt_input)
|
| 1782 |
+
|
| 1783 |
# Load model and tokenizer
|
| 1784 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 1785 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
| 1791 |
'norm_parameters': activation_data.get('norm_parameters', [])
|
| 1792 |
}
|
| 1793 |
|
| 1794 |
+
# 1. Run Standard Ablation (Forward Pass)
|
| 1795 |
ablated_data = execute_forward_pass_with_head_ablation(
|
| 1796 |
+
model, tokenizer, sequence_text, config, layer_num, selected_heads
|
| 1797 |
+
)
|
| 1798 |
+
|
| 1799 |
+
# 2. Compute Full Sequence Metrics (KL Divergence, Delta Probs)
|
| 1800 |
+
# This requires re-running passes (Original & Ablated) on the full sequence
|
| 1801 |
+
# We use a helper that handles the ablation hooking internally for the metric pass
|
| 1802 |
+
seq_metrics = evaluate_sequence_ablation(
|
| 1803 |
+
model, tokenizer, sequence_text, config,
|
| 1804 |
+
ablation_type='head', ablation_target=(layer_num, selected_heads)
|
| 1805 |
)
|
| 1806 |
+
ablated_data['sequence_metrics'] = seq_metrics
|
| 1807 |
+
|
| 1808 |
+
# 3. Re-score Top Generated Sequences (if available)
|
| 1809 |
+
if generation_results:
|
| 1810 |
+
top_sequences_comparison = []
|
| 1811 |
+
|
| 1812 |
+
# Helper to run ablation for scoring (we need to apply hook again)
|
| 1813 |
+
# Since we can't easily pass 'ablated model' around, we re-apply hooks
|
| 1814 |
+
# Simplification: We already have 'evaluate_sequence_ablation'.
|
| 1815 |
+
# But that compares Ref vs Abl.
|
| 1816 |
+
# Here we just want Ablated Score.
|
| 1817 |
+
# Actually, `score_sequence` runs valid forward pass.
|
| 1818 |
+
# We need to apply ablation hooks to validly score.
|
| 1819 |
+
|
| 1820 |
+
# Define localized hook manager for scoring
|
| 1821 |
+
def get_ablated_score(seq_text):
|
| 1822 |
+
# Apply hook
|
| 1823 |
+
hooks = []
|
| 1824 |
+
def head_ablation_hook(module, input, output):
|
| 1825 |
+
# Similar to evaluate_sequence_ablation hook
|
| 1826 |
+
if isinstance(output, tuple): h = output[0]
|
| 1827 |
+
else: h = output
|
| 1828 |
+
if not isinstance(h, torch.Tensor): h = torch.tensor(h)
|
| 1829 |
+
|
| 1830 |
+
num_heads = model.config.num_attention_heads
|
| 1831 |
+
head_dim = h.shape[-1] // num_heads
|
| 1832 |
+
new_shape = h.shape[:-1] + (num_heads, head_dim)
|
| 1833 |
+
reshaped = h.view(new_shape).clone()
|
| 1834 |
+
for h_idx in selected_heads: reshaped[..., h_idx, :] = 0
|
| 1835 |
+
ablated = reshaped.view(h.shape)
|
| 1836 |
+
return (ablated,) + output[1:] if isinstance(output, tuple) else ablated
|
| 1837 |
+
|
| 1838 |
+
# Register
|
| 1839 |
+
target_module = None
|
| 1840 |
+
for name, mod in model.named_modules():
|
| 1841 |
+
if f"layers.{layer_num}.self_attn" in name or f"h.{layer_num}.attn" in name:
|
| 1842 |
+
if "k_proj" not in name:
|
| 1843 |
+
target_module = mod; break
|
| 1844 |
+
|
| 1845 |
+
if target_module:
|
| 1846 |
+
hooks.append(target_module.register_forward_hook(head_ablation_hook))
|
| 1847 |
+
|
| 1848 |
+
try:
|
| 1849 |
+
score = score_sequence(model, tokenizer, seq_text)
|
| 1850 |
+
finally:
|
| 1851 |
+
for hook in hooks: hook.remove()
|
| 1852 |
+
return score
|
| 1853 |
+
|
| 1854 |
+
for res in generation_results:
|
| 1855 |
+
txt = res['text']
|
| 1856 |
+
orig_score = res['score']
|
| 1857 |
+
new_score = get_ablated_score(txt)
|
| 1858 |
+
top_sequences_comparison.append({
|
| 1859 |
+
'text': txt,
|
| 1860 |
+
'original_score': orig_score,
|
| 1861 |
+
'ablated_score': new_score,
|
| 1862 |
+
'delta': new_score - orig_score
|
| 1863 |
+
})
|
| 1864 |
+
|
| 1865 |
+
ablated_data['top_sequences_comparison'] = top_sequences_comparison
|
| 1866 |
+
|
| 1867 |
|
| 1868 |
# Update activation data with ablated results
|
|
|
|
| 1869 |
ablated_data['ablated'] = True
|
| 1870 |
ablated_data['ablated_layer'] = layer_num
|
| 1871 |
ablated_data['ablated_heads'] = selected_heads
|
| 1872 |
|
| 1873 |
+
# Preserve input_ids if needed
|
| 1874 |
if 'input_ids' not in ablated_data and 'input_ids' in activation_data:
|
| 1875 |
ablated_data['input_ids'] = activation_data['input_ids']
|
| 1876 |
|
|
|
|
| 1878 |
heads_str = ', '.join([f"H{h}" for h in sorted(selected_heads)])
|
| 1879 |
success_message = html.Div([
|
| 1880 |
html.I(className="fas fa-check-circle", style={'marginRight': '8px', 'color': '#28a745'}),
|
| 1881 |
+
f"Ablation complete: Layer {layer_num}, Heads {heads_str} removed. Scroll down for sequence analysis."
|
| 1882 |
], className="status-success")
|
| 1883 |
|
| 1884 |
return ablated_data, original_data, success_message
|
|
|
|
| 1940 |
return original_data, {}, success_message
|
| 1941 |
|
| 1942 |
|
| 1943 |
+
|
| 1944 |
+
# Callback to update sequence ablation analysis view
|
| 1945 |
+
@app.callback(
|
| 1946 |
+
[Output('sequence-ablation-results-container', 'children'),
|
| 1947 |
+
Output('sequence-ablation-results-container', 'style')],
|
| 1948 |
+
Input('session-activation-store', 'data'),
|
| 1949 |
+
prevent_initial_call=False
|
| 1950 |
+
)
|
| 1951 |
+
def update_sequence_ablation_view(activation_data):
|
| 1952 |
+
"""Update the sequence ablation results view (KL Divergence, Sequence Comparison)."""
|
| 1953 |
+
if not activation_data or not activation_data.get('ablated', False):
|
| 1954 |
+
return [], {'display': 'none'}
|
| 1955 |
+
|
| 1956 |
+
try:
|
| 1957 |
+
import plotly.graph_objs as go
|
| 1958 |
+
from dash import html, dcc
|
| 1959 |
+
|
| 1960 |
+
children = []
|
| 1961 |
+
|
| 1962 |
+
# 1. Header
|
| 1963 |
+
children.append(html.H3("Full Sequence Ablation Analysis", style={'marginTop': '0', 'marginBottom': '20px', 'color': '#2d3748'}))
|
| 1964 |
+
|
| 1965 |
+
# 2. Top-5 Sequence Comparison Table
|
| 1966 |
+
top_seqs = activation_data.get('top_sequences_comparison', [])
|
| 1967 |
+
if top_seqs:
|
| 1968 |
+
rows = []
|
| 1969 |
+
for i, seq in enumerate(top_seqs):
|
| 1970 |
+
delta = seq['delta']
|
| 1971 |
+
delta_color = '#28a745' if delta > 0 else '#dc3545' if delta < 0 else '#6c757d'
|
| 1972 |
+
|
| 1973 |
+
rows.append(html.Tr([
|
| 1974 |
+
html.Td(f"#{i+1}", style={'fontWeight': 'bold'}),
|
| 1975 |
+
html.Td(seq['text'], style={'fontFamily': 'monospace', 'maxWidth': '400px', 'overflow': 'hidden', 'textOverflow': 'ellipsis', 'whiteSpace': 'nowrap'}),
|
| 1976 |
+
html.Td(f"{seq['original_score']:.4f}"),
|
| 1977 |
+
html.Td(f"{seq['ablated_score']:.4f}"),
|
| 1978 |
+
html.Td(f"{delta:+.4f}", style={'color': delta_color, 'fontWeight': 'bold'})
|
| 1979 |
+
]))
|
| 1980 |
+
|
| 1981 |
+
table_header = html.Thead(html.Tr([
|
| 1982 |
+
html.Th("Rank"), html.Th("Sequence"), html.Th("Original Score"), html.Th("Ablated Score"), html.Th("Delta")
|
| 1983 |
+
]))
|
| 1984 |
+
table_body = html.Tbody(rows)
|
| 1985 |
+
|
| 1986 |
+
children.append(html.Div([
|
| 1987 |
+
html.H5("Top Sequences Impact", style={'marginBottom': '10px'}),
|
| 1988 |
+
html.Table([table_header, table_body], className="table table-striped table-bordered")
|
| 1989 |
+
], style={'marginBottom': '30px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
|
| 1990 |
+
|
| 1991 |
+
# 3. KL Divergence Chart
|
| 1992 |
+
seq_metrics = activation_data.get('sequence_metrics', {})
|
| 1993 |
+
kl_divs = seq_metrics.get('kl_divergence', [])
|
| 1994 |
+
tokens = seq_metrics.get('tokens', [])
|
| 1995 |
+
|
| 1996 |
+
if kl_divs:
|
| 1997 |
+
# KL Chart
|
| 1998 |
+
fig_kl = go.Figure()
|
| 1999 |
+
fig_kl.add_trace(go.Scatter(
|
| 2000 |
+
x=list(range(len(kl_divs))),
|
| 2001 |
+
y=kl_divs,
|
| 2002 |
+
mode='lines+markers',
|
| 2003 |
+
name='KL Divergence',
|
| 2004 |
+
line=dict(color='#6610f2', width=2),
|
| 2005 |
+
hovertext=[f"Token: {t}<br>KL: {v:.4f}" for t, v in zip(tokens, kl_divs)],
|
| 2006 |
+
hoverinfo='text'
|
| 2007 |
+
))
|
| 2008 |
+
|
| 2009 |
+
fig_kl.update_layout(
|
| 2010 |
+
title="KL Divergence per Position (Distribution Shift)",
|
| 2011 |
+
xaxis_title="Position / Token",
|
| 2012 |
+
yaxis_title="KL Divergence (nats)",
|
| 2013 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
| 2014 |
+
height=300,
|
| 2015 |
+
xaxis=dict(
|
| 2016 |
+
tickmode='array',
|
| 2017 |
+
tickvals=list(range(len(tokens))),
|
| 2018 |
+
ticktext=tokens
|
| 2019 |
+
)
|
| 2020 |
+
)
|
| 2021 |
+
|
| 2022 |
+
children.append(html.Div([
|
| 2023 |
+
dcc.Graph(figure=fig_kl, config={'displayModeBar': False})
|
| 2024 |
+
], style={'marginBottom': '20px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
|
| 2025 |
+
|
| 2026 |
+
# 4. Target Probability Deltas Chart
|
| 2027 |
+
prob_deltas = seq_metrics.get('probability_deltas', [])
|
| 2028 |
+
if prob_deltas:
|
| 2029 |
+
# Shift tokens for x-axis (deltas are for prediction of next token)
|
| 2030 |
+
# Input: T0, T1, T2
|
| 2031 |
+
# Delta 0: Change in P(T1|T0)
|
| 2032 |
+
# So x-axis should be T1, T2...
|
| 2033 |
+
target_tokens = tokens[1:] if len(tokens) > 1 else []
|
| 2034 |
+
|
| 2035 |
+
fig_delta = go.Figure()
|
| 2036 |
+
fig_delta.add_trace(go.Bar(
|
| 2037 |
+
x=list(range(len(prob_deltas))),
|
| 2038 |
+
y=prob_deltas,
|
| 2039 |
+
name='Prob Delta',
|
| 2040 |
+
marker_color=['#28a745' if v >= 0 else '#dc3545' for v in prob_deltas],
|
| 2041 |
+
hovertext=[f"Target: {t}<br>Change: {v:+.4f}" for t, v in zip(target_tokens, prob_deltas)],
|
| 2042 |
+
hoverinfo='text'
|
| 2043 |
+
))
|
| 2044 |
+
|
| 2045 |
+
fig_delta.update_layout(
|
| 2046 |
+
title="Target Probability Change per Position",
|
| 2047 |
+
xaxis_title="Target Token",
|
| 2048 |
+
yaxis_title="Probability Delta",
|
| 2049 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
| 2050 |
+
height=300,
|
| 2051 |
+
xaxis=dict(
|
| 2052 |
+
tickmode='array',
|
| 2053 |
+
tickvals=list(range(len(target_tokens))),
|
| 2054 |
+
ticktext=target_tokens
|
| 2055 |
+
)
|
| 2056 |
+
)
|
| 2057 |
+
|
| 2058 |
+
children.append(html.Div([
|
| 2059 |
+
dcc.Graph(figure=fig_delta, config={'displayModeBar': False})
|
| 2060 |
+
], style={'marginBottom': '20px', 'padding': '15px', 'backgroundColor': '#fff', 'borderRadius': '8px', 'boxShadow': '0 2px 4px rgba(0,0,0,0.05)'}))
|
| 2061 |
+
|
| 2062 |
+
return children, {'display': 'block', 'marginTop': '30px', 'paddingTop': '30px', 'borderTop': '1px solid #dee2e6'}
|
| 2063 |
+
|
| 2064 |
+
except Exception as e:
|
| 2065 |
+
print(f"Error in ablation view: {e}")
|
| 2066 |
+
import traceback
|
| 2067 |
+
traceback.print_exc()
|
| 2068 |
+
return html.Div(f"Error loading visualization: {str(e)}"), {'display': 'block'}
|
| 2069 |
+
|
| 2070 |
+
|
| 2071 |
if __name__ == '__main__':
|
| 2072 |
app.run(debug=True, port=8050)
|
components/main_panel.py
CHANGED
|
@@ -105,7 +105,10 @@ def create_main_panel():
|
|
| 105 |
html.I(className="fas fa-spinner fa-spin", style={'fontSize': '24px', 'color': '#667eea', 'marginRight': '10px'}),
|
| 106 |
html.Span("Loading visuals...", style={'fontSize': '16px', 'color': '#495057'})
|
| 107 |
], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center', 'padding': '2rem'})
|
| 108 |
-
)
|
|
|
|
|
|
|
|
|
|
| 109 |
], className="visualization-section")
|
| 110 |
])
|
| 111 |
], id="analysis-view-container", style={'display': 'none'}) # Hidden by default
|
|
|
|
| 105 |
html.I(className="fas fa-spinner fa-spin", style={'fontSize': '24px', 'color': '#667eea', 'marginRight': '10px'}),
|
| 106 |
html.Span("Loading visuals...", style={'fontSize': '16px', 'color': '#495057'})
|
| 107 |
], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center', 'padding': '2rem'})
|
| 108 |
+
),
|
| 109 |
+
|
| 110 |
+
# Sequence Ablation Results (New)
|
| 111 |
+
html.Div(id='sequence-ablation-results-container', style={'marginTop': '30px', 'display': 'none'})
|
| 112 |
], className="visualization-section")
|
| 113 |
])
|
| 114 |
], id="analysis-view-container", style={'display': 'none'}) # Hidden by default
|
utils/__init__.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
-
from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation, execute_forward_pass_with_head_ablation, merge_token_probabilities, compute_global_top5_tokens, detect_significant_probability_increases, compute_layer_wise_summaries
|
| 2 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 3 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
| 4 |
from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
|
| 5 |
from .beam_search import perform_beam_search, compute_sequence_trajectory
|
|
|
|
|
|
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
'load_model_and_get_patterns',
|
| 9 |
'execute_forward_pass',
|
| 10 |
'execute_forward_pass_with_layer_ablation',
|
| 11 |
'execute_forward_pass_with_head_ablation',
|
|
|
|
| 12 |
'logit_lens_transformation',
|
| 13 |
'extract_layer_data',
|
| 14 |
'generate_bertviz_html',
|
|
@@ -32,5 +35,9 @@ __all__ = [
|
|
| 32 |
'format_comparison_summary',
|
| 33 |
'ComparisonConfig',
|
| 34 |
'perform_beam_search',
|
| 35 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
]
|
|
|
|
| 1 |
+
from .model_patterns import load_model_and_get_patterns, execute_forward_pass, logit_lens_transformation, extract_layer_data, generate_bertviz_html, generate_category_bertviz_html, get_check_token_probabilities, execute_forward_pass_with_layer_ablation, execute_forward_pass_with_head_ablation, merge_token_probabilities, compute_global_top5_tokens, detect_significant_probability_increases, compute_layer_wise_summaries, evaluate_sequence_ablation
|
| 2 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 3 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
| 4 |
from .prompt_comparison import compare_attention_layers, compare_output_probabilities, format_comparison_summary, ComparisonConfig
|
| 5 |
from .beam_search import perform_beam_search, compute_sequence_trajectory
|
| 6 |
+
from .ablation_metrics import compute_kl_divergence, score_sequence, get_token_probability_deltas
|
| 7 |
+
|
| 8 |
|
| 9 |
__all__ = [
|
| 10 |
'load_model_and_get_patterns',
|
| 11 |
'execute_forward_pass',
|
| 12 |
'execute_forward_pass_with_layer_ablation',
|
| 13 |
'execute_forward_pass_with_head_ablation',
|
| 14 |
+
'evaluate_sequence_ablation',
|
| 15 |
'logit_lens_transformation',
|
| 16 |
'extract_layer_data',
|
| 17 |
'generate_bertviz_html',
|
|
|
|
| 35 |
'format_comparison_summary',
|
| 36 |
'ComparisonConfig',
|
| 37 |
'perform_beam_search',
|
| 38 |
+
'perform_beam_search',
|
| 39 |
+
'compute_sequence_trajectory',
|
| 40 |
+
'compute_kl_divergence',
|
| 41 |
+
'score_sequence',
|
| 42 |
+
'get_token_probability_deltas'
|
| 43 |
]
|
utils/ablation_metrics.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import List, Dict, Any, Tuple, Optional
|
| 5 |
+
|
| 6 |
+
def compute_kl_divergence(logits_p: torch.Tensor, logits_q: torch.Tensor) -> List[float]:
|
| 7 |
+
"""
|
| 8 |
+
Compute KL Divergence KL(P || Q) for each position in the sequence.
|
| 9 |
+
P is the reference distribution (logits_p), Q is the ablated distribution (logits_q).
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
logits_p: Reference logits [batch, seq_len, vocab_size]
|
| 13 |
+
logits_q: Ablated logits [batch, seq_len, vocab_size]
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
List of KL divergence values for each position.
|
| 17 |
+
"""
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
# Ensure batch size 1 or handle appropriately
|
| 20 |
+
if logits_p.dim() == 3:
|
| 21 |
+
logits_p = logits_p.squeeze(0)
|
| 22 |
+
if logits_q.dim() == 3:
|
| 23 |
+
logits_q = logits_q.squeeze(0)
|
| 24 |
+
|
| 25 |
+
# P = softmax(logits_p)
|
| 26 |
+
# Q = softmax(logits_q)
|
| 27 |
+
# KL(P||Q) = sum(P * (log P - log Q))
|
| 28 |
+
|
| 29 |
+
# Use log_softmax for stability
|
| 30 |
+
log_probs_p = F.log_softmax(logits_p, dim=-1)
|
| 31 |
+
log_probs_q = F.log_softmax(logits_q, dim=-1)
|
| 32 |
+
probs_p = torch.exp(log_probs_p)
|
| 33 |
+
|
| 34 |
+
# Element-wise KL
|
| 35 |
+
kl_divs = torch.sum(probs_p * (log_probs_p - log_probs_q), dim=-1)
|
| 36 |
+
|
| 37 |
+
return kl_divs.tolist()
|
| 38 |
+
|
| 39 |
+
def score_sequence(model, tokenizer, text: str) -> float:
|
| 40 |
+
"""
|
| 41 |
+
Compute the total log probability (score) of a text sequence.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model: HuggingFace model
|
| 45 |
+
tokenizer: Tokenizer
|
| 46 |
+
text: The sequence to score
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Total log probability.
|
| 50 |
+
"""
|
| 51 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 52 |
+
input_ids = inputs["input_ids"].to(model.device)
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
outputs = model(input_ids)
|
| 56 |
+
logits = outputs.logits # [1, seq_len, vocab_size]
|
| 57 |
+
|
| 58 |
+
# We want P(token_i | tokens_<i)
|
| 59 |
+
# The logits at position i-1 predict position i
|
| 60 |
+
|
| 61 |
+
# Shift logits and labels
|
| 62 |
+
shift_logits = logits[0, :-1, :].contiguous()
|
| 63 |
+
shift_labels = input_ids[0, 1:].contiguous()
|
| 64 |
+
|
| 65 |
+
# Helper to pick specific token probabilities
|
| 66 |
+
# log_softmax
|
| 67 |
+
log_probs_all = F.log_softmax(shift_logits, dim=-1)
|
| 68 |
+
|
| 69 |
+
# Gather only the target label log probs
|
| 70 |
+
# gather needs index column vector
|
| 71 |
+
target_log_probs = log_probs_all.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
|
| 72 |
+
|
| 73 |
+
total_score = target_log_probs.sum().item()
|
| 74 |
+
|
| 75 |
+
return total_score
|
| 76 |
+
|
| 77 |
+
def get_token_probability_deltas(logits_ref: torch.Tensor, logits_abl: torch.Tensor, input_ids: torch.Tensor) -> List[float]:
|
| 78 |
+
"""
|
| 79 |
+
Compute the change in probability (Prob_abl - Prob_ref) for the actual target tokens.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
logits_ref: Reference logits
|
| 83 |
+
logits_abl: Ablated logits
|
| 84 |
+
input_ids: The sequence token IDs [1, seq_len]
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
List of probability deltas for each position (starting from first prediction).
|
| 88 |
+
"""
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
if logits_ref.dim() == 3: logits_ref = logits_ref.squeeze(0)
|
| 91 |
+
if logits_abl.dim() == 3: logits_abl = logits_abl.squeeze(0)
|
| 92 |
+
|
| 93 |
+
target_ids = input_ids[0, 1:] # Targets are from index 1 onwards
|
| 94 |
+
|
| 95 |
+
# Probabilities
|
| 96 |
+
probs_ref = F.softmax(logits_ref[:-1], dim=-1) # Predicts 1..N
|
| 97 |
+
probs_abl = F.softmax(logits_abl[:-1], dim=-1)
|
| 98 |
+
|
| 99 |
+
# Gather target probs
|
| 100 |
+
ref_target_probs = probs_ref.gather(1, target_ids.unsqueeze(1)).squeeze(1)
|
| 101 |
+
abl_target_probs = probs_abl.gather(1, target_ids.unsqueeze(1)).squeeze(1)
|
| 102 |
+
|
| 103 |
+
deltas = (abl_target_probs - ref_target_probs).tolist()
|
| 104 |
+
|
| 105 |
+
return deltas
|
utils/model_patterns.py
CHANGED
|
@@ -614,6 +614,160 @@ def execute_forward_pass_with_layer_ablation(model, tokenizer, prompt: str, conf
|
|
| 614 |
return result
|
| 615 |
|
| 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 618 |
"""
|
| 619 |
Transform layer output to top K token probabilities using logit lens.
|
|
|
|
| 614 |
return result
|
| 615 |
|
| 616 |
|
| 617 |
+
def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dict[str, Any],
|
| 618 |
+
ablation_type: str, ablation_target: Any) -> Dict[str, Any]:
|
| 619 |
+
"""
|
| 620 |
+
Evaluate the impact of ablation on a full sequence.
|
| 621 |
+
|
| 622 |
+
This runs TWO forward passes on the FULL sequence:
|
| 623 |
+
1. Reference pass (original model) -> Capture logits/probs
|
| 624 |
+
2. Ablated pass (modified model) -> Capture logits/probs
|
| 625 |
+
|
| 626 |
+
Then computes metrics: KL Divergence, Target Prob Changes.
|
| 627 |
+
|
| 628 |
+
Args:
|
| 629 |
+
model: Loaded transformer model
|
| 630 |
+
tokenizer: Tokenizer
|
| 631 |
+
sequence_text: The full text sequence to evaluate
|
| 632 |
+
config: Module configuration (needed for ablation setup)
|
| 633 |
+
ablation_type: 'head' or 'layer'
|
| 634 |
+
ablation_target: tuple (layer, head_indices) or int (layer_num)
|
| 635 |
+
|
| 636 |
+
Returns:
|
| 637 |
+
Dict with evaluation metrics.
|
| 638 |
+
"""
|
| 639 |
+
from .ablation_metrics import compute_kl_divergence, get_token_probability_deltas
|
| 640 |
+
|
| 641 |
+
print(f"Evaluating sequence ablation: Type={ablation_type}, Target={ablation_target}")
|
| 642 |
+
|
| 643 |
+
inputs = tokenizer(sequence_text, return_tensors="pt")
|
| 644 |
+
input_ids = inputs["input_ids"].to(model.device)
|
| 645 |
+
|
| 646 |
+
# --- 1. Reference Pass ---
|
| 647 |
+
with torch.no_grad():
|
| 648 |
+
outputs_ref = model(input_ids)
|
| 649 |
+
logits_ref = outputs_ref.logits # [1, seq_len, vocab_size]
|
| 650 |
+
|
| 651 |
+
# --- 2. Ablated Pass ---
|
| 652 |
+
# Setup ablation based on type
|
| 653 |
+
|
| 654 |
+
# We need to wrap the model using PyVene logic or custom hooks just for this pass
|
| 655 |
+
# Since we already have logic in execute_forward_pass_with_..._ablation, we can reuse the Hook logic
|
| 656 |
+
# But we want the full logits, not just captured activations.
|
| 657 |
+
|
| 658 |
+
# Let's manually register hooks here for simplicity and control
|
| 659 |
+
hooks = []
|
| 660 |
+
|
| 661 |
+
def head_ablation_hook_factory(layer_idx, head_indices):
|
| 662 |
+
def hook(module, input, output):
|
| 663 |
+
# output is (hidden_states, ...) or hidden_states
|
| 664 |
+
if isinstance(output, tuple):
|
| 665 |
+
hidden_states = output[0]
|
| 666 |
+
else:
|
| 667 |
+
hidden_states = output
|
| 668 |
+
|
| 669 |
+
# Assume hidden_states is [batch, seq, hidden]
|
| 670 |
+
# Reshape, zero out heads, Reshape back
|
| 671 |
+
if not isinstance(hidden_states, torch.Tensor):
|
| 672 |
+
if isinstance(hidden_states, list): hidden_states = torch.tensor(hidden_states)
|
| 673 |
+
|
| 674 |
+
# Move to device if needed? They should be on device.
|
| 675 |
+
|
| 676 |
+
num_heads = model.config.num_attention_heads
|
| 677 |
+
head_dim = hidden_states.shape[-1] // num_heads
|
| 678 |
+
|
| 679 |
+
# view: [batch, seq, heads, dim]
|
| 680 |
+
new_shape = hidden_states.shape[:-1] + (num_heads, head_dim)
|
| 681 |
+
reshaped = hidden_states.view(new_shape)
|
| 682 |
+
|
| 683 |
+
# Create mask or just zero out
|
| 684 |
+
# We can't modify in place securely with autograd usually, but here no_grad is on.
|
| 685 |
+
# Clone to be safe
|
| 686 |
+
reshaped = reshaped.clone()
|
| 687 |
+
|
| 688 |
+
for h_idx in head_indices:
|
| 689 |
+
reshaped[..., h_idx, :] = 0
|
| 690 |
+
|
| 691 |
+
ablated_hidden = reshaped.view(hidden_states.shape)
|
| 692 |
+
|
| 693 |
+
if isinstance(output, tuple):
|
| 694 |
+
return (ablated_hidden,) + output[1:]
|
| 695 |
+
return ablated_hidden
|
| 696 |
+
return hook
|
| 697 |
+
|
| 698 |
+
# Hook for Layer Ablation (Identity/Skip or Zero)
|
| 699 |
+
# We'll use Identity (Skip Layer) as a simpler approximation of "removing logic"
|
| 700 |
+
# OR Mean Ablation if we had the mean.
|
| 701 |
+
# For now, let's just do nothing for layer ablation or return error,
|
| 702 |
+
# as the user primarily asks for "ablation experiment updates" which often means Heads.
|
| 703 |
+
# But to be safe, let's implement the same Mean Ablation if possible, or Identity.
|
| 704 |
+
# Identity (Skip) is easier:
|
| 705 |
+
def identity_hook(module, input, output):
|
| 706 |
+
# input is tuple (hidden_states, ...)
|
| 707 |
+
return input if isinstance(input, tuple) else (input,)
|
| 708 |
+
|
| 709 |
+
try:
|
| 710 |
+
if ablation_type == 'head':
|
| 711 |
+
layer_num, head_indices = ablation_target
|
| 712 |
+
# Find module
|
| 713 |
+
# Standard transformers: model.layers[i].self_attn
|
| 714 |
+
# We need the exact module name map standard to HuggingFace
|
| 715 |
+
# Or use the config's mapping if available.
|
| 716 |
+
# Let's rely on standard naming or search
|
| 717 |
+
|
| 718 |
+
# Simple heuristic: find 'layers.X.self_attn' or 'h.X.attn'
|
| 719 |
+
target_module = None
|
| 720 |
+
for name, mod in model.named_modules():
|
| 721 |
+
# Check for standard patterns
|
| 722 |
+
# layer_num is int
|
| 723 |
+
if f"layers.{layer_num}.self_attn" in name or f"h.{layer_num}.attn" in name or f"blocks.{layer_num}.attn" in name:
|
| 724 |
+
if "k_proj" not in name and "v_proj" not in name and "q_proj" not in name: # avoid submodules
|
| 725 |
+
target_module = mod
|
| 726 |
+
break
|
| 727 |
+
|
| 728 |
+
if target_module:
|
| 729 |
+
hooks.append(target_module.register_forward_hook(head_ablation_hook_factory(layer_num, head_indices)))
|
| 730 |
+
else:
|
| 731 |
+
print(f"Warning: Could not find attention module for layer {layer_num}")
|
| 732 |
+
|
| 733 |
+
elif ablation_type == 'layer':
|
| 734 |
+
layer_num = ablation_target
|
| 735 |
+
target_module = None
|
| 736 |
+
for name, mod in model.named_modules():
|
| 737 |
+
# Layers are usually 'model.layers.X' or 'transformer.h.X'
|
| 738 |
+
# We want the module that corresponds to the layer block
|
| 739 |
+
# Be careful not to pick 'layers.X.mlp'
|
| 740 |
+
if (f"layers.{layer_num}" in name or f"h.{layer_num}" in name) and name.count('.') <= 2: # heuristic for top-level layer
|
| 741 |
+
target_module = mod
|
| 742 |
+
break
|
| 743 |
+
|
| 744 |
+
if target_module:
|
| 745 |
+
# Skip layer (Identity)
|
| 746 |
+
hooks.append(target_module.register_forward_hook(lambda m, i, o: i[0] if isinstance(i, tuple) else i))
|
| 747 |
+
|
| 748 |
+
# Run Ablated Pass
|
| 749 |
+
with torch.no_grad():
|
| 750 |
+
outputs_abl = model(input_ids)
|
| 751 |
+
logits_abl = outputs_abl.logits
|
| 752 |
+
|
| 753 |
+
finally:
|
| 754 |
+
for hook in hooks:
|
| 755 |
+
hook.remove()
|
| 756 |
+
|
| 757 |
+
# --- 3. Compute Metrics ---
|
| 758 |
+
# KL Divergence [seq_len]
|
| 759 |
+
kl_div = compute_kl_divergence(logits_ref, logits_abl)
|
| 760 |
+
|
| 761 |
+
# Prob Deltas for actual tokens [seq_len-1] (shifted)
|
| 762 |
+
prob_deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
|
| 763 |
+
|
| 764 |
+
return {
|
| 765 |
+
"kl_divergence": kl_div,
|
| 766 |
+
"probability_deltas": prob_deltas,
|
| 767 |
+
"tokens": [tokenizer.decode([tid]) for tid in input_ids[0].tolist()]
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
|
| 771 |
def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 772 |
"""
|
| 773 |
Transform layer output to top K token probabilities using logit lens.
|