Spaces:
Running
Running
Commit ·
d60cfe2
1
Parent(s): 3f991b4
feat(output): Add token scrubber with per-position top-5 next-token probabilities
Browse files- app.py +62 -7
- components/pipeline.py +219 -46
- plans.md +6 -5
- tests/conftest.py +4 -0
- tests/test_model_patterns.py +130 -0
- todo.md +17 -0
- utils/__init__.py +3 -1
- utils/model_patterns.py +128 -3
app.py
CHANGED
|
@@ -28,7 +28,8 @@ from components.model_selector import create_model_selector
|
|
| 28 |
from components.glossary import create_glossary_modal
|
| 29 |
from components.pipeline import (create_pipeline_container, create_tokenization_content,
|
| 30 |
create_embedding_content, create_attention_content,
|
| 31 |
-
create_mlp_content, create_output_content
|
|
|
|
| 32 |
from components.investigation_panel import create_investigation_panel, create_attribution_results_display
|
| 33 |
from components.ablation_panel import create_selected_heads_display, create_ablation_results_display
|
| 34 |
from components.chatbot import create_chatbot_container, render_messages
|
|
@@ -375,9 +376,12 @@ def run_generation(n_clicks, model_name, prompt, max_new_tokens, beam_width, pat
|
|
| 375 |
# the full-sequence analysis runs when the user selects a beam.
|
| 376 |
if max_new_tokens == 1:
|
| 377 |
full_text = results[0]['text']
|
|
|
|
|
|
|
|
|
|
| 378 |
else:
|
| 379 |
full_text = prompt
|
| 380 |
-
|
| 381 |
|
| 382 |
results_ui = []
|
| 383 |
if max_new_tokens > 1:
|
|
@@ -418,10 +422,11 @@ def run_generation(n_clicks, model_name, prompt, max_new_tokens, beam_width, pat
|
|
| 418 |
Output('session-activation-store', 'data', allow_duplicate=True)],
|
| 419 |
Input({'type': 'result-item', 'index': ALL}, 'n_clicks'),
|
| 420 |
[State('generation-results-store', 'data'),
|
| 421 |
-
State('session-activation-store', 'data')
|
|
|
|
| 422 |
prevent_initial_call=True
|
| 423 |
)
|
| 424 |
-
def store_selected_beam(n_clicks_list, results_data, existing_activation_data):
|
| 425 |
"""
|
| 426 |
Store selected beam and re-run forward pass on the full sequence.
|
| 427 |
|
|
@@ -490,7 +495,12 @@ def store_selected_beam(n_clicks_list, results_data, existing_activation_data):
|
|
| 490 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 491 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 492 |
model.eval()
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
except Exception as e:
|
| 495 |
import traceback
|
| 496 |
traceback.print_exc()
|
|
@@ -594,9 +604,19 @@ def update_pipeline_content(activation_data, model_name):
|
|
| 594 |
# Stage 5: Output
|
| 595 |
# Get original prompt for context display
|
| 596 |
original_prompt = activation_data.get('prompt', '')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
outputs.append(f"→ {predicted_token}")
|
| 598 |
-
outputs.append(create_output_content(
|
| 599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
|
| 601 |
return tuple(outputs)
|
| 602 |
|
|
@@ -606,6 +626,41 @@ def update_pipeline_content(activation_data, model_name):
|
|
| 606 |
return tuple(empty_outputs)
|
| 607 |
|
| 608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
# ============================================================================
|
| 610 |
# CALLBACKS: Sidebar
|
| 611 |
# ============================================================================
|
|
|
|
| 28 |
from components.glossary import create_glossary_modal
|
| 29 |
from components.pipeline import (create_pipeline_container, create_tokenization_content,
|
| 30 |
create_embedding_content, create_attention_content,
|
| 31 |
+
create_mlp_content, create_output_content,
|
| 32 |
+
_build_token_display, _build_top5_chart)
|
| 33 |
from components.investigation_panel import create_investigation_panel, create_attribution_results_display
|
| 34 |
from components.ablation_panel import create_selected_heads_display, create_ablation_results_display
|
| 35 |
from components.chatbot import create_chatbot_container, render_messages
|
|
|
|
| 376 |
# the full-sequence analysis runs when the user selects a beam.
|
| 377 |
if max_new_tokens == 1:
|
| 378 |
full_text = results[0]['text']
|
| 379 |
+
# Pass original_prompt so per-position top-5 is computed for the scrubber
|
| 380 |
+
activation_data = execute_forward_pass(model, tokenizer, full_text, config,
|
| 381 |
+
original_prompt=prompt)
|
| 382 |
else:
|
| 383 |
full_text = prompt
|
| 384 |
+
activation_data = execute_forward_pass(model, tokenizer, full_text, config)
|
| 385 |
|
| 386 |
results_ui = []
|
| 387 |
if max_new_tokens > 1:
|
|
|
|
| 422 |
Output('session-activation-store', 'data', allow_duplicate=True)],
|
| 423 |
Input({'type': 'result-item', 'index': ALL}, 'n_clicks'),
|
| 424 |
[State('generation-results-store', 'data'),
|
| 425 |
+
State('session-activation-store', 'data'),
|
| 426 |
+
State('session-original-prompt-store', 'data')],
|
| 427 |
prevent_initial_call=True
|
| 428 |
)
|
| 429 |
+
def store_selected_beam(n_clicks_list, results_data, existing_activation_data, original_prompt_data):
|
| 430 |
"""
|
| 431 |
Store selected beam and re-run forward pass on the full sequence.
|
| 432 |
|
|
|
|
| 495 |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
|
| 496 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 497 |
model.eval()
|
| 498 |
+
# Pass original_prompt so per-position top-5 data is computed for scrubber
|
| 499 |
+
orig_prompt = original_prompt_data.get('prompt', '') if original_prompt_data else ''
|
| 500 |
+
new_activation_data = execute_forward_pass(
|
| 501 |
+
model, tokenizer, result['text'], config,
|
| 502 |
+
original_prompt=orig_prompt
|
| 503 |
+
)
|
| 504 |
except Exception as e:
|
| 505 |
import traceback
|
| 506 |
traceback.print_exc()
|
|
|
|
| 604 |
# Stage 5: Output
|
| 605 |
# Get original prompt for context display
|
| 606 |
original_prompt = activation_data.get('prompt', '')
|
| 607 |
+
# Per-position data for the scrubber (populated when original_prompt was given)
|
| 608 |
+
per_position_data = activation_data.get('per_position_top5', [])
|
| 609 |
+
generated_tokens = activation_data.get('generated_tokens', [])
|
| 610 |
+
scrubber_prompt = activation_data.get('original_prompt', original_prompt)
|
| 611 |
+
|
| 612 |
outputs.append(f"→ {predicted_token}")
|
| 613 |
+
outputs.append(create_output_content(
|
| 614 |
+
top_tokens, predicted_token, predicted_prob,
|
| 615 |
+
original_prompt=original_prompt,
|
| 616 |
+
per_position_data=per_position_data,
|
| 617 |
+
generated_tokens=generated_tokens,
|
| 618 |
+
prompt_text=scrubber_prompt
|
| 619 |
+
))
|
| 620 |
|
| 621 |
return tuple(outputs)
|
| 622 |
|
|
|
|
| 626 |
return tuple(empty_outputs)
|
| 627 |
|
| 628 |
|
| 629 |
+
# ============================================================================
|
| 630 |
+
# CALLBACKS: Output Scrubber
|
| 631 |
+
# ============================================================================
|
| 632 |
+
|
| 633 |
+
@app.callback(
|
| 634 |
+
[Output('output-token-display', 'children'),
|
| 635 |
+
Output('output-top5-chart', 'children')],
|
| 636 |
+
[Input('output-scrubber-slider', 'value')],
|
| 637 |
+
[State('session-activation-store', 'data')],
|
| 638 |
+
prevent_initial_call=True
|
| 639 |
+
)
|
| 640 |
+
def update_output_scrubber(position, activation_data):
|
| 641 |
+
"""Update the token display and top-5 chart when the scrubber slider moves."""
|
| 642 |
+
if activation_data is None or position is None:
|
| 643 |
+
return no_update, no_update
|
| 644 |
+
|
| 645 |
+
per_position_data = activation_data.get('per_position_top5', [])
|
| 646 |
+
generated_tokens = activation_data.get('generated_tokens', [])
|
| 647 |
+
prompt_text = activation_data.get('original_prompt', activation_data.get('prompt', ''))
|
| 648 |
+
|
| 649 |
+
if not per_position_data or not generated_tokens:
|
| 650 |
+
return no_update, no_update
|
| 651 |
+
|
| 652 |
+
# Clamp position to valid range
|
| 653 |
+
position = max(0, min(position, len(per_position_data) - 1))
|
| 654 |
+
pos_data = per_position_data[position]
|
| 655 |
+
|
| 656 |
+
token_display = _build_token_display(
|
| 657 |
+
prompt_text, generated_tokens, position, pos_data['actual_prob']
|
| 658 |
+
)
|
| 659 |
+
top5_chart = _build_top5_chart(pos_data['top5'], pos_data.get('actual_token'))
|
| 660 |
+
|
| 661 |
+
return token_display, top5_chart
|
| 662 |
+
|
| 663 |
+
|
| 664 |
# ============================================================================
|
| 665 |
# CALLBACKS: Sidebar
|
| 666 |
# ============================================================================
|
components/pipeline.py
CHANGED
|
@@ -689,18 +689,218 @@ def create_mlp_content(layer_count=None, hidden_dim=None, intermediate_dim=None)
|
|
| 689 |
])
|
| 690 |
|
| 691 |
|
| 692 |
-
def
|
| 693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
"""
|
| 695 |
Create content for the output selection stage.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
Args:
|
| 698 |
-
top_tokens: List of (token, probability) tuples for top predictions
|
| 699 |
-
predicted_token: The final predicted token
|
| 700 |
-
predicted_prob: Probability of the predicted token
|
| 701 |
-
top5_chart: Optional Plotly figure for top-5 visualization
|
| 702 |
-
original_prompt: Original input prompt to show context with prediction
|
|
|
|
|
|
|
|
|
|
| 703 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
content_items = [
|
| 705 |
html.Div([
|
| 706 |
html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
|
|
@@ -711,75 +911,49 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
|
|
| 711 |
])
|
| 712 |
]
|
| 713 |
|
| 714 |
-
# Predicted token display with full prompt context
|
| 715 |
if predicted_token:
|
| 716 |
-
# Build the full prompt + predicted token display
|
| 717 |
prompt_display = original_prompt if original_prompt else ""
|
| 718 |
-
|
| 719 |
content_items.append(
|
| 720 |
html.Div([
|
| 721 |
html.Div([
|
| 722 |
html.Span("Model prediction:", style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}),
|
| 723 |
html.Div([
|
| 724 |
-
# Original prompt (dimmed)
|
| 725 |
html.Span(prompt_display, style={
|
| 726 |
-
'color': '#6c757d',
|
| 727 |
-
'fontFamily': 'monospace',
|
| 728 |
-
'fontSize': '15px'
|
| 729 |
}),
|
| 730 |
-
# Predicted token (highlighted)
|
| 731 |
html.Span(predicted_token, style={
|
| 732 |
-
'padding': '4px 8px',
|
| 733 |
-
'
|
| 734 |
-
'
|
| 735 |
-
'
|
| 736 |
-
'fontFamily': 'monospace',
|
| 737 |
-
'fontWeight': '600',
|
| 738 |
-
'fontSize': '15px',
|
| 739 |
-
'marginLeft': '2px'
|
| 740 |
})
|
| 741 |
], style={'display': 'inline'}),
|
| 742 |
-
# Confidence indicator
|
| 743 |
html.Div([
|
| 744 |
html.Span(f"{predicted_prob:.1%} confidence" if predicted_prob else "", style={
|
| 745 |
-
'color': '#6c757d',
|
| 746 |
-
'fontSize': '13px',
|
| 747 |
-
'marginTop': '8px',
|
| 748 |
-
'display': 'block'
|
| 749 |
})
|
| 750 |
])
|
| 751 |
], style={'textAlign': 'center'})
|
| 752 |
-
], style={'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
|
| 753 |
'border': '2px solid #00f2fe', 'marginBottom': '16px'})
|
| 754 |
)
|
| 755 |
|
| 756 |
-
# Top-5 bar chart with improved hover formatting
|
| 757 |
if top_tokens:
|
| 758 |
tokens = [t[0] for t in top_tokens[:5]]
|
| 759 |
probs = [t[1] for t in top_tokens[:5]]
|
| 760 |
|
| 761 |
fig = go.Figure(go.Bar(
|
| 762 |
-
x=probs,
|
| 763 |
-
y=tokens,
|
| 764 |
-
orientation='h',
|
| 765 |
marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
|
| 766 |
-
text=[f"{p:.1%}" for p in probs],
|
| 767 |
-
textposition='outside',
|
| 768 |
-
# Format hover to show "Token (X%)" instead of long decimals
|
| 769 |
hovertemplate='%{y} (%{x:.1%})<extra></extra>'
|
| 770 |
))
|
| 771 |
-
|
| 772 |
fig.update_layout(
|
| 773 |
-
title="Top 5 Predictions",
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
height=250,
|
| 777 |
-
margin=dict(l=20, r=60, t=40, b=20),
|
| 778 |
-
paper_bgcolor='rgba(0,0,0,0)',
|
| 779 |
-
plot_bgcolor='rgba(0,0,0,0)',
|
| 780 |
yaxis=dict(autorange='reversed')
|
| 781 |
)
|
| 782 |
-
|
| 783 |
content_items.append(
|
| 784 |
html.Div([
|
| 785 |
dcc.Graph(figure=fig, config={'displayModeBar': False})
|
|
@@ -792,7 +966,6 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
|
|
| 792 |
], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
|
| 793 |
)
|
| 794 |
|
| 795 |
-
# Disclaimer about token selection drivers
|
| 796 |
content_items.append(
|
| 797 |
html.Div([
|
| 798 |
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
|
|
|
| 689 |
])
|
| 690 |
|
| 691 |
|
| 692 |
+
def _build_token_display(prompt_text, generated_tokens, position, actual_prob):
|
| 693 |
+
"""Build the token display for a given scrubber position.
|
| 694 |
+
|
| 695 |
+
Args:
|
| 696 |
+
prompt_text: Original prompt string.
|
| 697 |
+
generated_tokens: List of generated token strings.
|
| 698 |
+
position: Current slider position (0-indexed into generated_tokens).
|
| 699 |
+
actual_prob: Probability of the highlighted token at this position.
|
| 700 |
+
"""
|
| 701 |
+
# Context = prompt + all generated tokens before the current position
|
| 702 |
+
context_parts = [
|
| 703 |
+
html.Span(prompt_text, style={
|
| 704 |
+
'color': '#6c757d',
|
| 705 |
+
'fontFamily': 'monospace',
|
| 706 |
+
'fontSize': '15px'
|
| 707 |
+
})
|
| 708 |
+
]
|
| 709 |
+
for j in range(position):
|
| 710 |
+
context_parts.append(
|
| 711 |
+
html.Span(generated_tokens[j], style={
|
| 712 |
+
'color': '#6c757d',
|
| 713 |
+
'fontFamily': 'monospace',
|
| 714 |
+
'fontSize': '15px'
|
| 715 |
+
})
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Highlighted token at the current position
|
| 719 |
+
highlighted = html.Span(generated_tokens[position], style={
|
| 720 |
+
'padding': '4px 8px',
|
| 721 |
+
'backgroundColor': '#00f2fe',
|
| 722 |
+
'color': '#1a1a2e',
|
| 723 |
+
'borderRadius': '4px',
|
| 724 |
+
'fontFamily': 'monospace',
|
| 725 |
+
'fontWeight': '600',
|
| 726 |
+
'fontSize': '15px',
|
| 727 |
+
'marginLeft': '2px'
|
| 728 |
+
})
|
| 729 |
+
|
| 730 |
+
confidence = html.Div([
|
| 731 |
+
html.Span(
|
| 732 |
+
f"{actual_prob:.1%} confidence" if actual_prob else "",
|
| 733 |
+
style={'color': '#6c757d', 'fontSize': '13px', 'marginTop': '8px', 'display': 'block'}
|
| 734 |
+
)
|
| 735 |
+
])
|
| 736 |
+
|
| 737 |
+
return html.Div([
|
| 738 |
+
html.Div([
|
| 739 |
+
html.Span(
|
| 740 |
+
f"Token {position + 1} of {len(generated_tokens)}:",
|
| 741 |
+
style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}
|
| 742 |
+
),
|
| 743 |
+
html.Div(context_parts + [highlighted], style={'display': 'inline'}),
|
| 744 |
+
confidence
|
| 745 |
+
], style={'textAlign': 'center'})
|
| 746 |
+
], style={
|
| 747 |
+
'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
|
| 748 |
+
'border': '2px solid #00f2fe', 'marginBottom': '16px'
|
| 749 |
+
})
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def _build_top5_chart(top5_data, actual_token=None):
|
| 753 |
+
"""Build the top-5 bar chart for a single scrubber position.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
top5_data: List of {'token': str, 'probability': float}.
|
| 757 |
+
actual_token: The token that was actually generated (highlighted if present).
|
| 758 |
+
"""
|
| 759 |
+
tokens = [entry['token'] for entry in top5_data]
|
| 760 |
+
probs = [entry['probability'] for entry in top5_data]
|
| 761 |
+
|
| 762 |
+
# Highlight the actual chosen token if it appears in the top 5
|
| 763 |
+
colors = []
|
| 764 |
+
actual_in_top5 = False
|
| 765 |
+
for t in tokens:
|
| 766 |
+
if actual_token and t.strip() == actual_token.strip():
|
| 767 |
+
colors.append('#00f2fe')
|
| 768 |
+
actual_in_top5 = True
|
| 769 |
+
else:
|
| 770 |
+
colors.append('#4facfe')
|
| 771 |
+
|
| 772 |
+
fig = go.Figure(go.Bar(
|
| 773 |
+
x=probs,
|
| 774 |
+
y=tokens,
|
| 775 |
+
orientation='h',
|
| 776 |
+
marker_color=colors,
|
| 777 |
+
text=[f"{p:.1%}" for p in probs],
|
| 778 |
+
textposition='outside',
|
| 779 |
+
hovertemplate='%{y} (%{x:.1%})<extra></extra>'
|
| 780 |
+
))
|
| 781 |
+
|
| 782 |
+
fig.update_layout(
|
| 783 |
+
title="Top 5 Next-Token Predictions",
|
| 784 |
+
xaxis_title="Probability",
|
| 785 |
+
yaxis_title="Token",
|
| 786 |
+
height=250,
|
| 787 |
+
margin=dict(l=20, r=60, t=40, b=20),
|
| 788 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 789 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 790 |
+
yaxis=dict(autorange='reversed')
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
children = [dcc.Graph(figure=fig, config={'displayModeBar': False})]
|
| 794 |
+
|
| 795 |
+
# If the actual token is not in the top 5, add a note below
|
| 796 |
+
if actual_token and not actual_in_top5:
|
| 797 |
+
children.append(html.Div([
|
| 798 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '6px'}),
|
| 799 |
+
html.Span([
|
| 800 |
+
"The actual token \"", html.Strong(actual_token.strip()),
|
| 801 |
+
"\" was not in the top 5 predictions at this position."
|
| 802 |
+
], style={'color': '#6c757d', 'fontSize': '13px'})
|
| 803 |
+
], style={'padding': '8px 12px'}))
|
| 804 |
+
|
| 805 |
+
return html.Div(children, style={
|
| 806 |
+
'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'
|
| 807 |
+
})
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=None,
|
| 811 |
+
top5_chart=None, original_prompt=None,
|
| 812 |
+
per_position_data=None, generated_tokens=None,
|
| 813 |
+
prompt_text=None):
|
| 814 |
"""
|
| 815 |
Create content for the output selection stage.
|
| 816 |
+
|
| 817 |
+
When per_position_data is available the output is an interactive scrubber
|
| 818 |
+
that lets the user step through each generated-token position. Otherwise
|
| 819 |
+
falls back to the previous static display.
|
| 820 |
|
| 821 |
Args:
|
| 822 |
+
top_tokens: List of (token, probability) tuples for top predictions (static mode).
|
| 823 |
+
predicted_token: The final predicted token (static mode).
|
| 824 |
+
predicted_prob: Probability of the predicted token (static mode).
|
| 825 |
+
top5_chart: Optional Plotly figure for top-5 visualization (static mode).
|
| 826 |
+
original_prompt: Original input prompt to show context with prediction (static mode).
|
| 827 |
+
per_position_data: List of per-position dicts from compute_per_position_top5 (scrubber mode).
|
| 828 |
+
generated_tokens: List of generated token strings (scrubber mode).
|
| 829 |
+
prompt_text: Original prompt text for context display (scrubber mode).
|
| 830 |
"""
|
| 831 |
+
# --- Scrubber mode ---
|
| 832 |
+
if per_position_data and generated_tokens:
|
| 833 |
+
num_positions = len(per_position_data)
|
| 834 |
+
prompt_display = prompt_text or original_prompt or ""
|
| 835 |
+
|
| 836 |
+
content_items = [
|
| 837 |
+
html.Div([
|
| 838 |
+
html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
|
| 839 |
+
html.P([
|
| 840 |
+
"The model converts the final hidden state into a ",
|
| 841 |
+
html.Strong("probability distribution"),
|
| 842 |
+
" over all possible next tokens. Use the slider below to step through "
|
| 843 |
+
"each generated token and see the model's top predictions at that point."
|
| 844 |
+
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
|
| 845 |
+
])
|
| 846 |
+
]
|
| 847 |
+
|
| 848 |
+
# Slider / scrubber
|
| 849 |
+
slider_marks = {i: {'label': generated_tokens[i].strip() or repr(generated_tokens[i])}
|
| 850 |
+
for i in range(num_positions)}
|
| 851 |
+
content_items.append(
|
| 852 |
+
html.Div([
|
| 853 |
+
html.Span("Step through generated tokens:",
|
| 854 |
+
style={'color': '#495057', 'fontWeight': '500', 'display': 'block',
|
| 855 |
+
'marginBottom': '8px'}),
|
| 856 |
+
dcc.Slider(
|
| 857 |
+
id='output-scrubber-slider',
|
| 858 |
+
min=0,
|
| 859 |
+
max=max(num_positions - 1, 0),
|
| 860 |
+
step=1,
|
| 861 |
+
value=0,
|
| 862 |
+
marks=slider_marks,
|
| 863 |
+
included=False,
|
| 864 |
+
)
|
| 865 |
+
], style={'marginBottom': '20px', 'padding': '12px 16px',
|
| 866 |
+
'backgroundColor': '#f8f9fa', 'borderRadius': '8px',
|
| 867 |
+
'border': '1px solid #dee2e6'})
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Initial render at position 0
|
| 871 |
+
pos0 = per_position_data[0]
|
| 872 |
+
content_items.append(
|
| 873 |
+
html.Div(
|
| 874 |
+
_build_token_display(prompt_display, generated_tokens, 0, pos0['actual_prob']),
|
| 875 |
+
id='output-token-display'
|
| 876 |
+
)
|
| 877 |
+
)
|
| 878 |
+
content_items.append(
|
| 879 |
+
html.Div(
|
| 880 |
+
_build_top5_chart(pos0['top5'], pos0.get('actual_token')),
|
| 881 |
+
id='output-top5-chart'
|
| 882 |
+
)
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# Disclaimer
|
| 886 |
+
content_items.append(
|
| 887 |
+
html.Div([
|
| 888 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
| 889 |
+
html.Span([
|
| 890 |
+
html.Strong("Note on Token Selection: "),
|
| 891 |
+
"While the probabilities above show the model's raw preference for the immediate next token, the final choice ",
|
| 892 |
+
"can be influenced by other factors. Techniques like ", html.Strong("Beam Search"),
|
| 893 |
+
" look ahead at multiple possible sequences to find the best overall result, rather than just the single most likely token at each step. ",
|
| 894 |
+
"Additionally, architectures like ", html.Strong("Mixture of Experts (MoE)"),
|
| 895 |
+
" might route processing through different specialized internal networks which can impact the final output distribution."
|
| 896 |
+
], style={'color': '#6c757d', 'fontSize': '13px'})
|
| 897 |
+
], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#f8f9fa',
|
| 898 |
+
'borderRadius': '6px', 'border': '1px solid #dee2e6'})
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
return html.Div(content_items)
|
| 902 |
+
|
| 903 |
+
# --- Static fallback (prompt-only analysis, no generated tokens yet) ---
|
| 904 |
content_items = [
|
| 905 |
html.Div([
|
| 906 |
html.H5("What happens here:", style={'color': '#495057', 'marginBottom': '8px'}),
|
|
|
|
| 911 |
])
|
| 912 |
]
|
| 913 |
|
|
|
|
| 914 |
if predicted_token:
|
|
|
|
| 915 |
prompt_display = original_prompt if original_prompt else ""
|
|
|
|
| 916 |
content_items.append(
|
| 917 |
html.Div([
|
| 918 |
html.Div([
|
| 919 |
html.Span("Model prediction:", style={'color': '#495057', 'marginBottom': '12px', 'display': 'block', 'fontWeight': '500'}),
|
| 920 |
html.Div([
|
|
|
|
| 921 |
html.Span(prompt_display, style={
|
| 922 |
+
'color': '#6c757d', 'fontFamily': 'monospace', 'fontSize': '15px'
|
|
|
|
|
|
|
| 923 |
}),
|
|
|
|
| 924 |
html.Span(predicted_token, style={
|
| 925 |
+
'padding': '4px 8px', 'backgroundColor': '#00f2fe',
|
| 926 |
+
'color': '#1a1a2e', 'borderRadius': '4px',
|
| 927 |
+
'fontFamily': 'monospace', 'fontWeight': '600',
|
| 928 |
+
'fontSize': '15px', 'marginLeft': '2px'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
})
|
| 930 |
], style={'display': 'inline'}),
|
|
|
|
| 931 |
html.Div([
|
| 932 |
html.Span(f"{predicted_prob:.1%} confidence" if predicted_prob else "", style={
|
| 933 |
+
'color': '#6c757d', 'fontSize': '13px', 'marginTop': '8px', 'display': 'block'
|
|
|
|
|
|
|
|
|
|
| 934 |
})
|
| 935 |
])
|
| 936 |
], style={'textAlign': 'center'})
|
| 937 |
+
], style={'padding': '20px', 'backgroundColor': 'white', 'borderRadius': '8px',
|
| 938 |
'border': '2px solid #00f2fe', 'marginBottom': '16px'})
|
| 939 |
)
|
| 940 |
|
|
|
|
| 941 |
if top_tokens:
|
| 942 |
tokens = [t[0] for t in top_tokens[:5]]
|
| 943 |
probs = [t[1] for t in top_tokens[:5]]
|
| 944 |
|
| 945 |
fig = go.Figure(go.Bar(
|
| 946 |
+
x=probs, y=tokens, orientation='h',
|
|
|
|
|
|
|
| 947 |
marker_color=['#00f2fe' if i == 0 else '#4facfe' for i in range(len(tokens))],
|
| 948 |
+
text=[f"{p:.1%}" for p in probs], textposition='outside',
|
|
|
|
|
|
|
| 949 |
hovertemplate='%{y} (%{x:.1%})<extra></extra>'
|
| 950 |
))
|
|
|
|
| 951 |
fig.update_layout(
|
| 952 |
+
title="Top 5 Predictions", xaxis_title="Probability", yaxis_title="Token",
|
| 953 |
+
height=250, margin=dict(l=20, r=60, t=40, b=20),
|
| 954 |
+
paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
yaxis=dict(autorange='reversed')
|
| 956 |
)
|
|
|
|
| 957 |
content_items.append(
|
| 958 |
html.Div([
|
| 959 |
dcc.Graph(figure=fig, config={'displayModeBar': False})
|
|
|
|
| 966 |
], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
|
| 967 |
)
|
| 968 |
|
|
|
|
| 969 |
content_items.append(
|
| 970 |
html.Div([
|
| 971 |
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
plans.md
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
- specs on what each attention head does
|
| 2 |
-
- change attention to entire generated sequence
|
| 3 |
-
- output slider to look at each token
|
| 4 |
-
- put in a more obvious place?
|
| 5 |
- experiment results side by side comparison
|
| 6 |
- output streaming for chatbot
|
| 7 |
-
- change width of chatbot window
|
| 8 |
- shorter, concise responses in system prompt
|
| 9 |
- add video links to glossary
|
| 10 |
- three blue one brown
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
- specs on what each attention head does
|
|
|
|
|
|
|
|
|
|
| 2 |
- experiment results side by side comparison
|
| 3 |
- output streaming for chatbot
|
|
|
|
| 4 |
- shorter, concise responses in system prompt
|
| 5 |
- add video links to glossary
|
| 6 |
- three blue one brown
|
| 7 |
+
|
| 8 |
+
Done:
|
| 9 |
+
- change attention to entire generated sequence
|
| 10 |
+
- change width of chatbot window
|
| 11 |
+
- add output token generation to attention, tokenization, etc
|
| 12 |
+
- output slider to look at each token (scrubber with top-5 at each position)
|
tests/conftest.py
CHANGED
|
@@ -5,6 +5,10 @@ Provides reusable mock data structures and synthetic tensors
|
|
| 5 |
to test utility functions without loading actual ML models.
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import pytest
|
| 9 |
import torch
|
| 10 |
import numpy as np
|
|
|
|
| 5 |
to test utility functions without loading actual ML models.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
# Disable TensorFlow before any other imports (mirrors app.py)
|
| 9 |
+
import os
|
| 10 |
+
os.environ["USE_TF"] = "0"
|
| 11 |
+
|
| 12 |
import pytest
|
| 13 |
import torch
|
| 14 |
import numpy as np
|
tests/test_model_patterns.py
CHANGED
|
@@ -263,6 +263,136 @@ class TestMultiLayerHeadAblation:
|
|
| 263 |
assert '99' in result['error'] # Should mention the invalid layer
|
| 264 |
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
class TestFullSequenceAttentionData:
|
| 267 |
"""
|
| 268 |
Tests verifying that activation data for full sequences (prompt + generated output)
|
|
|
|
| 263 |
assert '99' in result['error'] # Should mention the invalid layer
|
| 264 |
|
| 265 |
|
| 266 |
+
class TestComputePerPositionTop5:
|
| 267 |
+
"""Tests for compute_per_position_top5 function."""
|
| 268 |
+
|
| 269 |
+
def _make_mock_output(self, seq_len, vocab_size=10):
|
| 270 |
+
"""Create a mock model output with predictable logits.
|
| 271 |
+
|
| 272 |
+
At each position i, logit[i] = 10.0 (highest), so the top-1 token
|
| 273 |
+
is always token index == position index. Other logits are 1.0.
|
| 274 |
+
"""
|
| 275 |
+
logits = torch.ones(1, seq_len, vocab_size)
|
| 276 |
+
for i in range(seq_len):
|
| 277 |
+
# Make token (i % vocab_size) the top prediction at position i
|
| 278 |
+
logits[0, i, i % vocab_size] = 10.0
|
| 279 |
+
|
| 280 |
+
class MockOutput:
|
| 281 |
+
pass
|
| 282 |
+
out = MockOutput()
|
| 283 |
+
out.logits = logits
|
| 284 |
+
return out
|
| 285 |
+
|
| 286 |
+
def _make_mock_tokenizer(self, vocab_size=10):
|
| 287 |
+
"""Create a mock tokenizer that decodes token IDs to 'tok_N'."""
|
| 288 |
+
from unittest.mock import MagicMock
|
| 289 |
+
tok = MagicMock()
|
| 290 |
+
def decode_fn(ids, skip_special_tokens=False):
|
| 291 |
+
if isinstance(ids, list) and len(ids) == 1:
|
| 292 |
+
return f"tok_{ids[0]}"
|
| 293 |
+
return "".join(f"tok_{i}" for i in ids)
|
| 294 |
+
tok.decode = decode_fn
|
| 295 |
+
return tok
|
| 296 |
+
|
| 297 |
+
def test_returns_correct_number_of_positions(self):
|
| 298 |
+
"""With prompt_token_count=3 and seq_len=7, should return 4 positions (7-3)."""
|
| 299 |
+
from utils.model_patterns import compute_per_position_top5
|
| 300 |
+
model_output = self._make_mock_output(seq_len=7, vocab_size=10)
|
| 301 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 302 |
+
# Full sequence has 7 tokens, prompt has 3, so 4 generated tokens
|
| 303 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 304 |
+
assert len(result) == 4 # positions 0, 1, 2, 3
|
| 305 |
+
|
| 306 |
+
def test_single_generated_token(self):
|
| 307 |
+
"""With 1 generated token, should return exactly 1 position."""
|
| 308 |
+
from utils.model_patterns import compute_per_position_top5
|
| 309 |
+
model_output = self._make_mock_output(seq_len=4, vocab_size=10)
|
| 310 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 311 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 312 |
+
assert len(result) == 1
|
| 313 |
+
assert result[0]['position'] == 0
|
| 314 |
+
|
| 315 |
+
def test_each_position_has_top_k_entries(self):
|
| 316 |
+
"""Each position should have exactly top_k entries in top5 list."""
|
| 317 |
+
from utils.model_patterns import compute_per_position_top5
|
| 318 |
+
model_output = self._make_mock_output(seq_len=8, vocab_size=10)
|
| 319 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 320 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 321 |
+
for pos_data in result:
|
| 322 |
+
assert len(pos_data['top5']) == 5
|
| 323 |
+
|
| 324 |
+
def test_top_k_3(self):
|
| 325 |
+
"""Should respect custom top_k parameter."""
|
| 326 |
+
from utils.model_patterns import compute_per_position_top5
|
| 327 |
+
model_output = self._make_mock_output(seq_len=6, vocab_size=10)
|
| 328 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 329 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=3)
|
| 330 |
+
for pos_data in result:
|
| 331 |
+
assert len(pos_data['top5']) == 3
|
| 332 |
+
|
| 333 |
+
def test_probabilities_sorted_descending(self):
|
| 334 |
+
"""Top-5 probabilities should be in descending order."""
|
| 335 |
+
from utils.model_patterns import compute_per_position_top5
|
| 336 |
+
model_output = self._make_mock_output(seq_len=6, vocab_size=10)
|
| 337 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 338 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 339 |
+
for pos_data in result:
|
| 340 |
+
probs = [entry['probability'] for entry in pos_data['top5']]
|
| 341 |
+
assert probs == sorted(probs, reverse=True)
|
| 342 |
+
|
| 343 |
+
def test_probabilities_are_valid(self):
|
| 344 |
+
"""All probabilities should be between 0 and 1."""
|
| 345 |
+
from utils.model_patterns import compute_per_position_top5
|
| 346 |
+
model_output = self._make_mock_output(seq_len=6, vocab_size=10)
|
| 347 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 348 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 349 |
+
for pos_data in result:
|
| 350 |
+
for entry in pos_data['top5']:
|
| 351 |
+
assert 0.0 <= entry['probability'] <= 1.0
|
| 352 |
+
assert 0.0 <= pos_data['actual_prob'] <= 1.0
|
| 353 |
+
|
| 354 |
+
def test_actual_token_field_present(self):
|
| 355 |
+
"""Each position should have actual_token and actual_prob fields."""
|
| 356 |
+
from utils.model_patterns import compute_per_position_top5
|
| 357 |
+
model_output = self._make_mock_output(seq_len=6, vocab_size=10)
|
| 358 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 359 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 360 |
+
for pos_data in result:
|
| 361 |
+
assert 'actual_token' in pos_data
|
| 362 |
+
assert 'actual_prob' in pos_data
|
| 363 |
+
assert isinstance(pos_data['actual_token'], str)
|
| 364 |
+
assert isinstance(pos_data['actual_prob'], float)
|
| 365 |
+
|
| 366 |
+
def test_position_indices_sequential(self):
|
| 367 |
+
"""Position indices should be 0, 1, 2, ... N-1."""
|
| 368 |
+
from utils.model_patterns import compute_per_position_top5
|
| 369 |
+
model_output = self._make_mock_output(seq_len=8, vocab_size=10)
|
| 370 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 371 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 372 |
+
positions = [r['position'] for r in result]
|
| 373 |
+
assert positions == list(range(5)) # 8 - 3 = 5 positions
|
| 374 |
+
|
| 375 |
+
def test_does_not_include_position_beyond_sequence(self):
|
| 376 |
+
"""Should NOT produce a position that predicts beyond the last token."""
|
| 377 |
+
from utils.model_patterns import compute_per_position_top5
|
| 378 |
+
model_output = self._make_mock_output(seq_len=5, vocab_size=10)
|
| 379 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 380 |
+
# prompt=3, seq=5, so 2 generated tokens -> positions 0 and 1
|
| 381 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 382 |
+
assert len(result) == 2
|
| 383 |
+
# Position 0: logits at index 2 (prompt_len-1), predicts token at index 3
|
| 384 |
+
# Position 1: logits at index 3, predicts token at index 4
|
| 385 |
+
# NO position for logits at index 4 (would predict beyond sequence)
|
| 386 |
+
|
| 387 |
+
def test_prompt_equals_sequence_returns_empty(self):
|
| 388 |
+
"""When prompt_token_count == seq_len (no generated tokens), return empty."""
|
| 389 |
+
from utils.model_patterns import compute_per_position_top5
|
| 390 |
+
model_output = self._make_mock_output(seq_len=3, vocab_size=10)
|
| 391 |
+
tokenizer = self._make_mock_tokenizer(vocab_size=10)
|
| 392 |
+
result = compute_per_position_top5(model_output, tokenizer, prompt_token_count=3, top_k=5)
|
| 393 |
+
assert result == []
|
| 394 |
+
|
| 395 |
+
|
| 396 |
class TestFullSequenceAttentionData:
|
| 397 |
"""
|
| 398 |
Tests verifying that activation data for full sequences (prompt + generated output)
|
todo.md
CHANGED
|
@@ -208,3 +208,20 @@
|
|
| 208 |
- [x] Added 5 tests in `test_model_patterns.py` (`TestFullSequenceAttentionData`) verifying attention matrix dimensions match full sequence length
|
| 209 |
- Attention visualization now covers the entire chosen output (input + generated tokens), not just the input prompt
|
| 210 |
- No changes needed in `model_patterns.py`, `beam_search.py`, `pipeline.py`, or `head_detection.py`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
- [x] Added 5 tests in `test_model_patterns.py` (`TestFullSequenceAttentionData`) verifying attention matrix dimensions match full sequence length
|
| 209 |
- Attention visualization now covers the entire chosen output (input + generated tokens), not just the input prompt
|
| 210 |
- No changes needed in `model_patterns.py`, `beam_search.py`, `pipeline.py`, or `head_detection.py`
|
| 211 |
+
|
| 212 |
+
## Completed: Output Token Scrubber
|
| 213 |
+
|
| 214 |
+
- [x] Add `compute_per_position_top5()` to `utils/model_patterns.py` — extracts top-5 next-token probabilities at each generated-token position from a single forward pass
|
| 215 |
+
- [x] Add `original_prompt` parameter to `execute_forward_pass()` — when provided, computes per-position top-5 data and stores in activation_data
|
| 216 |
+
- [x] Export `compute_per_position_top5` in `utils/__init__.py`
|
| 217 |
+
- [x] Update `run_generation()` in app.py — passes `original_prompt=prompt` for single-token generation
|
| 218 |
+
- [x] Update `store_selected_beam()` in app.py — reads original prompt from session store and passes to forward pass
|
| 219 |
+
- [x] Rewrite `create_output_content()` in `components/pipeline.py` — scrubber mode with `dcc.Slider`, token display, and top-5 chart; falls back to static mode when no per-position data
|
| 220 |
+
- [x] Add `_build_token_display()` and `_build_top5_chart()` helpers in pipeline.py
|
| 221 |
+
- [x] Add `update_output_scrubber()` callback in app.py — responds to slider changes, updates token highlight and chart
|
| 222 |
+
- [x] Update `update_pipeline_content()` in app.py — extracts per-position data and passes to output content
|
| 223 |
+
- [x] Add 10 tests for `compute_per_position_top5` in `test_model_patterns.py`
|
| 224 |
+
- [x] Fix `conftest.py` to set `USE_TF=0` for test import compatibility
|
| 225 |
+
- [x] All 100 tests pass
|
| 226 |
+
- Scrubber shows prompt context (gray) + highlighted token (cyan) + top-5 bar chart at each slider position
|
| 227 |
+
- Pre-beam-selection falls back to static output display; scrubber activates after beam selection or single-token generation
|
utils/__init__.py
CHANGED
|
@@ -4,7 +4,8 @@ from .model_patterns import (load_model_and_get_patterns, execute_forward_pass,
|
|
| 4 |
execute_forward_pass_with_head_ablation,
|
| 5 |
execute_forward_pass_with_multi_layer_head_ablation,
|
| 6 |
merge_token_probabilities,
|
| 7 |
-
compute_global_top5_tokens,
|
|
|
|
| 8 |
evaluate_sequence_ablation, generate_bertviz_model_view_html)
|
| 9 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 10 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
|
@@ -25,6 +26,7 @@ __all__ = [
|
|
| 25 |
'generate_bertviz_html',
|
| 26 |
'merge_token_probabilities',
|
| 27 |
'compute_global_top5_tokens',
|
|
|
|
| 28 |
'detect_significant_probability_increases',
|
| 29 |
'generate_bertviz_model_view_html',
|
| 30 |
|
|
|
|
| 4 |
execute_forward_pass_with_head_ablation,
|
| 5 |
execute_forward_pass_with_multi_layer_head_ablation,
|
| 6 |
merge_token_probabilities,
|
| 7 |
+
compute_global_top5_tokens, compute_per_position_top5,
|
| 8 |
+
detect_significant_probability_increases,
|
| 9 |
evaluate_sequence_ablation, generate_bertviz_model_view_html)
|
| 10 |
from .model_config import get_model_family, get_family_config, get_auto_selections, MODEL_TO_FAMILY, MODEL_FAMILIES
|
| 11 |
from .head_detection import categorize_all_heads, categorize_single_layer_heads, format_categorization_summary, HeadCategorizationConfig
|
|
|
|
| 26 |
'generate_bertviz_html',
|
| 27 |
'merge_token_probabilities',
|
| 28 |
'compute_global_top5_tokens',
|
| 29 |
+
'compute_per_position_top5',
|
| 30 |
'detect_significant_probability_increases',
|
| 31 |
'generate_bertviz_model_view_html',
|
| 32 |
|
utils/model_patterns.py
CHANGED
|
@@ -125,6 +125,98 @@ def compute_global_top5_tokens(model_output, tokenizer, top_k: int = 5) -> List[
|
|
| 125 |
return [{'token': t, 'probability': p} for t, p in merged[:top_k]]
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
|
| 129 |
"""
|
| 130 |
Extract the predicted token from model's output.
|
|
@@ -148,16 +240,21 @@ def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
|
|
| 148 |
return token_str, top_prob.item()
|
| 149 |
|
| 150 |
|
| 151 |
-
def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
|
|
|
|
|
|
|
| 152 |
"""
|
| 153 |
Execute forward pass with PyVene IntervenableModel to capture activations from specified modules.
|
| 154 |
|
| 155 |
Args:
|
| 156 |
model: Loaded transformer model
|
| 157 |
tokenizer: Loaded tokenizer
|
| 158 |
-
prompt: Input text prompt
|
| 159 |
config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
|
| 160 |
ablation_config: Optional dict mapping layer numbers to list of head indices to ablate.
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
Returns:
|
| 163 |
JSON-serializable dict with captured activations and metadata
|
|
@@ -255,6 +352,30 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
|
|
| 255 |
except Exception as e:
|
| 256 |
print(f"Warning: Could not extract model output: {e}")
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Build output dictionary
|
| 259 |
result = {
|
| 260 |
"model": getattr(model.config, "name_or_path", "unknown"),
|
|
@@ -267,7 +388,11 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
|
|
| 267 |
"norm_parameters": norm_parameters,
|
| 268 |
"norm_data": norm_data,
|
| 269 |
"actual_output": actual_output,
|
| 270 |
-
"global_top5_tokens": global_top5_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
}
|
| 272 |
|
| 273 |
print(f"Captured {len(captured)} module outputs using PyVene")
|
|
|
|
| 125 |
return [{'token': t, 'probability': p} for t, p in merged[:top_k]]
|
| 126 |
|
| 127 |
|
| 128 |
+
def compute_per_position_top5(model_output, tokenizer, prompt_token_count: int, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 129 |
+
"""
|
| 130 |
+
Compute top-K next-token probabilities at each generated-token position.
|
| 131 |
+
|
| 132 |
+
Uses logits already produced by the forward pass on the full sequence
|
| 133 |
+
(prompt + generated tokens). Position i in the returned list corresponds
|
| 134 |
+
to the prediction of generated token g_i given the prefix up to g_{i-1}.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
model_output: Output from model(**inputs) containing logits [1, seq_len, vocab].
|
| 138 |
+
tokenizer: Tokenizer for decoding token IDs.
|
| 139 |
+
prompt_token_count: Number of tokens in the original prompt (P).
|
| 140 |
+
top_k: Number of top tokens per position (default 5).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
List of dicts, one per generated token position::
|
| 144 |
+
|
| 145 |
+
[
|
| 146 |
+
{
|
| 147 |
+
"position": 0,
|
| 148 |
+
"top5": [{"token": str, "probability": float}, ...],
|
| 149 |
+
"actual_token": str, # token actually generated at this position
|
| 150 |
+
"actual_prob": float # its probability at this position
|
| 151 |
+
},
|
| 152 |
+
...
|
| 153 |
+
]
|
| 154 |
+
"""
|
| 155 |
+
seq_len = model_output.logits.shape[1]
|
| 156 |
+
num_generated = seq_len - prompt_token_count
|
| 157 |
+
if num_generated <= 0:
|
| 158 |
+
return []
|
| 159 |
+
|
| 160 |
+
results = []
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
# Precompute input_ids from the logits tensor shape for actual-token lookup.
|
| 163 |
+
# The actual token at generated position i lives at input index prompt_token_count + i.
|
| 164 |
+
# We recover it from argmax only when we don't have the real ids; however
|
| 165 |
+
# the caller should pass the full-sequence ids. Here we derive the actual
|
| 166 |
+
# token from the logits tensor's *next* position in the sequence.
|
| 167 |
+
all_logits = model_output.logits[0] # [seq_len, vocab]
|
| 168 |
+
|
| 169 |
+
for i in range(num_generated):
|
| 170 |
+
logit_idx = prompt_token_count - 1 + i # index into logits
|
| 171 |
+
next_token_idx = prompt_token_count + i # index of the actual next token
|
| 172 |
+
|
| 173 |
+
probs = F.softmax(all_logits[logit_idx], dim=-1)
|
| 174 |
+
|
| 175 |
+
# --- top-K with merge ---
|
| 176 |
+
top_probs, top_indices = torch.topk(probs, k=min(top_k * 2, len(probs)))
|
| 177 |
+
candidates = [
|
| 178 |
+
(tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
|
| 179 |
+
for idx, prob in zip(top_indices, top_probs)
|
| 180 |
+
]
|
| 181 |
+
merged = merge_token_probabilities(candidates)
|
| 182 |
+
top5 = [{'token': t, 'probability': p} for t, p in merged[:top_k]]
|
| 183 |
+
|
| 184 |
+
# --- actual token at this position ---
|
| 185 |
+
# The actual next token is whichever token the model *was given* at
|
| 186 |
+
# next_token_idx. We can infer it from the argmax of the embedding
|
| 187 |
+
# lookup, but the simplest reliable way is to use the input_ids that
|
| 188 |
+
# produced these logits. Since we don't have direct access to
|
| 189 |
+
# input_ids here, we look at the logits at the *next* position:
|
| 190 |
+
# the token fed at position next_token_idx determined that position's
|
| 191 |
+
# context. We recover it by checking which token index has the
|
| 192 |
+
# highest *un-softmaxed* logit at position (logit_idx - 1) ... but
|
| 193 |
+
# that is circular. Instead, the caller stores the actual token ids
|
| 194 |
+
# alongside model_output. We fall back to a secondary attribute.
|
| 195 |
+
actual_token_id = None
|
| 196 |
+
if hasattr(model_output, 'input_ids') and model_output.input_ids is not None:
|
| 197 |
+
actual_token_id = model_output.input_ids[0, next_token_idx].item()
|
| 198 |
+
elif hasattr(model_output, '_input_ids'):
|
| 199 |
+
actual_token_id = model_output._input_ids[0, next_token_idx].item()
|
| 200 |
+
|
| 201 |
+
if actual_token_id is not None:
|
| 202 |
+
actual_token = tokenizer.decode([actual_token_id], skip_special_tokens=False)
|
| 203 |
+
actual_prob = probs[actual_token_id].item()
|
| 204 |
+
else:
|
| 205 |
+
# Fallback: use the argmax as "actual" (only correct for greedy)
|
| 206 |
+
top_prob, top_idx = probs.max(dim=-1)
|
| 207 |
+
actual_token = tokenizer.decode([top_idx.item()], skip_special_tokens=False)
|
| 208 |
+
actual_prob = top_prob.item()
|
| 209 |
+
|
| 210 |
+
results.append({
|
| 211 |
+
'position': i,
|
| 212 |
+
'top5': top5,
|
| 213 |
+
'actual_token': actual_token,
|
| 214 |
+
'actual_prob': float(actual_prob),
|
| 215 |
+
})
|
| 216 |
+
|
| 217 |
+
return results
|
| 218 |
+
|
| 219 |
+
|
| 220 |
def get_actual_model_output(model_output, tokenizer) -> Tuple[str, float]:
|
| 221 |
"""
|
| 222 |
Extract the predicted token from model's output.
|
|
|
|
| 240 |
return token_str, top_prob.item()
|
| 241 |
|
| 242 |
|
| 243 |
+
def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any],
|
| 244 |
+
ablation_config: Optional[Dict[int, List[int]]] = None,
|
| 245 |
+
original_prompt: Optional[str] = None) -> Dict[str, Any]:
|
| 246 |
"""
|
| 247 |
Execute forward pass with PyVene IntervenableModel to capture activations from specified modules.
|
| 248 |
|
| 249 |
Args:
|
| 250 |
model: Loaded transformer model
|
| 251 |
tokenizer: Loaded tokenizer
|
| 252 |
+
prompt: Input text prompt (may be full sequence: original prompt + generated tokens)
|
| 253 |
config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
|
| 254 |
ablation_config: Optional dict mapping layer numbers to list of head indices to ablate.
|
| 255 |
+
original_prompt: When provided, enables per-position top-5 computation for
|
| 256 |
+
the output scrubber. If prompt contains generated tokens beyond
|
| 257 |
+
original_prompt, each generated-token position gets its own top-5 data.
|
| 258 |
|
| 259 |
Returns:
|
| 260 |
JSON-serializable dict with captured activations and metadata
|
|
|
|
| 352 |
except Exception as e:
|
| 353 |
print(f"Warning: Could not extract model output: {e}")
|
| 354 |
|
| 355 |
+
# --- Per-position top-5 for the output scrubber ---
|
| 356 |
+
per_position_top5 = []
|
| 357 |
+
prompt_token_count = None
|
| 358 |
+
generated_tokens = []
|
| 359 |
+
if original_prompt is not None:
|
| 360 |
+
prompt_ids = tokenizer(original_prompt, return_tensors="pt")["input_ids"]
|
| 361 |
+
prompt_token_count = prompt_ids.shape[1]
|
| 362 |
+
seq_len = inputs["input_ids"].shape[1]
|
| 363 |
+
num_generated = seq_len - prompt_token_count
|
| 364 |
+
|
| 365 |
+
if num_generated > 0:
|
| 366 |
+
# Attach input_ids to model_output so compute_per_position_top5
|
| 367 |
+
# can look up the actual token at each position.
|
| 368 |
+
model_output.input_ids = inputs["input_ids"]
|
| 369 |
+
per_position_top5 = compute_per_position_top5(
|
| 370 |
+
model_output, tokenizer, prompt_token_count, top_k=5
|
| 371 |
+
)
|
| 372 |
+
# Decode each generated token individually for slider marks
|
| 373 |
+
full_ids = inputs["input_ids"][0].tolist()
|
| 374 |
+
generated_tokens = [
|
| 375 |
+
tokenizer.decode([full_ids[prompt_token_count + i]], skip_special_tokens=False)
|
| 376 |
+
for i in range(num_generated)
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
# Build output dictionary
|
| 380 |
result = {
|
| 381 |
"model": getattr(model.config, "name_or_path", "unknown"),
|
|
|
|
| 388 |
"norm_parameters": norm_parameters,
|
| 389 |
"norm_data": norm_data,
|
| 390 |
"actual_output": actual_output,
|
| 391 |
+
"global_top5_tokens": global_top5_tokens,
|
| 392 |
+
"per_position_top5": per_position_top5,
|
| 393 |
+
"prompt_token_count": prompt_token_count,
|
| 394 |
+
"generated_tokens": generated_tokens,
|
| 395 |
+
"original_prompt": original_prompt,
|
| 396 |
}
|
| 397 |
|
| 398 |
print(f"Captured {len(captured)} module outputs using PyVene")
|