Spaces:
Sleeping
Sleeping
Commit ·
aabd66e
1
Parent(s): 95244f4
conductor setup
Browse files- .cursor/rules/minimal_changes.mdc +1 -0
- README.md +76 -181
- app.py +23 -30
- components/__pycache__/ablation_panel.cpython-311.pyc +0 -0
- components/__pycache__/pipeline.cpython-311.pyc +0 -0
- components/ablation_panel.py +68 -101
- components/pipeline.py +15 -0
- conductor/tracks/ablation_20260129/index.md +0 -5
- conductor/tracks/ablation_20260129/metadata.json +0 -8
- conductor/tracks/ablation_20260129/plan.md +0 -31
- conductor/tracks/ablation_20260129/spec.md +0 -36
- utils/__pycache__/model_patterns.cpython-311.pyc +0 -0
- utils/model_patterns.py +14 -6
.cursor/rules/minimal_changes.mdc
CHANGED
|
@@ -20,6 +20,7 @@ alwaysApply: true
|
|
| 20 |
- Plan first:
|
| 21 |
- Update `todo.md` with the smallest next actions tied to `plans.md`.
|
| 22 |
- Keep tasks atomic and check them off as you go.
|
|
|
|
| 23 |
|
| 24 |
- Keep edits minimal:
|
| 25 |
- Prefer small, surgical changes over refactors.
|
|
|
|
| 20 |
- Plan first:
|
| 21 |
- Update `todo.md` with the smallest next actions tied to `plans.md`.
|
| 22 |
- Keep tasks atomic and check them off as you go.
|
| 23 |
+
- Use the `conductor` folder to learn about the project. Maintain this folder after every change to the code in order to keep running memory (only make changes if necessary).
|
| 24 |
|
| 25 |
- Keep edits minimal:
|
| 26 |
- Prefer small, surgical changes over refactors.
|
README.md
CHANGED
|
@@ -1,210 +1,105 @@
|
|
| 1 |
-
# Transformer
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
##
|
| 6 |
|
| 7 |
-
|
| 8 |
-
1. **Interactive Dashboard** (`app.py`) - Web-based visualization with automatic model family detection
|
| 9 |
-
2. **Model Configuration** (`utils/model_config.py`) - Hard-coded mappings for common model families
|
| 10 |
-
3. **Activation Capture** (`utils/model_patterns.py`) - PyVene-based activation capture utilities
|
| 11 |
-
4. **Legacy Tools** (`agnostic_capture.py`, `bertviz_head_model_view.py`) - Command-line tools
|
| 12 |
|
| 13 |
-
##
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
##
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
- **Attention modules**: e.g., `transformer.h.{N}.attn` for GPT-2
|
| 33 |
-
- **MLP modules**: e.g., `model.layers.{N}.mlp` for LLaMA
|
| 34 |
-
- **Normalization parameters**: e.g., `model.norm.weight` for LLaMA
|
| 35 |
-
- **Logit lens parameter**: e.g., `lm_head.weight`
|
| 36 |
-
4. You can still manually adjust selections if needed
|
| 37 |
|
| 38 |
-
###
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
}
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
No code changes needed if the model follows an existing family's architecture!
|
| 50 |
|
| 51 |
-
##
|
| 52 |
|
| 53 |
-
|
| 54 |
-
A model-agnostic activation capture tool that hooks into transformer modules and saves their outputs.
|
| 55 |
|
| 56 |
-
**Key Features:**
|
| 57 |
-
- Automatically categorizes modules into attention, MLP, and other types
|
| 58 |
-
- Interactive module selection by pattern
|
| 59 |
-
- Supports both PyTorch hooks and PyVene integration
|
| 60 |
-
- Saves data in organized JSON structure for easy retrieval
|
| 61 |
-
|
| 62 |
-
**Usage:**
|
| 63 |
```bash
|
| 64 |
-
|
| 65 |
-
python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "Once upon a time"
|
| 66 |
-
|
| 67 |
-
# Capture attention weights for bertviz
|
| 68 |
-
python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "Once upon a time" --output my_activations.json
|
| 69 |
-
|
| 70 |
-
# Auto-select patterns (attention:0, mlp:0 selects first pattern from each)
|
| 71 |
-
python agnostic_capture.py --auto-select "attn:0;mlp:0;other:" --model "Qwen/Qwen2.5-0.5B"
|
| 72 |
```
|
| 73 |
|
| 74 |
-
|
| 75 |
-
When run without `--auto-select`, the script will:
|
| 76 |
-
1. Show available module patterns grouped by type
|
| 77 |
-
2. Allow you to select patterns by index, name, or suffix
|
| 78 |
-
3. Selected patterns apply to ALL layers that contain them
|
| 79 |
-
|
| 80 |
-
### `bertviz_head_model_view.py`
|
| 81 |
-
Creates interactive HTML visualizations of attention patterns using the bertviz library.
|
| 82 |
-
|
| 83 |
-
**Features:**
|
| 84 |
-
- Generates head view (attention patterns per head)
|
| 85 |
-
- Generates model view (attention patterns across layers)
|
| 86 |
-
- Automatically extracts attention weights from captured data
|
| 87 |
-
- Saves HTML files that can be opened in any browser
|
| 88 |
|
| 89 |
-
|
| 90 |
-
```bash
|
| 91 |
-
python bertviz_head_model_view.py
|
| 92 |
-
```
|
| 93 |
|
| 94 |
-
**
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
{
|
| 104 |
-
"model": "model_name",
|
| 105 |
-
"prompt": "input_text",
|
| 106 |
-
"input_ids": [[token_ids]],
|
| 107 |
-
"selected_patterns": {
|
| 108 |
-
"attention": ["pattern1", "pattern2"],
|
| 109 |
-
"mlp": ["pattern1"],
|
| 110 |
-
"other": []
|
| 111 |
-
},
|
| 112 |
-
"selected_modules": {
|
| 113 |
-
"attention": ["model.layers.0.self_attn", "model.layers.1.self_attn", ...],
|
| 114 |
-
"mlp": ["model.layers.0.mlp", "model.layers.1.mlp", ...],
|
| 115 |
-
"other": []
|
| 116 |
-
},
|
| 117 |
-
"captured": {
|
| 118 |
-
"attention_outputs": {
|
| 119 |
-
"model.layers.0.self_attn": {
|
| 120 |
-
"output": [
|
| 121 |
-
[[...]], // Attention output (processed values)
|
| 122 |
-
[[...]] // Attention weights (used by bertviz)
|
| 123 |
-
]
|
| 124 |
-
}
|
| 125 |
-
},
|
| 126 |
-
"mlp_outputs": { ... },
|
| 127 |
-
"other_outputs": { ... }
|
| 128 |
-
}
|
| 129 |
-
}
|
| 130 |
-
```
|
| 131 |
|
| 132 |
-
##
|
| 133 |
|
| 134 |
-
|
| 135 |
-
``
|
| 136 |
-
|
| 137 |
-
``
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
```bash
|
| 143 |
-
python bertviz_head_model_view.py
|
| 144 |
-
```
|
| 145 |
-
- Reads from `agnostic_activations.json`
|
| 146 |
-
- Creates HTML visualization files in `bertviz/` directory
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
-
##
|
| 153 |
|
| 154 |
-
``
|
| 155 |
-
pip install torch transformers bertviz
|
| 156 |
-
```
|
| 157 |
|
| 158 |
-
Optional:
|
| 159 |
```bash
|
| 160 |
-
|
| 161 |
```
|
| 162 |
-
|
| 163 |
-
## Important Notes
|
| 164 |
-
|
| 165 |
-
### For Attention Visualization:
|
| 166 |
-
- **Must capture `self_attn` modules** (not `self_attn.o_proj`) for bertviz to work
|
| 167 |
-
- Attention modules return tuples: `(output, attention_weights)`
|
| 168 |
-
- bertviz uses the attention weights (element 1) for visualization
|
| 169 |
-
|
| 170 |
-
### Module Selection:
|
| 171 |
-
- Patterns use `{layer}` placeholder (e.g., `model.{layer}.self_attn`)
|
| 172 |
-
- Selected patterns apply to ALL layers automatically
|
| 173 |
-
- Use indices, exact names, or unique suffixes for selection
|
| 174 |
-
|
| 175 |
-
### File Outputs:
|
| 176 |
-
- `agnostic_activations.json` - Captured activation data
|
| 177 |
-
- `bertviz/attention_head_view_{model}.html` - Per-head attention visualization
|
| 178 |
-
- `bertviz/attention_model_view_{model}.html` - Cross-layer attention visualization
|
| 179 |
-
|
| 180 |
-
## Troubleshooting
|
| 181 |
-
|
| 182 |
-
**"Attention tensor does not have correct dimensions"**
|
| 183 |
-
- Ensure you captured `self_attn` modules, not output projections
|
| 184 |
-
- Check that attention weights have shape `(batch, heads, seq_len, seq_len)`
|
| 185 |
-
|
| 186 |
-
**"Module not found"**
|
| 187 |
-
- Verify module patterns match your model architecture
|
| 188 |
-
- Use the interactive selection to see available patterns
|
| 189 |
-
|
| 190 |
-
**"No data captured"**
|
| 191 |
-
- Check hook registration succeeded
|
| 192 |
-
- Ensure selected modules exist in the model
|
| 193 |
-
- Verify the model actually runs forward pass
|
| 194 |
-
|
| 195 |
-
## Example Session
|
| 196 |
-
|
| 197 |
-
```bash
|
| 198 |
-
# 1. Capture attention data
|
| 199 |
-
python agnostic_capture.py --model "Qwen/Qwen2.5-0.5B" --prompt "The cat sat on the mat"
|
| 200 |
-
# Select attention patterns: 0 (for model.{layer}.self_attn)
|
| 201 |
-
# Select MLP patterns: (press enter to skip)
|
| 202 |
-
# Select other patterns: (press enter to skip)
|
| 203 |
-
|
| 204 |
-
# 2. Generate visualizations
|
| 205 |
-
python bertviz_head_model_view.py
|
| 206 |
-
|
| 207 |
-
# 3. Open bertviz/attention_head_view_Qwen_Qwen2.5-0.5B.html in browser
|
| 208 |
-
```
|
| 209 |
-
|
| 210 |
-
This will show you how the model attends to different tokens when processing "The cat sat on the mat".
|
|
|
|
| 1 |
+
# Transformer Explanation Dashboard
|
| 2 |
|
| 3 |
+
A comprehensive, interactive tool for capturing, visualizing, and experimenting with Transformer-based Large Language Models (LLMs). This project demystifies the inner workings of models by transforming abstract architectural concepts into tangible, observable phenomena.
|
| 4 |
|
| 5 |
+
## Vision
|
| 6 |
|
| 7 |
+
To foster a deep, intuitive understanding of how powerful models process information by combining interactive visualizations with hands-on experimentation capabilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
## Key Features
|
| 10 |
|
| 11 |
+
### 🔍 Interactive Pipeline Visualization
|
| 12 |
+
Follow the data flow step-by-step through the model's architecture:
|
| 13 |
+
1. **Tokenization**: See how text is split and assigned IDs.
|
| 14 |
+
2. **Embedding**: Visualize the look-up of semantic vectors.
|
| 15 |
+
3. **Attention**: Explore head-level attention patterns using **BertViz**.
|
| 16 |
+
4. **MLP (Feed-Forward)**: Understand where factual knowledge is stored.
|
| 17 |
+
5. **Output Selection**: View probability distributions and top predictions.
|
| 18 |
|
| 19 |
+
### 🧪 Experiments & Investigation
|
| 20 |
+
Go beyond static observation with interactive experiments:
|
| 21 |
+
* **Ablation Studies**: Selectively disable specific attention heads across different layers to observe their impact on generation and probability.
|
| 22 |
+
* **Token Attribution**: Use **Integrated Gradients** to see which input tokens contributed most to a specific prediction.
|
| 23 |
+
* **Beam Search Analysis**: Visualize how multiple generation choices are explored.
|
| 24 |
|
| 25 |
+
### 🤖 Broad Model Support
|
| 26 |
+
The dashboard features **Automatic Model Family Detection**, supporting a wide range of architectures without manual configuration:
|
| 27 |
+
* **LLaMA-like**: LLaMA 2/3, Mistral, Mixtral, Qwen2/2.5
|
| 28 |
+
* **GPT-2**: GPT-2 (Small/Medium/Large/XL)
|
| 29 |
+
* **OPT**: Facebook OPT models
|
| 30 |
+
* **GPT-NeoX**: Pythia, GPT-NeoX
|
| 31 |
+
* **BLOOM**: BigScience BLOOM
|
| 32 |
+
* **Falcon**: TII Falcon
|
| 33 |
+
* **MPT**: MosaicML MPT
|
| 34 |
|
| 35 |
+
## Getting Started
|
| 36 |
|
| 37 |
+
### Prerequisites
|
| 38 |
+
* Python 3.11+ recommended
|
| 39 |
+
* PyTorch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
### Installation
|
| 42 |
|
| 43 |
+
1. Clone the repository:
|
| 44 |
+
```bash
|
| 45 |
+
git clone https://github.com/yourusername/transformer-dashboard.git
|
| 46 |
+
cd transformer-dashboard
|
| 47 |
+
```
|
| 48 |
|
| 49 |
+
2. Install dependencies:
|
| 50 |
+
```bash
|
| 51 |
+
pip install -r requirements.txt
|
| 52 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
### Running the Dashboard
|
| 55 |
|
| 56 |
+
Launch the application:
|
|
|
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
```bash
|
| 59 |
+
python app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
```
|
| 61 |
|
| 62 |
+
Open your browser and navigate to `http://127.0.0.1:8050/`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
## Usage Guide
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
1. **Select a Model**: Choose from the predefined list or enter a HuggingFace model ID. The system will auto-detect the architecture.
|
| 67 |
+
2. **Enter a Prompt**: Type a sentence to analyze.
|
| 68 |
+
3. **Configure Generation**: Adjust "Number of New Tokens" and "Number of Generation Choices" (Beam Width).
|
| 69 |
+
4. **Run Analysis**: Click "Analyze" to run the forward pass.
|
| 70 |
+
5. **Explore the Pipeline**: Click on the pipeline stages (Tokenization, Attention, etc.) to expand detailed views.
|
| 71 |
+
6. **Run Experiments**:
|
| 72 |
+
* Use the **Investigation Panel** at the bottom to switch between Ablation and Attribution tabs.
|
| 73 |
+
* In **Ablation**, select layers and heads to disable, then click "Run Ablation Experiment".
|
| 74 |
+
* In **Attribution**, select a target token and method to visualize feature importance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
## Project Structure
|
| 77 |
|
| 78 |
+
* `app.py`: Main application entry point and layout orchestration.
|
| 79 |
+
* `components/`: Modular UI components.
|
| 80 |
+
* `pipeline.py`: The core 5-stage visualization.
|
| 81 |
+
* `investigation_panel.py`: Ablation and attribution interfaces.
|
| 82 |
+
* `ablation_panel.py`: Specific logic for head ablation UI.
|
| 83 |
+
* `tokenization_panel.py`: Token visualization.
|
| 84 |
+
* `utils/`: Backend logic and helper functions.
|
| 85 |
+
* `model_patterns.py`: Activation capture and hooking logic.
|
| 86 |
+
* `model_config.py`: Model family definitions and auto-detection.
|
| 87 |
+
* `head_detection.py`: Attention head categorization logic.
|
| 88 |
+
* `beam_search.py`: Beam search implementation.
|
| 89 |
+
* `tests/`: Comprehensive test suite ensuring stability.
|
| 90 |
+
* `conductor/`: Detailed project documentation and product guidelines.
|
| 91 |
|
| 92 |
+
## Documentation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
+
For more detailed information on the project's background and technical details, check the `conductor/` directory:
|
| 95 |
+
* [Product Definition](conductor/product.md)
|
| 96 |
+
* [Tech Stack](conductor/tech-stack.md)
|
| 97 |
+
* [Workflow](conductor/workflow.md)
|
| 98 |
|
| 99 |
+
## Contributing
|
| 100 |
|
| 101 |
+
Contributions are welcome! Please ensure that any new features include appropriate tests in the `tests/` directory. Run the test suite before submitting:
|
|
|
|
|
|
|
| 102 |
|
|
|
|
| 103 |
```bash
|
| 104 |
+
pytest tests/
|
| 105 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -639,26 +639,6 @@ def switch_investigation_tab(abl_clicks, attr_clicks, current_tab):
|
|
| 639 |
# CALLBACKS: Investigation Panel - Ablation (Updated for New UI)
|
| 640 |
# ============================================================================
|
| 641 |
|
| 642 |
-
@app.callback(
|
| 643 |
-
Output('ablation-model-view-container', 'children'),
|
| 644 |
-
[Input('session-activation-store', 'data')]
|
| 645 |
-
)
|
| 646 |
-
def update_ablation_model_view(activation_data):
|
| 647 |
-
"""Update BertViz model view when new analysis is run."""
|
| 648 |
-
if not activation_data:
|
| 649 |
-
return html.Div("Run analysis to see attention visualization.",
|
| 650 |
-
style={'padding': '20px', 'color': '#6c757d', 'textAlign': 'center'})
|
| 651 |
-
|
| 652 |
-
try:
|
| 653 |
-
html_content = generate_bertviz_model_view_html(activation_data)
|
| 654 |
-
return html.Iframe(
|
| 655 |
-
srcDoc=html_content,
|
| 656 |
-
style={'width': '100%', 'height': '100%', 'border': 'none'}
|
| 657 |
-
)
|
| 658 |
-
except Exception as e:
|
| 659 |
-
return html.Div(f"Error generating visualization: {str(e)}", style={'color': 'red', 'padding': '20px'})
|
| 660 |
-
|
| 661 |
-
|
| 662 |
@app.callback(
|
| 663 |
[Output('ablation-layer-select', 'options'),
|
| 664 |
Output('ablation-head-select', 'options')],
|
|
@@ -768,11 +748,13 @@ def manage_ablation_heads(add_clicks, clear_clicks, remove_clicks,
|
|
| 768 |
State('session-activation-store', 'data'),
|
| 769 |
State('model-dropdown', 'value'),
|
| 770 |
State('prompt-input', 'value'),
|
| 771 |
-
State('session-selected-beam-store', 'data')
|
|
|
|
|
|
|
| 772 |
prevent_initial_call=True
|
| 773 |
)
|
| 774 |
-
def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_name, prompt, selected_beam):
|
| 775 |
-
"""Run ablation on ORIGINAL PROMPT and compare results."""
|
| 776 |
if not n_clicks or not selected_heads or not activation_data:
|
| 777 |
return no_update, no_update
|
| 778 |
|
|
@@ -809,7 +791,7 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
|
|
| 809 |
if not heads_by_layer:
|
| 810 |
return html.Div("No valid heads selected.", style={'color': '#dc3545'}), no_update
|
| 811 |
|
| 812 |
-
# Run ablation
|
| 813 |
ablated_data = execute_forward_pass_with_multi_layer_head_ablation(
|
| 814 |
model, tokenizer, sequence_text, config, heads_by_layer
|
| 815 |
)
|
|
@@ -821,9 +803,25 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
|
|
| 821 |
ablated_token = ablated_output.get('token', '')
|
| 822 |
ablated_prob = ablated_output.get('probability', 0)
|
| 823 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
results_display = create_ablation_results_display(
|
| 825 |
original_token, ablated_token, original_prob, ablated_prob,
|
| 826 |
-
selected_heads, selected_beam
|
| 827 |
)
|
| 828 |
|
| 829 |
return results_display, ablated_data
|
|
@@ -832,11 +830,6 @@ def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_nam
|
|
| 832 |
import traceback
|
| 833 |
traceback.print_exc()
|
| 834 |
return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'}), no_update
|
| 835 |
-
|
| 836 |
-
except Exception as e:
|
| 837 |
-
import traceback
|
| 838 |
-
traceback.print_exc()
|
| 839 |
-
return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'})
|
| 840 |
|
| 841 |
|
| 842 |
# ============================================================================
|
|
|
|
| 639 |
# CALLBACKS: Investigation Panel - Ablation (Updated for New UI)
|
| 640 |
# ============================================================================
|
| 641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
@app.callback(
|
| 643 |
[Output('ablation-layer-select', 'options'),
|
| 644 |
Output('ablation-head-select', 'options')],
|
|
|
|
| 748 |
State('session-activation-store', 'data'),
|
| 749 |
State('model-dropdown', 'value'),
|
| 750 |
State('prompt-input', 'value'),
|
| 751 |
+
State('session-selected-beam-store', 'data'),
|
| 752 |
+
State('max-new-tokens-slider', 'value'),
|
| 753 |
+
State('beam-width-slider', 'value')],
|
| 754 |
prevent_initial_call=True
|
| 755 |
)
|
| 756 |
+
def run_ablation_experiment(n_clicks, selected_heads, activation_data, model_name, prompt, selected_beam, max_new_tokens, beam_width):
|
| 757 |
+
"""Run ablation on ORIGINAL PROMPT and compare results, including beam generation."""
|
| 758 |
if not n_clicks or not selected_heads or not activation_data:
|
| 759 |
return no_update, no_update
|
| 760 |
|
|
|
|
| 791 |
if not heads_by_layer:
|
| 792 |
return html.Div("No valid heads selected.", style={'color': '#dc3545'}), no_update
|
| 793 |
|
| 794 |
+
# Run ablation for analysis (single pass)
|
| 795 |
ablated_data = execute_forward_pass_with_multi_layer_head_ablation(
|
| 796 |
model, tokenizer, sequence_text, config, heads_by_layer
|
| 797 |
)
|
|
|
|
| 803 |
ablated_token = ablated_output.get('token', '')
|
| 804 |
ablated_prob = ablated_output.get('probability', 0)
|
| 805 |
|
| 806 |
+
# Run ablation for generation
|
| 807 |
+
ablated_beam = None
|
| 808 |
+
try:
|
| 809 |
+
# Always perform beam search during ablation to show comparison
|
| 810 |
+
beam_results = perform_beam_search(
|
| 811 |
+
model, tokenizer, sequence_text,
|
| 812 |
+
beam_width=beam_width,
|
| 813 |
+
max_new_tokens=max_new_tokens,
|
| 814 |
+
ablation_config=heads_by_layer
|
| 815 |
+
)
|
| 816 |
+
if beam_results:
|
| 817 |
+
# Select the top beam
|
| 818 |
+
ablated_beam = {'text': beam_results[0]['text'], 'score': beam_results[0].get('score', 0)}
|
| 819 |
+
except Exception as e:
|
| 820 |
+
print(f"Error during ablated generation: {e}")
|
| 821 |
+
|
| 822 |
results_display = create_ablation_results_display(
|
| 823 |
original_token, ablated_token, original_prob, ablated_prob,
|
| 824 |
+
selected_heads, selected_beam, ablated_beam
|
| 825 |
)
|
| 826 |
|
| 827 |
return results_display, ablated_data
|
|
|
|
| 830 |
import traceback
|
| 831 |
traceback.print_exc()
|
| 832 |
return html.Div(f"Ablation error: {str(e)}", style={'color': '#dc3545'}), no_update
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 833 |
|
| 834 |
|
| 835 |
# ============================================================================
|
components/__pycache__/ablation_panel.cpython-311.pyc
CHANGED
|
Binary files a/components/__pycache__/ablation_panel.cpython-311.pyc and b/components/__pycache__/ablation_panel.cpython-311.pyc differ
|
|
|
components/__pycache__/pipeline.cpython-311.pyc
CHANGED
|
Binary files a/components/__pycache__/pipeline.cpython-311.pyc and b/components/__pycache__/pipeline.cpython-311.pyc differ
|
|
|
components/ablation_panel.py
CHANGED
|
@@ -22,17 +22,6 @@ def create_ablation_panel():
|
|
| 22 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
|
| 23 |
]),
|
| 24 |
|
| 25 |
-
# BertViz Model View Visualization
|
| 26 |
-
html.Div([
|
| 27 |
-
html.H6("Attention Visualization (BertViz Model View)", style={'color': '#495057', 'marginBottom': '8px'}),
|
| 28 |
-
html.P("Use this view to identify interesting heads to ablate.", style={'color': '#6c757d', 'fontSize': '12px'}),
|
| 29 |
-
dcc.Loading(
|
| 30 |
-
id="loading-ablation-viz",
|
| 31 |
-
type="default",
|
| 32 |
-
children=html.Div(id='ablation-model-view-container', style={'height': '500px', 'width': '100%', 'overflow': 'auto', 'border': '1px solid #e2e8f0', 'borderRadius': '8px'})
|
| 33 |
-
)
|
| 34 |
-
], style={'marginBottom': '20px'}),
|
| 35 |
-
|
| 36 |
# Head Selector Interface
|
| 37 |
html.Div([
|
| 38 |
html.Label("Add Head to Ablation List:", className="input-label", style={'marginBottom': '8px', 'display': 'block'}),
|
|
@@ -170,21 +159,10 @@ def create_selected_heads_display(selected_heads):
|
|
| 170 |
|
| 171 |
|
| 172 |
def create_ablation_results_display(original_token, ablated_token, original_prob, ablated_prob,
|
| 173 |
-
selected_heads, selected_beam=None):
|
| 174 |
"""
|
| 175 |
-
Create the ablation results display.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
original_token: Original predicted token
|
| 179 |
-
ablated_token: Predicted token after ablation
|
| 180 |
-
original_prob: Original prediction probability
|
| 181 |
-
ablated_prob: Ablated prediction probability
|
| 182 |
-
selected_heads: List of {layer: N, head: M} dicts
|
| 183 |
-
selected_beam: Optional data for original beam comparison
|
| 184 |
"""
|
| 185 |
-
output_changed = original_token != ablated_token
|
| 186 |
-
prob_delta = ablated_prob - original_prob
|
| 187 |
-
|
| 188 |
# Format selected heads for display
|
| 189 |
all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
|
| 190 |
|
|
@@ -200,87 +178,76 @@ def create_ablation_results_display(original_token, ablated_token, original_prob
|
|
| 200 |
], style={'marginBottom': '16px'})
|
| 201 |
]))
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
html.Div([
|
| 207 |
-
html.Div("Original", style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '6px'}),
|
| 208 |
-
html.Div(original_token, style={
|
| 209 |
-
'padding': '10px 16px',
|
| 210 |
-
'backgroundColor': '#e8f5e9',
|
| 211 |
-
'border': '2px solid #4caf50',
|
| 212 |
-
'borderRadius': '6px',
|
| 213 |
-
'fontFamily': 'monospace',
|
| 214 |
-
'fontWeight': '600',
|
| 215 |
-
'textAlign': 'center'
|
| 216 |
-
}),
|
| 217 |
-
html.Div(f"{original_prob:.1%}", style={
|
| 218 |
-
'fontSize': '12px',
|
| 219 |
-
'color': '#6c757d',
|
| 220 |
-
'marginTop': '4px',
|
| 221 |
-
'textAlign': 'center'
|
| 222 |
-
})
|
| 223 |
-
], style={'flex': '1'}),
|
| 224 |
-
|
| 225 |
-
# Arrow
|
| 226 |
-
html.Div([
|
| 227 |
-
html.I(className='fas fa-arrow-right', style={
|
| 228 |
-
'fontSize': '24px',
|
| 229 |
-
'color': '#dc3545' if output_changed else '#6c757d'
|
| 230 |
-
})
|
| 231 |
-
], style={'display': 'flex', 'alignItems': 'center', 'padding': '0 20px'}),
|
| 232 |
-
|
| 233 |
-
# Ablated
|
| 234 |
-
html.Div([
|
| 235 |
-
html.Div("After Ablation", style={'fontSize': '12px', 'color': '#6c757d', 'marginBottom': '6px'}),
|
| 236 |
-
html.Div(ablated_token, style={
|
| 237 |
-
'padding': '10px 16px',
|
| 238 |
-
'backgroundColor': '#ffebee' if output_changed else '#e8f5e9',
|
| 239 |
-
'border': f'2px solid {"#dc3545" if output_changed else "#4caf50"}',
|
| 240 |
-
'borderRadius': '6px',
|
| 241 |
-
'fontFamily': 'monospace',
|
| 242 |
-
'fontWeight': '600',
|
| 243 |
-
'textAlign': 'center'
|
| 244 |
-
}),
|
| 245 |
-
html.Div(f"{ablated_prob:.1%} ({prob_delta:+.1%})", style={
|
| 246 |
-
'fontSize': '12px',
|
| 247 |
-
'color': '#dc3545' if prob_delta < 0 else '#4caf50' if prob_delta > 0 else '#6c757d',
|
| 248 |
-
'marginTop': '4px',
|
| 249 |
-
'textAlign': 'center'
|
| 250 |
-
})
|
| 251 |
-
], style={'flex': '1'})
|
| 252 |
|
| 253 |
-
], style={
|
| 254 |
-
'display': 'flex',
|
| 255 |
-
'alignItems': 'stretch',
|
| 256 |
-
'padding': '16px',
|
| 257 |
-
'backgroundColor': 'white',
|
| 258 |
-
'borderRadius': '8px',
|
| 259 |
-
'border': '1px solid #e2e8f0'
|
| 260 |
-
}))
|
| 261 |
-
|
| 262 |
-
# Interpretation
|
| 263 |
-
results.append(html.Div([
|
| 264 |
-
html.I(className='fas fa-lightbulb', style={'color': '#ffc107', 'marginRight': '8px'}),
|
| 265 |
-
html.Span(
|
| 266 |
-
"The prediction changed! These heads are important for this input."
|
| 267 |
-
if output_changed else
|
| 268 |
-
"The prediction stayed the same. These heads may not be critical for this specific input.",
|
| 269 |
-
style={'color': '#6c757d', 'fontSize': '13px'}
|
| 270 |
-
)
|
| 271 |
-
], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#fff8e1', 'borderRadius': '6px'}))
|
| 272 |
-
|
| 273 |
-
# Selected beam comparison context
|
| 274 |
-
if selected_beam and selected_beam.get('text'):
|
| 275 |
results.append(html.Div([
|
| 276 |
-
html.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
html.Div([
|
| 278 |
-
html.I(className=
|
| 279 |
-
|
| 280 |
-
html.Span(
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
return html.Div(results, style={
|
| 286 |
'padding': '20px',
|
|
|
|
| 22 |
], style={'color': '#6c757d', 'fontSize': '14px', 'marginBottom': '16px'})
|
| 23 |
]),
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Head Selector Interface
|
| 26 |
html.Div([
|
| 27 |
html.Label("Add Head to Ablation List:", className="input-label", style={'marginBottom': '8px', 'display': 'block'}),
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def create_ablation_results_display(original_token, ablated_token, original_prob, ablated_prob,
|
| 162 |
+
selected_heads, selected_beam=None, ablated_beam=None):
|
| 163 |
"""
|
| 164 |
+
Create the ablation results display focusing on full generation comparison.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
"""
|
|
|
|
|
|
|
|
|
|
| 166 |
# Format selected heads for display
|
| 167 |
all_heads_formatted = [f"L{item['layer']}-H{item['head']}" for item in selected_heads if isinstance(item, dict)]
|
| 168 |
|
|
|
|
| 178 |
], style={'marginBottom': '16px'})
|
| 179 |
]))
|
| 180 |
|
| 181 |
+
# Generation Comparison (Main Display)
|
| 182 |
+
if selected_beam and selected_beam.get('text') and ablated_beam and ablated_beam.get('text'):
|
| 183 |
+
gen_changed = selected_beam['text'] != ablated_beam['text']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
results.append(html.Div([
|
| 186 |
+
html.H6("Full Generation Comparison", style={'color': '#495057', 'marginBottom': '15px'}),
|
| 187 |
+
|
| 188 |
+
# Comparison Grid
|
| 189 |
+
html.Div([
|
| 190 |
+
# Original Generation
|
| 191 |
+
html.Div([
|
| 192 |
+
html.Div("Original Generation", style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#28a745', 'marginBottom': '8px'}),
|
| 193 |
+
html.Div(selected_beam['text'], style={
|
| 194 |
+
'fontFamily': 'monospace',
|
| 195 |
+
'fontSize': '13px',
|
| 196 |
+
'padding': '12px',
|
| 197 |
+
'backgroundColor': '#f8f9fa',
|
| 198 |
+
'border': '1px solid #dee2e6',
|
| 199 |
+
'borderRadius': '6px',
|
| 200 |
+
'whiteSpace': 'pre-wrap',
|
| 201 |
+
'height': '100%'
|
| 202 |
+
})
|
| 203 |
+
], style={'flex': '1', 'marginRight': '10px'}),
|
| 204 |
+
|
| 205 |
+
# Ablated Generation
|
| 206 |
+
html.Div([
|
| 207 |
+
html.Div("Ablated Generation", style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#dc3545', 'marginBottom': '8px'}),
|
| 208 |
+
html.Div(ablated_beam['text'], style={
|
| 209 |
+
'fontFamily': 'monospace',
|
| 210 |
+
'fontSize': '13px',
|
| 211 |
+
'padding': '12px',
|
| 212 |
+
'backgroundColor': '#fff5f5' if gen_changed else '#f8f9fa',
|
| 213 |
+
'border': '1px solid #dee2e6',
|
| 214 |
+
'borderRadius': '6px',
|
| 215 |
+
'whiteSpace': 'pre-wrap',
|
| 216 |
+
'height': '100%'
|
| 217 |
+
})
|
| 218 |
+
], style={'flex': '1', 'marginLeft': '10px'})
|
| 219 |
+
], style={'display': 'flex', 'alignItems': 'stretch'}),
|
| 220 |
+
|
| 221 |
+
# Generation Change Indicator
|
| 222 |
html.Div([
|
| 223 |
+
html.I(className=f"fas {'fa-exclamation-circle' if gen_changed else 'fa-check-circle'}",
|
| 224 |
+
style={'color': '#dc3545' if gen_changed else '#28a745', 'marginRight': '8px'}),
|
| 225 |
+
html.Span(
|
| 226 |
+
"The generated sequence changed significantly after ablation."
|
| 227 |
+
if gen_changed else
|
| 228 |
+
"The generated sequence remained identical.",
|
| 229 |
+
style={'fontWeight': '500', 'color': '#495057'}
|
| 230 |
+
)
|
| 231 |
+
], style={'marginTop': '15px', 'fontSize': '13px'}),
|
| 232 |
+
|
| 233 |
+
# Probability info as secondary context
|
| 234 |
+
html.Div([
|
| 235 |
+
html.Hr(style={'margin': '15px 0', 'borderTop': '1px dotted #dee2e6'}),
|
| 236 |
+
html.Span("Immediate next-token probability: ", style={'color': '#6c757d', 'fontSize': '12px'}),
|
| 237 |
+
html.Span(f"{original_prob:.1%} → {ablated_prob:.1%} ",
|
| 238 |
+
style={'fontSize': '12px', 'fontWeight': 'bold', 'color': '#495057'}),
|
| 239 |
+
html.Span(f"({ablated_prob - original_prob:+.1%})",
|
| 240 |
+
style={'fontSize': '12px', 'color': '#dc3545' if ablated_prob < original_prob else '#28a745'})
|
| 241 |
+
], style={'marginTop': '10px'})
|
| 242 |
+
|
| 243 |
]))
|
| 244 |
+
else:
|
| 245 |
+
# Fallback if beam data is missing
|
| 246 |
+
results.append(html.Div([
|
| 247 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
| 248 |
+
html.Span("Run a full analysis first to select a generation for comparison.",
|
| 249 |
+
style={'color': '#6c757d', 'fontSize': '14px'})
|
| 250 |
+
], style={'padding': '20px', 'backgroundColor': '#f8f9fa', 'borderRadius': '8px', 'border': '1px solid #dee2e6'}))
|
| 251 |
|
| 252 |
return html.Div(results, style={
|
| 253 |
'padding': '20px',
|
components/pipeline.py
CHANGED
|
@@ -792,6 +792,21 @@ def create_output_content(top_tokens=None, predicted_token=None, predicted_prob=
|
|
| 792 |
], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
|
| 793 |
)
|
| 794 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
return html.Div(content_items)
|
| 796 |
|
| 797 |
|
|
|
|
| 792 |
], style={'backgroundColor': 'white', 'borderRadius': '8px', 'border': '1px solid #e2e8f0'})
|
| 793 |
)
|
| 794 |
|
| 795 |
+
# Disclaimer about token selection drivers
|
| 796 |
+
content_items.append(
|
| 797 |
+
html.Div([
|
| 798 |
+
html.I(className='fas fa-info-circle', style={'color': '#6c757d', 'marginRight': '8px'}),
|
| 799 |
+
html.Span([
|
| 800 |
+
html.Strong("Note on Token Selection: "),
|
| 801 |
+
"While the probabilities above show the model's raw preference for the immediate next token, the final choice ",
|
| 802 |
+
"can be influenced by other factors. Techniques like ", html.Strong("Beam Search"),
|
| 803 |
+
" look ahead at multiple possible sequences to find the best overall result, rather than just the single most likely token at each step. ",
|
| 804 |
+
"Additionally, architectures like ", html.Strong("Mixture of Experts (MoE)"),
|
| 805 |
+
" might route processing through different specialized internal networks which can impact the final output distribution."
|
| 806 |
+
], style={'color': '#6c757d', 'fontSize': '13px'})
|
| 807 |
+
], style={'marginTop': '16px', 'padding': '12px', 'backgroundColor': '#f8f9fa', 'borderRadius': '6px', 'border': '1px solid #dee2e6'})
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
return html.Div(content_items)
|
| 811 |
|
| 812 |
|
conductor/tracks/ablation_20260129/index.md
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
# Track ablation_20260129 Context
|
| 2 |
-
|
| 3 |
-
- [Specification](./spec.md)
|
| 4 |
-
- [Implementation Plan](./plan.md)
|
| 5 |
-
- [Metadata](./metadata.json)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conductor/tracks/ablation_20260129/metadata.json
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"track_id": "ablation_20260129",
|
| 3 |
-
"type": "feature",
|
| 4 |
-
"status": "new",
|
| 5 |
-
"created_at": "2026-01-29T12:40:00Z",
|
| 6 |
-
"updated_at": "2026-01-29T12:40:00Z",
|
| 7 |
-
"description": "Implement interactive ablation studies for attention heads in the Dash dashboard."
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conductor/tracks/ablation_20260129/plan.md
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
# Implementation Plan - Interactive Attention Head Ablation
|
| 2 |
-
|
| 3 |
-
## Phase 1: Backend Support for Ablation [checkpoint: 5fc7374]
|
| 4 |
-
- [x] Task: Create a reproduction script to test manual PyVene interventions for head ablation. [43cf4ff]
|
| 5 |
-
- [x] Sub-task: Write a standalone script that loads a small model (e.g., GPT-2) and uses PyVene to zero out a specific head (e.g., L0H0).
|
| 6 |
-
- [x] Sub-task: Verify that the output logits change compared to the baseline run.
|
| 7 |
-
- [x] Task: Extend `utils/model_patterns.py` (or creating `utils/ablation.py`) to support dynamic head masking. [890f413]
|
| 8 |
-
- [x] Sub-task: Write tests for the new ablation utility function.
|
| 9 |
-
- [x] Sub-task: Implement a function `apply_ablation_mask(model, heads_to_ablate)` that registers the necessary PyVene hooks.
|
| 10 |
-
- [x] Task: Update the main inference pipeline to accept an ablation configuration. [d2ea949]
|
| 11 |
-
- [ ] Sub-task: Modify the capture logic to check for an "ablation list" in the request.
|
| 12 |
-
- [ ] Sub-task: Ensure the pipeline correctly applies the mask before running the forward pass.
|
| 13 |
-
- [ ] Task: Conductor - User Manual Verification 'Backend Support for Ablation' (Protocol in workflow.md)
|
| 14 |
-
|
| 15 |
-
## Phase 2: Frontend Control Panel [checkpoint: 24bd049]
|
| 16 |
-
- [x] Task: Create an `AblationPanel` component in `components/`. [e2e8f0]
|
| 17 |
-
- [x] Sub-task: Integrate BertViz Model View to visualize attention heads (replacing custom heatmap/grid).
|
| 18 |
-
- [x] Sub-task: Implement the callback to handle clicks on the grid and update a `dcc.Store` with the list of disabled heads.
|
| 19 |
-
- [x] Task: Integrate the `AblationPanel` into `app.py`. [e2e8f0]
|
| 20 |
-
- [ ] Sub-task: Add the panel to the main layout (likely in a new "Experiments" tab or collapsible sidebar).
|
| 21 |
-
- [ ] Sub-task: Connect the global "Run" or "Update" callback to include the ablation state from the store.
|
| 22 |
-
- [ ] Task: Conductor - User Manual Verification 'Frontend Control Panel' (Protocol in workflow.md)
|
| 23 |
-
|
| 24 |
-
## Phase 3: Visualization & Feedback Loop
|
| 25 |
-
- [~] Task: Connect the Frontend Ablation State to the Backend Inference.
|
| 26 |
-
- [ ] Sub-task: Update the main `app.py` callback to pass the `disabled_heads` list to the backend capture function.
|
| 27 |
-
- [ ] Sub-task: Verify that toggling a head in the UI updates the Logit Lens/Output display.
|
| 28 |
-
- [ ] Task: Visual Polish for Ablated State.
|
| 29 |
-
- [ ] Sub-task: Ensure the Attention Map visualization shows disabled heads as blank or "inactive".
|
| 30 |
-
- [ ] Sub-task: Add a "Reset Ablations" button to quickly restore the original model state.
|
| 31 |
-
- [ ] Task: Conductor - User Manual Verification 'Visualization & Feedback Loop' (Protocol in workflow.md)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conductor/tracks/ablation_20260129/spec.md
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
# Track Specification: Interactive Attention Head Ablation
|
| 2 |
-
|
| 3 |
-
## Overview
|
| 4 |
-
This track introduces interactive ablation capabilities to the Dash dashboard. Users will be able to selectively disable (zero out) specific attention heads in the transformer model and observe the resulting changes in model output (logits/probabilities) and attention patterns. This directly supports the "Interactive Experimentation" core value proposition.
|
| 5 |
-
|
| 6 |
-
## Goals
|
| 7 |
-
- Enable users to toggle specific attention heads on/off via the UI.
|
| 8 |
-
- Update the model's forward pass to respect these ablation masks.
|
| 9 |
-
- Visualize the "ablated" state compared to the "original" state (if feasible) or simply show the new state.
|
| 10 |
-
- Provide immediate feedback on how head removal affects token prediction.
|
| 11 |
-
|
| 12 |
-
## User Stories
|
| 13 |
-
- As a student, I want to turn off a specific attention head to see if it is responsible for a particular grammatical dependency (e.g., matching plural subjects to verbs).
|
| 14 |
-
- As a researcher, I want to ablate a group of heads to test a hypothesis about distributed representations.
|
| 15 |
-
- As a user, I want clear visual indicators of which heads are currently active or disabled.
|
| 16 |
-
|
| 17 |
-
## Requirements
|
| 18 |
-
|
| 19 |
-
### Frontend (Dash)
|
| 20 |
-
- **Ablation Control Panel:** A UI component (e.g., a grid of toggles or a heatmap with clickable cells) representing all attention heads in the model (Layers x Heads).
|
| 21 |
-
- **State Management:** Store the set of "disabled heads" in the Dash app state (`dcc.Store`).
|
| 22 |
-
- **Visual Feedback:**
|
| 23 |
-
- Disabled heads should be visually distinct (e.g., grayed out) in the visualization.
|
| 24 |
-
- The output (Logit Lens or Top-K tokens) must update dynamically when heads are toggled.
|
| 25 |
-
|
| 26 |
-
### Backend (Model Logic)
|
| 27 |
-
- **Intervention Mechanism:** Modify the `model_patterns.py` or `agnostic_capture.py` logic to accept an "ablation mask".
|
| 28 |
-
- **PyVene Integration:** Use PyVene's intervention capabilities to zero out the activations of specific heads during the forward pass.
|
| 29 |
-
- *Technical Note:* This might require defining a specific intervention function that takes the head output and multiplies it by 0 if the index matches the ablated head.
|
| 30 |
-
|
| 31 |
-
### Visualization
|
| 32 |
-
- Update the attention map visualization to reflect that the ablated head is contributing nothing (blank map or "Disabled" overlay).
|
| 33 |
-
|
| 34 |
-
## Non-Functional Requirements
|
| 35 |
-
- **Latency:** The update loop (Toggle -> Inference -> Update UI) should be fast enough for interactive exploration (< 2-3 seconds for small/medium models).
|
| 36 |
-
- **Clarity:** It must be obvious to the user that they have modified the model. A "Reset All" button is essential.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__pycache__/model_patterns.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
|
|
|
utils/model_patterns.py
CHANGED
|
@@ -979,6 +979,18 @@ def evaluate_sequence_ablation(model, tokenizer, sequence_text: str, config: Dic
|
|
| 979 |
}
|
| 980 |
|
| 981 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 983 |
"""
|
| 984 |
Transform layer output to top K token probabilities using logit lens.
|
|
@@ -1003,9 +1015,7 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, to
|
|
| 1003 |
"""
|
| 1004 |
with torch.no_grad():
|
| 1005 |
# Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
|
| 1006 |
-
hidden =
|
| 1007 |
-
if hidden.dim() == 4:
|
| 1008 |
-
hidden = hidden.squeeze(0)
|
| 1009 |
|
| 1010 |
# Step 1: Apply final layer normalization (critical for intermediate layers)
|
| 1011 |
final_norm = get_norm_layer_from_parameter(model, norm_parameter)
|
|
@@ -1096,9 +1106,7 @@ def _get_token_probabilities_for_layer(activation_data: Dict[str, Any], module_n
|
|
| 1096 |
lm_head = model.get_output_embeddings()
|
| 1097 |
|
| 1098 |
with torch.no_grad():
|
| 1099 |
-
hidden =
|
| 1100 |
-
if hidden.dim() == 4:
|
| 1101 |
-
hidden = hidden.squeeze(0)
|
| 1102 |
|
| 1103 |
if final_norm is not None:
|
| 1104 |
hidden = final_norm(hidden)
|
|
|
|
| 979 |
}
|
| 980 |
|
| 981 |
|
| 982 |
+
def _prepare_hidden_state(layer_output: Any) -> torch.Tensor:
|
| 983 |
+
"""Helper to convert layer output to tensor, handling tuple outputs."""
|
| 984 |
+
# Handle PyVene captured tuple outputs where 2nd element is None (e.g. use_cache=False)
|
| 985 |
+
if isinstance(layer_output, (list, tuple)) and len(layer_output) > 1 and layer_output[1] is None:
|
| 986 |
+
layer_output = layer_output[0]
|
| 987 |
+
|
| 988 |
+
hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
|
| 989 |
+
if hidden.dim() == 4:
|
| 990 |
+
hidden = hidden.squeeze(0)
|
| 991 |
+
return hidden
|
| 992 |
+
|
| 993 |
+
|
| 994 |
def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 995 |
"""
|
| 996 |
Transform layer output to top K token probabilities using logit lens.
|
|
|
|
| 1015 |
"""
|
| 1016 |
with torch.no_grad():
|
| 1017 |
# Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
|
| 1018 |
+
hidden = _prepare_hidden_state(layer_output)
|
|
|
|
|
|
|
| 1019 |
|
| 1020 |
# Step 1: Apply final layer normalization (critical for intermediate layers)
|
| 1021 |
final_norm = get_norm_layer_from_parameter(model, norm_parameter)
|
|
|
|
| 1106 |
lm_head = model.get_output_embeddings()
|
| 1107 |
|
| 1108 |
with torch.no_grad():
|
| 1109 |
+
hidden = _prepare_hidden_state(layer_output)
|
|
|
|
|
|
|
| 1110 |
|
| 1111 |
if final_norm is not None:
|
| 1112 |
hidden = final_norm(hidden)
|