cdpearlman commited on
Commit
aabd66e
·
1 Parent(s): 95244f4

conductor setup

Browse files
.cursor/rules/minimal_changes.mdc CHANGED
@@ -20,6 +20,7 @@ alwaysApply: true
20
  - Plan first:
21
  - Update `todo.md` with the smallest next actions tied to `plans.md`.
22
  - Keep tasks atomic and check them off as you go.
 
23
 
24
  - Keep edits minimal:
25
  - Prefer small, surgical changes over refactors.
 
20
  - Plan first:
21
  - Update `todo.md` with the smallest next actions tied to `plans.md`.
22
  - Keep tasks atomic and check them off as you go.
23
+ - Use the `conductor` folder to learn about the project. Maintain this folder after every change to the code in order to keep running memory (only make changes if necessary).
24
 
25
  - Keep edits minimal:
26
  - Prefer small, surgical changes over refactors.
README.md CHANGED
@@ -1,210 +1,105 @@
1
- # Transformer Activation Capture and Visualization
2
 
3
- This project provides tools for capturing activations from transformer models and visualizing attention patterns using bertviz and an interactive Dash web application.
4
 
5
- ## Overview
6
 
7
- The project consists of multiple components:
8
- 1. **Interactive Dashboard** (`app.py`) - Web-based visualization with automatic model family detection
9
- 2. **Model Configuration** (`utils/model_config.py`) - Hard-coded mappings for common model families
10
- 3. **Activation Capture** (`utils/model_patterns.py`) - PyVene-based activation capture utilities
11
- 4. **Legacy Tools** (`agnostic_capture.py`, `bertviz_head_model_view.py`) - Command-line tools
12
 
13
- ## New Feature: Automatic Model Family Detection
14
 
15
- The dashboard now automatically detects model families and pre-fills dropdown selections with appropriate modules and parameters. This eliminates manual selection for common architectures.
 
 
 
 
 
 
16
 
17
- ### Supported Model Families
 
 
 
 
18
 
19
- - **LLaMA-like**: LLaMA 2/3, Mistral, Mixtral, Qwen2/2.5
20
- - **GPT-2**: GPT-2, GPT-2 Medium/Large/XL
21
- - **OPT**: Facebook OPT models (125M - 30B)
22
- - **GPT-NeoX**: EleutherAI Pythia, GPT-NeoX-20B
23
- - **BLOOM**: BigScience BLOOM models
24
- - **Falcon**: TII Falcon models
25
- - **MPT**: MosaicML MPT models
 
 
26
 
27
- ### How It Works
28
 
29
- 1. Select a model from the dropdown
30
- 2. The app detects the model family (e.g., "gpt2", "llama_like")
31
- 3. Dropdowns auto-fill with family-specific patterns:
32
- - **Attention modules**: e.g., `transformer.h.{N}.attn` for GPT-2
33
- - **MLP modules**: e.g., `model.layers.{N}.mlp` for LLaMA
34
- - **Normalization parameters**: e.g., `model.norm.weight` for LLaMA
35
- - **Logit lens parameter**: e.g., `lm_head.weight`
36
- 4. You can still manually adjust selections if needed
37
 
38
- ### Adding New Models
39
 
40
- Edit `utils/model_config.py` and add entries to `MODEL_TO_FAMILY`:
 
 
 
 
41
 
42
- ```python
43
- MODEL_TO_FAMILY = {
44
- "your-org/your-model": "llama_like", # or gpt2, opt, etc.
45
- # ...
46
- }
47
- ```
48
-
49
- No code changes needed if the model follows an existing family's architecture!
50
 
51
- ## Files
52
 
53
- ### `agnostic_capture.py`
54
- A model-agnostic activation capture tool that hooks into transformer modules and saves their outputs.
55
 
56
- **Key Features:**
57
- - Automatically categorizes modules into attention, MLP, and other types
58
- - Interactive module selection by pattern
59
- - Supports both PyTorch hooks and PyVene integration
60
- - Saves data in organized JSON structure for easy retrieval
61
-
62
- **Usage:**
63
  ```bash
64
- # Basic usage
65
- python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "Once upon a time"
66
-
67
- # Capture attention weights for bertviz
68
- python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "Once upon a time" --output my_activations.json
69
-
70
- # Auto-select patterns (attention:0, mlp:0 selects first pattern from each)
71
- python agnostic_capture.py --auto-select "attn:0;mlp:0;other:" --model "Qwen/Qwen2.5-0.5B"
72
  ```
73
 
74
- **Interactive Selection:**
75
- When run without `--auto-select`, the script will:
76
- 1. Show available module patterns grouped by type
77
- 2. Allow you to select patterns by index, name, or suffix
78
- 3. Selected patterns apply to ALL layers that contain them
79
-
80
- ### `bertviz_head_model_view.py`
81
- Creates interactive HTML visualizations of attention patterns using the bertviz library.
82
-
83
- **Features:**
84
- - Generates head view (attention patterns per head)
85
- - Generates model view (attention patterns across layers)
86
- - Automatically extracts attention weights from captured data
87
- - Saves HTML files that can be opened in any browser
88
 
89
- **Usage:**
90
- ```bash
91
- python bertviz_head_model_view.py
92
- ```
93
 
94
- **Output:**
95
- - `bertviz/attention_head_view_{model_name}.html` - Head-level attention patterns
96
- - `bertviz/attention_model_view_{model_name}.html` - Model-level attention patterns
97
-
98
- ## Data Structure
99
-
100
- The captured data is organized in the following JSON structure:
101
-
102
- ```json
103
- {
104
- "model": "model_name",
105
- "prompt": "input_text",
106
- "input_ids": [[token_ids]],
107
- "selected_patterns": {
108
- "attention": ["pattern1", "pattern2"],
109
- "mlp": ["pattern1"],
110
- "other": []
111
- },
112
- "selected_modules": {
113
- "attention": ["model.layers.0.self_attn", "model.layers.1.self_attn", ...],
114
- "mlp": ["model.layers.0.mlp", "model.layers.1.mlp", ...],
115
- "other": []
116
- },
117
- "captured": {
118
- "attention_outputs": {
119
- "model.layers.0.self_attn": {
120
- "output": [
121
- [[...]], // Attention output (processed values)
122
- [[...]] // Attention weights (used by bertviz)
123
- ]
124
- }
125
- },
126
- "mlp_outputs": { ... },
127
- "other_outputs": { ... }
128
- }
129
- }
130
- ```
131
 
132
- ## Workflow
133
 
134
- 1. **Capture Activations:**
135
- ```bash
136
- python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "Your text here"
137
- ```
138
- - Select attention patterns (e.g., `model.{layer}.self_attn`)
139
- - This creates `agnostic_activations.json`
 
 
 
 
 
 
 
140
 
141
- 2. **Generate Visualizations:**
142
- ```bash
143
- python bertviz_head_model_view.py
144
- ```
145
- - Reads from `agnostic_activations.json`
146
- - Creates HTML visualization files in `bertviz/` directory
147
 
148
- 3. **View Results:**
149
- - Open the generated HTML files in your browser
150
- - Explore attention patterns across heads and layers
 
151
 
152
- ## Requirements
153
 
154
- ```bash
155
- pip install torch transformers bertviz
156
- ```
157
 
158
- Optional:
159
  ```bash
160
- pip install pyvene # For enhanced hooking capabilities
161
  ```
162
-
163
- ## Important Notes
164
-
165
- ### For Attention Visualization:
166
- - **Must capture `self_attn` modules** (not `self_attn.o_proj`) for bertviz to work
167
- - Attention modules return tuples: `(output, attention_weights)`
168
- - bertviz uses the attention weights (element 1) for visualization
169
-
170
- ### Module Selection:
171
- - Patterns use `{layer}` placeholder (e.g., `model.{layer}.self_attn`)
172
- - Selected patterns apply to ALL layers automatically
173
- - Use indices, exact names, or unique suffixes for selection
174
-
175
- ### File Outputs:
176
- - `agnostic_activations.json` - Captured activation data
177
- - `bertviz/attention_head_view_{model}.html` - Per-head attention visualization
178
- - `bertviz/attention_model_view_{model}.html` - Cross-layer attention visualization
179
-
180
- ## Troubleshooting
181
-
182
- **"Attention tensor does not have correct dimensions"**
183
- - Ensure you captured `self_attn` modules, not output projections
184
- - Check that attention weights have shape `(batch, heads, seq_len, seq_len)`
185
-
186
- **"Module not found"**
187
- - Verify module patterns match your model architecture
188
- - Use the interactive selection to see available patterns
189
-
190
- **"No data captured"**
191
- - Check hook registration succeeded
192
- - Ensure selected modules exist in the model
193
- - Verify the model actually runs forward pass
194
-
195
- ## Example Session
196
-
197
- ```bash
198
- # 1. Capture attention data
199
- python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "The cat sat on the mat"
200
- # Select attention patterns: 0 (for model.{layer}.self_attn)
201
- # Select MLP patterns: (press enter to skip)
202
- # Select other patterns: (press enter to skip)
203
-
204
- # 2. Generate visualizations
205
- python bertviz_head_model_view.py
206
-
207
- # 3. Open bertviz/attention_head_view_Qwen_Qwen2.5-0.5B.html in browser
208
- ```
209
-
210
- This will show you how the model attends to different tokens when processing "The cat sat on the mat".
 
1
+ # Transformer Explanation Dashboard
2
 
3
+ A comprehensive, interactive tool for capturing, visualizing, and experimenting with Transformer-based Large Language Models (LLMs). This project demystifies the inner workings of models by transforming abstract architectural concepts into tangible, observable phenomena.
4
 
5
+ ## Vision
6
 
7
+ To foster a deep, intuitive understanding of how powerful models process information by combining interactive visualizations with hands-on experimentation capabilities.
 
 
 
 
8
 
9
+ ## Key Features
10
 
11
+ ### 🔍 Interactive Pipeline Visualization
12
+ Follow the data flow step-by-step through the model's architecture:
13
+ 1. **Tokenization**: See how text is split and assigned IDs.
14
+ 2. **Embedding**: Visualize the look-up of semantic vectors.
15
+ 3. **Attention**: Explore head-level attention patterns using **BertViz**.
16
+ 4. **MLP (Feed-Forward)**: Understand where factual knowledge is stored.
17
+ 5. **Output Selection**: View probability distributions and top predictions.
18
 
19
+ ### 🧪 Experiments & Investigation
20
+ Go beyond static observation with interactive experiments:
21
+ * **Ablation Studies**: Selectively disable specific attention heads across different layers to observe their impact on generation and probability.
22
+ * **Token Attribution**: Use **Integrated Gradients** to see which input tokens contributed most to a specific prediction.
23
+ * **Beam Search Analysis**: Visualize how multiple generation choices are explored.
24
 
25
+ ### 🤖 Broad Model Support
26
+ The dashboard features **Automatic Model Family Detection**, supporting a wide range of architectures without manual configuration:
27
+ * **LLaMA-like**: LLaMA 2/3, Mistral, Mixtral, Qwen2/2.5
28
+ * **GPT-2**: GPT-2 (Small/Medium/Large/XL)
29
+ * **OPT**: Facebook OPT models
30
+ * **GPT-NeoX**: Pythia, GPT-NeoX
31
+ * **BLOOM**: BigScience BLOOM
32
+ * **Falcon**: TII Falcon
33
+ * **MPT**: MosaicML MPT
34
 
35
+ ## Getting Started
36
 
37
+ ### Prerequisites
38
+ * Python 3.11+ recommended
39
+ * PyTorch
 
 
 
 
 
40
 
41
+ ### Installation
42
 
43
+ 1. Clone the repository:
44
+ ```bash
45
+ git clone https://github.com/yourusername/transformer-dashboard.git
46
+ cd transformer-dashboard
47
+ ```
48
 
49
+ 2. Install dependencies:
50
+ ```bash
51
+ pip install -r requirements.txt
52
+ ```
 
 
 
 
53
 
54
+ ### Running the Dashboard
55
 
56
+ Launch the application:
 
57
 
 
 
 
 
 
 
 
58
  ```bash
59
+ python app.py
 
 
 
 
 
 
 
60
  ```
61
 
62
+ Open your browser and navigate to `http://127.0.0.1:8050/`.
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ ## Usage Guide
 
 
 
65
 
66
+ 1. **Select a Model**: Choose from the predefined list or enter a HuggingFace model ID. The system will auto-detect the architecture.
67
+ 2. **Enter a Prompt**: Type a sentence to analyze.
68
+ 3. **Configure Generation**: Adjust "Number of New Tokens" and "Number of Generation Choices" (Beam Width).
69
+ 4. **Run Analysis**: Click "Analyze" to run the forward pass.
70
+ 5. **Explore the Pipeline**: Click on the pipeline stages (Tokenization, Attention, etc.) to expand detailed views.
71
+ 6. **Run Experiments**:
72
+ * Use the **Investigation Panel** at the bottom to switch between Ablation and Attribution tabs.
73
+ * In **Ablation**, select layers and heads to disable, then click "Run Ablation Experiment".
74
+ * In **Attribution**, select a target token and method to visualize feature importance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ ## Project Structure
77
 
78
+ * `app.py`: Main application entry point and layout orchestration.
79
+ * `components/`: Modular UI components.
80
+ * `pipeline.py`: The core 5-stage visualization.
81
+ * `investigation_panel.py`: Ablation and attribution interfaces.
82
+ * `ablation_panel.py`: Specific logic for head ablation UI.
83
+ * `tokenization_panel.py`: Token visualization.
84
+ * `utils/`: Backend logic and helper functions.
85
+ * `model_patterns.py`: Activation capture and hooking logic.
86
+ * `model_config.py`: Model family definitions and auto-detection.
87
+ * `head_detection.py`: Attention head categorization logic.
88
+ * `beam_search.py`: Beam search implementation.
89
+ * `tests/`: Comprehensive test suite ensuring stability.
90
+ * `conductor/`: Detailed project documentation and product guidelines.
91
 
92
+ ## Documentation
 
 
 
 
 
93
 
94
+ For more detailed information on the project's background and technical details, check the `conductor/` directory:
95
+ * [Product Definition](conductor/product.md)
96
+ * [Tech Stack](conductor/tech-stack.md)
97
+ * [Workflow](conductor/workflow.md)
98
 
99
+ ## Contributing
100
 
101
+ Contributions are welcome! Please ensure that any new features include appropriate tests in the `tests/` directory. Run the test suite before submitting:
 
 
102
 
 
103
  ```bash
104
+ pytest tests/
105
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -639,26 +639,6 @@ def switch_investigation_tab(abl_clicks, attr_clicks, current_tab):
639
  # CALLBACKS: Investigation Panel - Ablation (Updated for New UI)
640
  # ============================================================================
641
 
642
- @app.callback(
643
- Output('ablation-model-view-container', 'children'),
644
- [Input('session-activation-store', 'data')]
645
- )
646
- def update_ablation_model_view(activation_data):
647
- """Update BertViz model view when new analysis is run."""
648
- if not activation_data:
649
- return html.Div("Run analysis to see attention visualization.",
650
- style={'padding': '20px', 'color': '#6c757d', 'textAlign': 'center'})
651
-
652
- try:
653
- html_content = generate_bertviz_model_view_html(activation_data)
654
- return html.Iframe(
655
- srcDoc=html_content,
656
- style={'width': '100%', 'height': '100%', 'border': 'none'}
657
- )
658
- except Exception as e:
659
- return html.Div(f"Error generating visualization: {str(e)}", style={'color': 'red', 'padding': '20px'})
660
-
661
-
662
  @app.callback(
663
  [Output('ablation-layer-select', 'options'),
664
  Output('ablation-head-select', 'options')],
@@ -768,11 +748,13 @@ def manage_ablation_heads(add_clicks, clear_clicks, remove_clicks,
768
  State('session-activation-store', 'data'),
769
  State('model-dropdown', 'value'),
770
  State('prompt-input', 'value'),
771
- State('session-selected-beam-store', 'data')],
 
 
772
  prevent_initial_call=True
773
  )
774
- def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_name, prompt, selected_beam):
775
- """Run ablation on ORIGINAL PROMPT and compare results."""
776
  if not n_clicks or not selected_heads or not activation_data:
777
  return no_update, no_update
778
 
@@ -809,7 +791,7 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
809
  if not heads_by_layer:
810
  return html.Div("No valid heads selected.", style={'color': '#dc3545'}), no_update
811
 
812
- # Run ablation
813
  ablated_data = execute_forward_pass_with_multi_layer_head_ablation(
814
  model, tokenizer, sequence_text, config, heads_by_layer
815
  )
@@ -821,9 +803,25 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
821
  ablated_token = ablated_output.get('token', '')
822
  ablated_prob = ablated_output.get('probability', 0)
823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  results_display = create_ablation_results_display(
825
  original_token, ablated_token, original_prob, ablated_prob,
826
- selected_heads, selected_beam
827
  )
828
 
829
  return results_display, ablated_data
@@ -832,11 +830,6 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
832
  import traceback
833
  traceback.print_exc()
834
  return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'}), no_update
835
-
836
- except Exception as e:
837
- import traceback
838
- traceback.print_exc()
839
- return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'})
840
 
841
 
842
  # ============================================================================
 
639
  # CALLBACKS: Investigation Panel - Ablation (Updated for New UI)
640
  # ============================================================================
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  @app.callback(
643
  [Output('ablation-layer-select', 'options'),
644
  Output('ablation-head-select', 'options')],
 
748
  State('session-activation-store', 'data'),
749
  State('model-dropdown', 'value'),
750
  State('prompt-input', 'value'),
751
+ State('session-selected-beam-store', 'data'),
752
+ State('max-new-tokens-slider', 'value'),
753
+ State('beam-width-slider', 'value')],
754
  prevent_initial_call=True
755
  )
756
+ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_name, prompt, selected_beam, max_new_tokens, beam_width):
757
+ """Run ablation on ORIGINAL PROMPT and compare results, including beam generation."""
758
  if not n_clicks or not selected_heads or not activation_data:
759
  return no_update, no_update
760
 
 
791
  if not heads_by_layer:
792
  return html.Div("No valid heads selected.", style={'color': '#dc3545'}), no_update
793
 
794
+ # Run ablation for analysis (single pass)
795
  ablated_data = execute_forward_pass_with_multi_layer_head_ablation(
796
  model, tokenizer, sequence_text, config, heads_by_layer
797
  )
 
803
  ablated_token = ablated_output.get('token', '')
804
  ablated_prob = ablated_output.get('probability', 0)
805
 
806
+ # Run ablation for generation
807
+ ablated_beam = None
808
+ try:
809
+ # Always perform beam search during ablation to show comparison
810
+ beam_results = perform_beam_search(
811
+ model, tokenizer, sequence_text,
812
+ beam_width=beam_width,
813
+ max_new_tokens=max_new_tokens,
814
+ ablation_config=heads_by_layer
815
+ )
816
+ if beam_results:
817
+ # Select the top beam
818
+ ablated_beam = {'text': beam_results[0]['text'], 'score': beam_results[0].get('score', 0)}
819
+ except Exception as e:
820
+ print(f"Error during ablated generation: {e}")
821
+
822
  results_display = create_ablation_results_display(
823
  original_token, ablated_token, original_prob, ablated_prob,
824
+ selected_heads, selected_beam, ablated_beam
825
  )
826
 
827
  return results_display, ablated_data
 
830
  import traceback
831
  traceback.print_exc()
832
  return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'}), no_update
 
 
 
 
 
833
 
834
 
835
  # ============================================================================
components/__pycache__/ablation_panel.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/ablation_panel.cpython-311.pyc and b/components/__pycache__/ablation_panel.cpython-311.pyc differ
 
components/__pycache__/pipeline.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/pipeline.cpython-311.pyc and b/components/__pycache__/pipeline.cpython-311.pyc differ
 
components/ablation_panel.py CHANGED
@@ -22,17 +22,6 @@ def create_ablation_panel():
22
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
23
  ]),
24
 
25
- # BertViz Model View Visualization
26
- html.Div([
27
- html.H6("Attention Visualization (BertViz Model View)", style={'color': '#495057', 'marginBottom': '8px'}),
28
- html.P("Use this view to identify interesting heads to ablate.", style={'color': '#6c757d', 'fontSize': '12px'}),
29
- dcc.Loading(
30
- id="loading-ablation-viz",
31
- type="default",
32
- children=html.Div(id='ablation-model-view-container', style={'height': '500px', 'width': '100%', 'overflow': 'auto', 'border': '1px solid #e2e8f0', 'borderRadius': '8px'})
33
- )
34
- ], style={'marginBottom': '20px'}),
35
-
36
  # Head Selector Interface
37
  html.Div([
38
  html.Label("Add Head to Ablation List:", className="input-label", style={'marginBottom': '8px', 'display': 'block'}),
@@ -170,21 +159,10 @@ def create_selected_heads_display(selected_heads):
170
 
171
 
172
  def create_ablation_results_display(original_token, ablated_token, original_prob, ablated_prob,
173
- selected_heads, selected_beam=None):
174
  """
175
- Create the ablation results display.
176
-
177
- Args:
178
- original_token: Original predicted token
179
- ablated_token: Predicted token after ablation
180
- original_prob: Original prediction probability
181
- ablated_prob: Ablated prediction probability
182
- selected_heads: List of {layer: N, head: M} dicts
183
- selected_beam: Optional data for original beam comparison
184
  """
185
- output_changed = original_token != ablated_token
186
- prob_delta = ablated_prob - original_prob
187
-
188
  # Format selected heads for display
189
  all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
190
 
@@ -200,87 +178,76 @@ def create_ablation_results_display(original_token, ablated_token, original_prob
200
  ], style={'marginBottom': '16px'})
201
  ]))
202
 
203
- # Before/After comparison
204
- results.append(html.Div([
205
- # Original
206
- html.Div([
207
- html.Div("Original", style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '6px'}),
208
- html.Div(original_token, style={
209
- 'padding': '10px 16px',
210
- 'backgroundColor': '#e8f5e9',
211
- 'border': '2px solid #4caf50',
212
- 'borderRadius': '6px',
213
- 'fontFamily': 'monospace',
214
- 'fontWeight': '600',
215
- 'textAlign': 'center'
216
- }),
217
- html.Div(f"{original_prob:.1%}", style={
218
- 'fontSize': '12px',
219
- 'color': '#6c757d',
220
- 'marginTop': '4px',
221
- 'textAlign': 'center'
222
- })
223
- ], style={'flex': '1'}),
224
-
225
- # Arrow
226
- html.Div([
227
- html.I(className='fas fa-arrow-right', style={
228
- 'fontSize': '24px',
229
- 'color': '#dc3545' if output_changed else '#6c757d'
230
- })
231
- ], style={'display': 'flex', 'alignItems': 'center', 'padding': '0 20px'}),
232
-
233
- # Ablated
234
- html.Div([
235
- html.Div("After Ablation", style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '6px'}),
236
- html.Div(ablated_token, style={
237
- 'padding': '10px 16px',
238
- 'backgroundColor': '#ffebee' if output_changed else '#e8f5e9',
239
- 'border': f'2px solid {"#dc3545" if output_changed else "#4caf50"}',
240
- 'borderRadius': '6px',
241
- 'fontFamily': 'monospace',
242
- 'fontWeight': '600',
243
- 'textAlign': 'center'
244
- }),
245
- html.Div(f"{ablated_prob:.1%} ({prob_delta:+.1%})", style={
246
- 'fontSize': '12px',
247
- 'color': '#dc3545' if prob_delta < 0 else '#4caf50' if prob_delta > 0 else '#6c757d',
248
- 'marginTop': '4px',
249
- 'textAlign': 'center'
250
- })
251
- ], style={'flex': '1'})
252
 
253
- ], style={
254
- 'display': 'flex',
255
- 'alignItems': 'stretch',
256
- 'padding': '16px',
257
- 'backgroundColor': 'white',
258
- 'borderRadius': '8px',
259
- 'border': '1px solid #e2e8f0'
260
- }))
261
-
262
- # Interpretation
263
- results.append(html.Div([
264
- html.I(className='fas fa-lightbulb', style={'color': '#ffc107', 'marginRight': '8px'}),
265
- html.Span(
266
- "The prediction changed! These heads are important for this input."
267
- if output_changed else
268
- "The prediction stayed the same. These heads may not be critical for this specific input.",
269
- style={'color': '#6c757d', 'fontSize': '13px'}
270
- )
271
- ], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#fff8e1', 'borderRadius': '6px'}))
272
-
273
- # Selected beam comparison context
274
- if selected_beam and selected_beam.get('text'):
275
  results.append(html.Div([
276
- html.Hr(style={'margin': '15px 0', 'borderColor': '#dee2e6'}),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  html.Div([
278
- html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
279
- html.Span("Original generation for reference: ", style={'fontWeight': '500', 'color': '#495057'}),
280
- html.Span(selected_beam['text'], style={'fontFamily': 'monospace', 'backgroundColor': '#f8f9fa',
281
- 'padding': '4px 8px', 'borderRadius': '4px'})
282
- ], style={'fontSize': '13px'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  ]))
 
 
 
 
 
 
 
284
 
285
  return html.Div(results, style={
286
  'padding': '20px',
 
22
  ], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
23
  ]),
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Head Selector Interface
26
  html.Div([
27
  html.Label("Add Head to Ablation List:", className="input-label", style={'marginBottom': '8px', 'display': 'block'}),
 
159
 
160
 
161
  def create_ablation_results_display(original_token, ablated_token, original_prob, ablated_prob,
162
+ selected_heads, selected_beam=None, ablated_beam=None):
163
  """
164
+ Create the ablation results display focusing on full generation comparison.
 
 
 
 
 
 
 
 
165
  """
 
 
 
166
  # Format selected heads for display
167
  all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
168
 
 
178
  ], style={'marginBottom': '16px'})
179
  ]))
180
 
181
+ # Generation Comparison (Main Display)
182
+ if selected_beam and selected_beam.get('text') and ablated_beam and ablated_beam.get('text'):
183
+ gen_changed = selected_beam['text'] != ablated_beam['text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  results.append(html.Div([
186
+ html.H6("Full Generation Comparison", style={'color': '#495057', 'marginBottom': '15px'}),
187
+
188
+ # Comparison Grid
189
+ html.Div([
190
+ # Original Generation
191
+ html.Div([
192
+ html.Div("Original Generation", style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#28a745', 'marginBottom': '8px'}),
193
+ html.Div(selected_beam['text'], style={
194
+ 'fontFamily': 'monospace',
195
+ 'fontSize': '13px',
196
+ 'padding': '12px',
197
+ 'backgroundColor': '#f8f9fa',
198
+ 'border': '1px solid #dee2e6',
199
+ 'borderRadius': '6px',
200
+ 'whiteSpace': 'pre-wrap',
201
+ 'height': '100%'
202
+ })
203
+ ], style={'flex': '1', 'marginRight': '10px'}),
204
+
205
+ # Ablated Generation
206
+ html.Div([
207
+ html.Div("Ablated Generation", style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#dc3545', 'marginBottom': '8px'}),
208
+ html.Div(ablated_beam['text'], style={
209
+ 'fontFamily': 'monospace',
210
+ 'fontSize': '13px',
211
+ 'padding': '12px',
212
+ 'backgroundColor': '#fff5f5' if gen_changed else '#f8f9fa',
213
+ 'border': '1px solid #dee2e6',
214
+ 'borderRadius': '6px',
215
+ 'whiteSpace': 'pre-wrap',
216
+ 'height': '100%'
217
+ })
218
+ ], style={'flex': '1', 'marginLeft': '10px'})
219
+ ], style={'display': 'flex', 'alignItems': 'stretch'}),
220
+
221
+ # Generation Change Indicator
222
  html.Div([
223
+ html.I(className=f"fas {'fa-exclamation-circle' if gen_changed else 'fa-check-circle'}",
224
+ style={'color': '#dc3545' if gen_changed else '#28a745', 'marginRight': '8px'}),
225
+ html.Span(
226
+ "The generated sequence changed significantly after ablation."
227
+ if gen_changed else
228
+ "The generated sequence remained identical.",
229
+ style={'fontWeight': '500', 'color': '#495057'}
230
+ )
231
+ ], style={'marginTop': '15px', 'fontSize': '13px'}),
232
+
233
+ # Probability info as secondary context
234
+ html.Div([
235
+ html.Hr(style={'margin': '15px 0', 'borderTop': '1px dotted #dee2e6'}),
236
+ html.Span("Immediate next-token probability: ", style={'color': '#6c757d', 'fontSize': '12px'}),
237
+ html.Span(f"{original_prob:.1%} → {ablated_prob:.1%} ",
238
+ style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#495057'}),
239
+ html.Span(f"({ablated_prob - original_prob:+.1%})",
240
+ style={'fontSize': '12px', 'color': '#dc3545' if ablated_prob < original_prob else '#28a745'})
241
+ ], style={'marginTop': '10px'})
242
+
243
  ]))
244
+ else:
245
+ # Fallback if beam data is missing
246
+ results.append(html.Div([
247
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
248
+ html.Span("Run a full analysis first to select a generation for comparison.",
249
+ style={'color': '#6c757d', 'fontSize': '14px'})
250
+ ], style={'padding': '20px', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px', 'border': '1px solid #dee2e6'}))
251
 
252
  return html.Div(results, style={
253
  'padding': '20px',
components/pipeline.py CHANGED
@@ -792,6 +792,21 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
792
  ], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
793
  )
794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  return html.Div(content_items)
796
 
797
 
 
792
  ], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
793
  )
794
 
795
+ # Disclaimer about token selection drivers
796
+ content_items.append(
797
+ html.Div([
798
+ html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
799
+ html.Span([
800
+ html.Strong("Note on Token Selection: "),
801
+ "While the probabilities above show the model's raw preference for the immediate next token, the final choice ",
802
+ "can be influenced by other factors. Techniques like ", html.Strong("Beam Search"),
803
+ " look ahead at multiple possible sequences to find the best overall result, rather than just the single most likely token at each step. ",
804
+ "Additionally, architectures like ", html.Strong("Mixture of Experts (MoE)"),
805
+ " might route processing through different specialized internal networks which can impact the final output distribution."
806
+ ], style={'color': '#6c757d', 'fontSize': '13px'})
807
+ ], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '6px', 'border': '1px solid #dee2e6'})
808
+ )
809
+
810
  return html.Div(content_items)
811
 
812
 
conductor/tracks/ablation_20260129/index.md DELETED
@@ -1,5 +0,0 @@
1
- # Track ablation_20260129 Context
2
-
3
- - [Specification](./spec.md)
4
- - [Implementation Plan](./plan.md)
5
- - [Metadata](./metadata.json)
 
 
 
 
 
 
conductor/tracks/ablation_20260129/metadata.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "track_id": "ablation_20260129",
3
- "type": "feature",
4
- "status": "new",
5
- "created_at": "2026-01-29T12:40:00Z",
6
- "updated_at": "2026-01-29T12:40:00Z",
7
- "description": "Implement interactive ablation studies for attention heads in the Dash dashboard."
8
- }
 
 
 
 
 
 
 
 
 
conductor/tracks/ablation_20260129/plan.md DELETED
@@ -1,31 +0,0 @@
1
- # Implementation Plan - Interactive Attention Head Ablation
2
-
3
- ## Phase 1: Backend Support for Ablation [checkpoint: 5fc7374]
4
- - [x] Task: Create a reproduction script to test manual PyVene interventions for head ablation. [43cf4ff]
5
- - [x] Sub-task: Write a standalone script that loads a small model (e.g., GPT-2) and uses PyVene to zero out a specific head (e.g., L0H0).
6
- - [x] Sub-task: Verify that the output logits change compared to the baseline run.
7
- - [x] Task: Extend `utils/model_patterns.py` (or creating `utils/ablation.py`) to support dynamic head masking. [890f413]
8
- - [x] Sub-task: Write tests for the new ablation utility function.
9
- - [x] Sub-task: Implement a function `apply_ablation_mask(model, heads_to_ablate)` that registers the necessary PyVene hooks.
10
- - [x] Task: Update the main inference pipeline to accept an ablation configuration. [d2ea949]
11
- - [ ] Sub-task: Modify the capture logic to check for an "ablation list" in the request.
12
- - [ ] Sub-task: Ensure the pipeline correctly applies the mask before running the forward pass.
13
- - [ ] Task: Conductor - User Manual Verification 'Backend Support for Ablation' (Protocol in workflow.md)
14
-
15
- ## Phase 2: Frontend Control Panel [checkpoint: 24bd049]
16
- - [x] Task: Create an `AblationPanel` component in `components/`. [e2e8f0]
17
- - [x] Sub-task: Integrate BertViz Model View to visualize attention heads (replacing custom heatmap/grid).
18
- - [x] Sub-task: Implement the callback to handle clicks on the grid and update a `dcc.Store` with the list of disabled heads.
19
- - [x] Task: Integrate the `AblationPanel` into `app.py`. [e2e8f0]
20
- - [ ] Sub-task: Add the panel to the main layout (likely in a new "Experiments" tab or collapsible sidebar).
21
- - [ ] Sub-task: Connect the global "Run" or "Update" callback to include the ablation state from the store.
22
- - [ ] Task: Conductor - User Manual Verification 'Frontend Control Panel' (Protocol in workflow.md)
23
-
24
- ## Phase 3: Visualization & Feedback Loop
25
- - [~] Task: Connect the Frontend Ablation State to the Backend Inference.
26
- - [ ] Sub-task: Update the main `app.py` callback to pass the `disabled_heads` list to the backend capture function.
27
- - [ ] Sub-task: Verify that toggling a head in the UI updates the Logit Lens/Output display.
28
- - [ ] Task: Visual Polish for Ablated State.
29
- - [ ] Sub-task: Ensure the Attention Map visualization shows disabled heads as blank or "inactive".
30
- - [ ] Sub-task: Add a "Reset Ablations" button to quickly restore the original model state.
31
- - [ ] Task: Conductor - User Manual Verification 'Visualization & Feedback Loop' (Protocol in workflow.md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conductor/tracks/ablation_20260129/spec.md DELETED
@@ -1,36 +0,0 @@
1
- # Track Specification: Interactive Attention Head Ablation
2
-
3
- ## Overview
4
- This track introduces interactive ablation capabilities to the Dash dashboard. Users will be able to selectively disable (zero out) specific attention heads in the transformer model and observe the resulting changes in model output (logits/probabilities) and attention patterns. This directly supports the "Interactive Experimentation" core value proposition.
5
-
6
- ## Goals
7
- - Enable users to toggle specific attention heads on/off via the UI.
8
- - Update the model's forward pass to respect these ablation masks.
9
- - Visualize the "ablated" state compared to the "original" state (if feasible) or simply show the new state.
10
- - Provide immediate feedback on how head removal affects token prediction.
11
-
12
- ## User Stories
13
- - As a student, I want to turn off a specific attention head to see if it is responsible for a particular grammatical dependency (e.g., matching plural subjects to verbs).
14
- - As a researcher, I want to ablate a group of heads to test a hypothesis about distributed representations.
15
- - As a user, I want clear visual indicators of which heads are currently active or disabled.
16
-
17
- ## Requirements
18
-
19
- ### Frontend (Dash)
20
- - **Ablation Control Panel:** A UI component (e.g., a grid of toggles or a heatmap with clickable cells) representing all attention heads in the model (Layers x Heads).
21
- - **State Management:** Store the set of "disabled heads" in the Dash app state (`dcc.Store`).
22
- - **Visual Feedback:**
23
- - Disabled heads should be visually distinct (e.g., grayed out) in the visualization.
24
- - The output (Logit Lens or Top-K tokens) must update dynamically when heads are toggled.
25
-
26
- ### Backend (Model Logic)
27
- - **Intervention Mechanism:** Modify the `model_patterns.py` or `agnostic_capture.py` logic to accept an "ablation mask".
28
- - **PyVene Integration:** Use PyVene's intervention capabilities to zero out the activations of specific heads during the forward pass.
29
- - *Technical Note:* This might require defining a specific intervention function that takes the head output and multiplies it by 0 if the index matches the ablated head.
30
-
31
- ### Visualization
32
- - Update the attention map visualization to reflect that the ablated head is contributing nothing (blank map or "Disabled" overlay).
33
-
34
- ## Non-Functional Requirements
35
- - **Latency:** The update loop (Toggle -> Inference -> Update UI) should be fast enough for interactive exploration (< 2-3 seconds for small/medium models).
36
- - **Clarity:** It must be obvious to the user that they have modified the model. A "Reset All" button is essential.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__pycache__/model_patterns.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
 
utils/model_patterns.py CHANGED
@@ -979,6 +979,18 @@ def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dic
979
  }
980
 
981
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
983
  """
984
  Transform layer output to top K token probabilities using logit lens.
@@ -1003,9 +1015,7 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, to
1003
  """
1004
  with torch.no_grad():
1005
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
1006
- hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
1007
- if hidden.dim() == 4:
1008
- hidden = hidden.squeeze(0)
1009
 
1010
  # Step 1: Apply final layer normalization (critical for intermediate layers)
1011
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
@@ -1096,9 +1106,7 @@ def _get_token_probabilities_for_layer(activation_data: Dict[str, Any], module_n
1096
  lm_head = model.get_output_embeddings()
1097
 
1098
  with torch.no_grad():
1099
- hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
1100
- if hidden.dim() == 4:
1101
- hidden = hidden.squeeze(0)
1102
 
1103
  if final_norm is not None:
1104
  hidden = final_norm(hidden)
 
979
  }
980
 
981
 
982
+ def _prepare_hidden_state(layer_output: Any) -> torch.Tensor:
983
+ """Helper to convert layer output to tensor, handling tuple outputs."""
984
+ # Handle PyVene captured tuple outputs where 2nd element is None (e.g. use_cache=False)
985
+ if isinstance(layer_output, (list, tuple)) and len(layer_output) > 1 and layer_output[1] is None:
986
+ layer_output = layer_output[0]
987
+
988
+ hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
989
+ if hidden.dim() == 4:
990
+ hidden = hidden.squeeze(0)
991
+ return hidden
992
+
993
+
994
  def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
995
  """
996
  Transform layer output to top K token probabilities using logit lens.
 
1015
  """
1016
  with torch.no_grad():
1017
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
1018
+ hidden = _prepare_hidden_state(layer_output)
 
 
1019
 
1020
  # Step 1: Apply final layer normalization (critical for intermediate layers)
1021
  final_norm = get_norm_layer_from_parameter(model, norm_parameter)
 
1106
  lm_head = model.get_output_embeddings()
1107
 
1108
  with torch.no_grad():
1109
+ hidden = _prepare_hidden_state(layer_output)
 
 
1110
 
1111
  if final_norm is not None:
1112
  hidden = final_norm(hidden)