Spaces:
Sleeping
Sleeping
Commit ·
e13b1ba
1
Parent(s): 3850656
Add comprehensive test suite for utility modules (73 tests)
Browse files- .cursor/debug.log +0 -4
- .cursor/rules/minimal_changes.mdc +6 -0
- .cursor/rules/planning_mode.mdc +62 -0
- components/__pycache__/investigation_panel.cpython-311.pyc +0 -0
- components/__pycache__/model_selector.cpython-311.pyc +0 -0
- components/__pycache__/pipeline.cpython-311.pyc +0 -0
- components/model_selector.py +11 -0
- plans.md +26 -15
- requirements.txt +3 -0
- tests/__init__.py +1 -0
- tests/__pycache__/__init__.cpython-311.pyc +0 -0
- tests/__pycache__/conftest.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_ablation_metrics.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_head_detection.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_model_config.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_model_patterns.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/__pycache__/test_token_attribution.cpython-311-pytest-8.4.2.pyc +0 -0
- tests/conftest.py +206 -0
- tests/test_ablation_metrics.py +120 -0
- tests/test_head_detection.py +313 -0
- tests/test_model_config.py +161 -0
- tests/test_model_patterns.py +180 -0
- tests/test_token_attribution.py +182 -0
- todo.md +9 -16
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/token_attribution.cpython-311.pyc +0 -0
.cursor/debug.log
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
{"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H1", "location": "app.py:render_heatmap:entry", "message": "render_heatmap entry", "data": {"has_activation_data": true, "has_activation_data2": false, "has_original_activation_data": false, "mode_data": {"comparison": "prompt1", "ablation": "original"}, "model_name": "Qwen/Qwen2.5-0.5B", "plotly_version": "6.3.0"}, "timestamp": 1768427085403}
|
| 2 |
-
{"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H2", "location": "app.py:render_heatmap:active_data", "message": "selected active data source", "data": {"comparison_mode": "prompt1", "ablation_mode": "original", "show_comparison_toggle": false, "show_ablation_toggle": false, "active_data_source": "activation_data"}, "timestamp": 1768427090436}
|
| 3 |
-
{"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H3", "location": "app.py:render_heatmap:matrix_data", "message": "matrix data summary", "data": {"rows": 24, "cols": 11, "tokens_len": 11, "layers_len": 24, "top_tokens_rows": 24}, "timestamp": 1768427136380}
|
| 4 |
-
{"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H1", "location": "app.py:render_heatmap:colorbar_config", "message": "heatmap colorbar config", "data": {"colorbar_config": {"title": {"text": "Delta", "side": "right"}}}, "timestamp": 1768427136385}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.cursor/rules/minimal_changes.mdc
CHANGED
|
@@ -11,6 +11,12 @@ alwaysApply: true
|
|
| 11 |
|
| 12 |
# Minimal Change Rules
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
- Plan first:
|
| 15 |
- Update `todo.md` with the smallest next actions tied to `plans.md`.
|
| 16 |
- Keep tasks atomic and check them off as you go.
|
|
|
|
| 11 |
|
| 12 |
# Minimal Change Rules
|
| 13 |
|
| 14 |
+
- Testing & verification:
|
| 15 |
+
- For substantial code changes (new files, new functionality), write tests first in `tests/` that describe expected behavior.
|
| 16 |
+
- Skip tests for UI/frontend changes, trivial additions, and documentation.
|
| 17 |
+
- After implementing changes, run `pytest` to verify all tests pass.
|
| 18 |
+
- If tests fail, iterate on debugging until fixed.
|
| 19 |
+
|
| 20 |
- Plan first:
|
| 21 |
- Update `todo.md` with the smallest next actions tied to `plans.md`.
|
| 22 |
- Keep tasks atomic and check them off as you go.
|
.cursor/rules/planning_mode.mdc
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
description: Rules for dividing work into parallel agent worktrees
|
| 3 |
+
globs:
|
| 4 |
+
- "plans.md"
|
| 5 |
+
- "todo.md"
|
| 6 |
+
- "**/PLAN*.md"
|
| 7 |
+
alwaysApply: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Parallel Agent Planning Rules
|
| 11 |
+
|
| 12 |
+
When creating or updating plans (plan mode active), structure work for independent parallel execution.
|
| 13 |
+
|
| 14 |
+
## Worktree Division Principles
|
| 15 |
+
|
| 16 |
+
- **Isolate by file/module**: Each agent task must target a distinct set of files that no other agent will touch.
|
| 17 |
+
- **No shared edits**: If two tasks could modify the same file, merge them into one task.
|
| 18 |
+
- **Define boundaries explicitly**: In your plan file, list the exact files/directories each agent owns.
|
| 19 |
+
|
| 20 |
+
## Plan Structure for Parallel Agents
|
| 21 |
+
|
| 22 |
+
When dividing work, create a section like:
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
## Parallel Worktrees
|
| 26 |
+
|
| 27 |
+
### Agent A: [Task Name]
|
| 28 |
+
- **Owns**: `utils/feature_a.py`, `tests/test_feature_a.py`
|
| 29 |
+
- **Does not touch**: everything else
|
| 30 |
+
- **Deliverable**: [specific outcome]
|
| 31 |
+
|
| 32 |
+
### Agent B: [Task Name]
|
| 33 |
+
- **Owns**: `components/widget.py`, `tests/test_widget.py`
|
| 34 |
+
- **Does not touch**: everything else
|
| 35 |
+
- **Deliverable**: [specific outcome]
|
| 36 |
+
|
| 37 |
+
### Sequential (after parallel completes)
|
| 38 |
+
- Integration task that touches shared files (e.g., `app.py` imports)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Rules for Each Agent Task
|
| 42 |
+
|
| 43 |
+
- Each task must be **self-contained**: write, test, verify independently.
|
| 44 |
+
- Each task must specify:
|
| 45 |
+
1. Files it will create or modify (exclusive ownership)
|
| 46 |
+
2. Tests it will write/run
|
| 47 |
+
3. Success criteria
|
| 48 |
+
- Shared dependencies (imports, configs) should be locked before parallel work begins.
|
| 49 |
+
|
| 50 |
+
## Conflict Prevention Checklist
|
| 51 |
+
|
| 52 |
+
Before finalizing a parallel plan, verify:
|
| 53 |
+
- [ ] No two agents modify the same file
|
| 54 |
+
- [ ] No two agents add imports to the same `__init__.py`
|
| 55 |
+
- [ ] Shared interfaces are defined and frozen before parallel work
|
| 56 |
+
- [ ] Each agent's tests can run independently
|
| 57 |
+
|
| 58 |
+
## Git Strategy for Parallel Agents
|
| 59 |
+
|
| 60 |
+
- Each agent works on its own feature branch: `feature/<agent-task-name>`
|
| 61 |
+
- Branches are merged sequentially after all pass tests
|
| 62 |
+
- Order: merge least-dependent branches first
|
components/__pycache__/investigation_panel.cpython-311.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
components/__pycache__/model_selector.cpython-311.pyc
CHANGED
|
Binary files a/components/__pycache__/model_selector.cpython-311.pyc and b/components/__pycache__/model_selector.cpython-311.pyc differ
|
|
|
components/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
components/model_selector.py
CHANGED
|
@@ -12,9 +12,20 @@ from dash import html, dcc
|
|
| 12 |
AVAILABLE_MODELS = [
|
| 13 |
# LLaMA-like models (Qwen)
|
| 14 |
{"label": "Qwen2.5-0.5B", "value": "Qwen/Qwen2.5-0.5B"},
|
|
|
|
| 15 |
|
| 16 |
# GPT-2 family
|
| 17 |
{"label": "GPT-2 (124M)", "value": "gpt2"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
]
|
| 19 |
|
| 20 |
def create_model_selector():
|
|
|
|
| 12 |
AVAILABLE_MODELS = [
|
| 13 |
# LLaMA-like models (Qwen)
|
| 14 |
{"label": "Qwen2.5-0.5B", "value": "Qwen/Qwen2.5-0.5B"},
|
| 15 |
+
# {"label": "Qwen2.5-1.5B", "value": "Qwen/Qwen2.5-1.5B"},
|
| 16 |
|
| 17 |
# GPT-2 family
|
| 18 |
{"label": "GPT-2 (124M)", "value": "gpt2"}
|
| 19 |
+
# {"label": "GPT-2 Medium (355M)", "value": "gpt2-medium"},
|
| 20 |
+
# {"label": "GPT-2 Large (774M)", "value": "gpt2-large"},
|
| 21 |
+
|
| 22 |
+
# # OPT family
|
| 23 |
+
# {"label": "OPT-125M", "value": "facebook/opt-125m"},
|
| 24 |
+
# {"label": "OPT-350M", "value": "facebook/opt-350m"},
|
| 25 |
+
|
| 26 |
+
# # GPT-NeoX family (Pythia)
|
| 27 |
+
# {"label": "Pythia-70M", "value": "EleutherAI/pythia-70m"},
|
| 28 |
+
# {"label": "Pythia-160M", "value": "EleutherAI/pythia-160m"},
|
| 29 |
]
|
| 30 |
|
| 31 |
def create_model_selector():
|
plans.md
CHANGED
|
@@ -1,19 +1,30 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
3. **Simplified Codebase**: Removed heatmap, comparison mode, and ~900 lines of code
|
| 10 |
-
4. **Token Attribution**: New gradient-based feature importance analysis
|
| 11 |
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Generation Settings:
|
| 2 |
+
- change "Max New Tokens" label to "Number of New Tokens"
|
| 3 |
+
- the concept of beams is non-trivial, change "Beam" referencse to something like "Number of Generation Choices"
|
| 4 |
|
| 5 |
+
Generated Sequences:
|
| 6 |
+
- don't show the score value, it is virtually meaningless to a user
|
| 7 |
|
| 8 |
+
1. Tokenization
|
| 9 |
+
- stack tokens vertically instead of listing horizontally (more visually appealing)
|
| 10 |
|
| 11 |
+
2. Embedding
|
| 12 |
+
- add some explanation that embedding is taken from pre-learned table (improves understanding about how ID goes to embedding)
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
3. Attention
|
| 15 |
+
- use BertViz head view instead of model view, less overwhelming and lets user scroll through attention heads
|
| 16 |
+
- add explanation of what the user is looking at and how to navigate BertViz vizualization (answers questions "What am I looking at?", "What should this visualization show me?", "What should I look for?")
|
| 17 |
+
- categorize attention heads so that the different components of the model's learning are visible
|
| 18 |
+
- "Most attended tokens" section doesn't provide much value, remove this and focus on categorized attention heads and BertViz
|
| 19 |
+
|
| 20 |
+
4. MLP
|
| 21 |
+
- need more explanation about how the feed-forward networks are learned in training, allowing it to understand the current words based on its training set
|
| 22 |
+
|
| 23 |
+
5. Output
|
| 24 |
+
- include full prompt in "predicted next token", with the predicted token appended at the end (still highlighted)
|
| 25 |
+
- in top 5 tokens graph, the hover-over data should just show percent and token, not the long decimal value
|
| 26 |
+
|
| 27 |
+
Overall Conceptual Changes:
|
| 28 |
+
- the analysis should be done off the initial user-given prompt, not the prompt that includes max tokens
|
| 29 |
+
- the selected beam will be used for comparison after experiments, not for analysis. For example, the user input is "The capital of the US is", the chosen beam is "The capital of the US is Washington D.C.", but the analysis is only done with the first prompt. After either experiment, the chosen beam is used to compare to the new results. If ablation made the output "The capital of the US is New York City", then that can be compared to the original chosen beam to show differences.
|
| 30 |
+
- add testing and verification to the entire project so that each round of changes can be double checked and verified for correctness (try to avoid running the app, just test functions)
|
requirements.txt
CHANGED
|
@@ -13,3 +13,6 @@ bertviz>=1.4.0
|
|
| 13 |
|
| 14 |
# Utility dependencies
|
| 15 |
numpy>=1.24.0
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Utility dependencies
|
| 15 |
numpy>=1.24.0
|
| 16 |
+
|
| 17 |
+
# Testing dependencies
|
| 18 |
+
pytest>=7.0.0
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Test suite for Transformer Activation Capture and Visualization
|
tests/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
tests/__pycache__/conftest.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (7.43 kB). View file
|
|
|
tests/__pycache__/test_ablation_metrics.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
tests/__pycache__/test_head_detection.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (44.9 kB). View file
|
|
|
tests/__pycache__/test_model_config.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (37 kB). View file
|
|
|
tests/__pycache__/test_model_patterns.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (34.1 kB). View file
|
|
|
tests/__pycache__/test_token_attribution.cpython-311-pytest-8.4.2.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared pytest fixtures for the test suite.
|
| 3 |
+
|
| 4 |
+
Provides reusable mock data structures and synthetic tensors
|
| 5 |
+
to test utility functions without loading actual ML models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# Synthetic Attention Matrices
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def uniform_attention_matrix():
|
| 19 |
+
"""4x4 uniform attention matrix (each position attends equally to all)."""
|
| 20 |
+
size = 4
|
| 21 |
+
return torch.ones(size, size) / size
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def previous_token_attention_matrix():
|
| 26 |
+
"""
|
| 27 |
+
4x4 attention matrix where each position attends primarily to the previous token.
|
| 28 |
+
Position 0 attends to itself (no previous token).
|
| 29 |
+
"""
|
| 30 |
+
size = 4
|
| 31 |
+
matrix = torch.zeros(size, size)
|
| 32 |
+
# Position 0 attends to itself
|
| 33 |
+
matrix[0, 0] = 1.0
|
| 34 |
+
# Other positions attend strongly to previous token
|
| 35 |
+
for i in range(1, size):
|
| 36 |
+
matrix[i, i-1] = 0.8
|
| 37 |
+
matrix[i, i] = 0.2 # Some self-attention
|
| 38 |
+
return matrix
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.fixture
|
| 42 |
+
def first_token_attention_matrix():
|
| 43 |
+
"""4x4 attention matrix where all positions attend strongly to first token."""
|
| 44 |
+
size = 4
|
| 45 |
+
matrix = torch.zeros(size, size)
|
| 46 |
+
for i in range(size):
|
| 47 |
+
matrix[i, 0] = 0.7 # Strong attention to first token
|
| 48 |
+
matrix[i, i] = 0.3 # Some self-attention
|
| 49 |
+
return matrix
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@pytest.fixture
|
| 53 |
+
def peaked_attention_matrix():
|
| 54 |
+
"""4x4 attention matrix with peaked (low entropy) attention at one position."""
|
| 55 |
+
size = 4
|
| 56 |
+
matrix = torch.zeros(size, size)
|
| 57 |
+
# Each position attends almost entirely to position 2
|
| 58 |
+
for i in range(size):
|
| 59 |
+
matrix[i, 2] = 0.95
|
| 60 |
+
# Distribute remaining across others
|
| 61 |
+
for j in range(size):
|
| 62 |
+
if j != 2:
|
| 63 |
+
matrix[i, j] = 0.05 / (size - 1)
|
| 64 |
+
return matrix
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# Mock Activation Data Structures
|
| 69 |
+
# =============================================================================
|
| 70 |
+
|
| 71 |
+
@pytest.fixture
|
| 72 |
+
def mock_activation_data():
|
| 73 |
+
"""
|
| 74 |
+
Mock activation data structure similar to execute_forward_pass output.
|
| 75 |
+
Used for testing functions that process activation data.
|
| 76 |
+
"""
|
| 77 |
+
return {
|
| 78 |
+
'model': 'mock-model',
|
| 79 |
+
'prompt': 'Hello world',
|
| 80 |
+
'input_ids': [[1, 2, 3, 4]],
|
| 81 |
+
'attention_modules': ['model.layers.0.self_attn', 'model.layers.1.self_attn'],
|
| 82 |
+
'attention_outputs': {
|
| 83 |
+
'model.layers.0.self_attn': {
|
| 84 |
+
'output': [
|
| 85 |
+
[[0.1, 0.2, 0.3]], # Hidden states (simplified)
|
| 86 |
+
[[[[0.25, 0.25, 0.25, 0.25], # Attention weights [batch, heads, seq, seq]
|
| 87 |
+
[0.25, 0.25, 0.25, 0.25],
|
| 88 |
+
[0.25, 0.25, 0.25, 0.25],
|
| 89 |
+
[0.25, 0.25, 0.25, 0.25]]]]
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
'model.layers.1.self_attn': {
|
| 93 |
+
'output': [
|
| 94 |
+
[[0.1, 0.2, 0.3]],
|
| 95 |
+
[[[[0.1, 0.2, 0.3, 0.4],
|
| 96 |
+
[0.1, 0.2, 0.3, 0.4],
|
| 97 |
+
[0.1, 0.2, 0.3, 0.4],
|
| 98 |
+
[0.1, 0.2, 0.3, 0.4]]]]
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
},
|
| 102 |
+
'block_modules': ['model.layers.0', 'model.layers.1'],
|
| 103 |
+
'block_outputs': {
|
| 104 |
+
'model.layers.0': {'output': [[[0.1, 0.2, 0.3, 0.4]]]},
|
| 105 |
+
'model.layers.1': {'output': [[[0.2, 0.3, 0.4, 0.5]]]}
|
| 106 |
+
},
|
| 107 |
+
'norm_parameters': ['model.norm.weight'],
|
| 108 |
+
'norm_data': [[1.0, 1.0, 1.0, 1.0]],
|
| 109 |
+
'actual_output': {'token': ' world', 'probability': 0.85},
|
| 110 |
+
'global_top5_tokens': [
|
| 111 |
+
{'token': 'world', 'probability': 0.85},
|
| 112 |
+
{'token': 'there', 'probability': 0.05},
|
| 113 |
+
{'token': 'friend', 'probability': 0.03},
|
| 114 |
+
{'token': 'everyone', 'probability': 0.02},
|
| 115 |
+
{'token': 'all', 'probability': 0.01}
|
| 116 |
+
]
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# =============================================================================
|
| 121 |
+
# Mock Module/Parameter Patterns
|
| 122 |
+
# =============================================================================
|
| 123 |
+
|
| 124 |
+
@pytest.fixture
|
| 125 |
+
def mock_module_patterns():
|
| 126 |
+
"""Mock module patterns as returned by extract_patterns."""
|
| 127 |
+
return {
|
| 128 |
+
'model.layers.{N}.self_attn': ['model.layers.0.self_attn', 'model.layers.1.self_attn'],
|
| 129 |
+
'model.layers.{N}.mlp': ['model.layers.0.mlp', 'model.layers.1.mlp'],
|
| 130 |
+
'model.layers.{N}': ['model.layers.0', 'model.layers.1'],
|
| 131 |
+
'model.embed_tokens': ['model.embed_tokens'],
|
| 132 |
+
'model.norm': ['model.norm']
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@pytest.fixture
|
| 137 |
+
def mock_param_patterns():
|
| 138 |
+
"""Mock parameter patterns as returned by extract_patterns."""
|
| 139 |
+
return {
|
| 140 |
+
'model.layers.{N}.self_attn.q_proj.weight': ['model.layers.0.self_attn.q_proj.weight'],
|
| 141 |
+
'model.layers.{N}.self_attn.k_proj.weight': ['model.layers.0.self_attn.k_proj.weight'],
|
| 142 |
+
'model.norm.weight': ['model.norm.weight'],
|
| 143 |
+
'lm_head.weight': ['lm_head.weight']
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# =============================================================================
|
| 148 |
+
# Synthetic Logits for Ablation Metrics
|
| 149 |
+
# =============================================================================
|
| 150 |
+
|
| 151 |
+
@pytest.fixture
|
| 152 |
+
def identical_logits():
|
| 153 |
+
"""Two identical logit tensors for testing KL divergence = 0."""
|
| 154 |
+
logits = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
|
| 155 |
+
[2.0, 3.0, 4.0, 5.0]]]) # [1, 2, 4] = [batch, seq, vocab]
|
| 156 |
+
return logits, logits.clone()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@pytest.fixture
|
| 160 |
+
def different_logits():
|
| 161 |
+
"""Two different logit tensors for testing KL divergence > 0."""
|
| 162 |
+
logits_p = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
|
| 163 |
+
[2.0, 3.0, 4.0, 5.0]]])
|
| 164 |
+
logits_q = torch.tensor([[[4.0, 3.0, 2.0, 1.0],
|
| 165 |
+
[5.0, 4.0, 3.0, 2.0]]])
|
| 166 |
+
return logits_p, logits_q
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@pytest.fixture
|
| 170 |
+
def prob_delta_data():
|
| 171 |
+
"""Data for testing probability delta computation."""
|
| 172 |
+
# Reference favors token 3, ablated favors token 0
|
| 173 |
+
logits_ref = torch.tensor([[[1.0, 2.0, 3.0, 10.0], # pos 0: predicts token 3
|
| 174 |
+
[1.0, 2.0, 10.0, 3.0]]]) # pos 1: predicts token 2
|
| 175 |
+
logits_abl = torch.tensor([[[10.0, 2.0, 3.0, 1.0], # pos 0: predicts token 0
|
| 176 |
+
[10.0, 2.0, 1.0, 3.0]]]) # pos 1: predicts token 0
|
| 177 |
+
input_ids = torch.tensor([[0, 3, 2]]) # Actual tokens: start, 3, 2
|
| 178 |
+
return logits_ref, logits_abl, input_ids
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# =============================================================================
|
| 182 |
+
# Attribution Data for Visualization Tests
|
| 183 |
+
# =============================================================================
|
| 184 |
+
|
| 185 |
+
@pytest.fixture
|
| 186 |
+
def mock_attribution_result():
|
| 187 |
+
"""Mock output from compute_integrated_gradients or compute_simple_gradient_attribution."""
|
| 188 |
+
return {
|
| 189 |
+
'tokens': ['Hello', ' world', '!'],
|
| 190 |
+
'token_ids': [1, 2, 3],
|
| 191 |
+
'attributions': [0.5, 1.0, 0.2], # Raw attribution scores
|
| 192 |
+
'normalized_attributions': [0.5, 1.0, 0.2], # Already normalized for simplicity
|
| 193 |
+
'target_token': 'next',
|
| 194 |
+
'target_token_id': 100
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# =============================================================================
|
| 199 |
+
# Head Categorization Config
|
| 200 |
+
# =============================================================================
|
| 201 |
+
|
| 202 |
+
@pytest.fixture
|
| 203 |
+
def default_head_config():
|
| 204 |
+
"""Default head categorization configuration for testing."""
|
| 205 |
+
from utils.head_detection import HeadCategorizationConfig
|
| 206 |
+
return HeadCategorizationConfig()
|
tests/test_ablation_metrics.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for utils/ablation_metrics.py
|
| 3 |
+
|
| 4 |
+
Tests KL divergence computation and probability delta calculations.
|
| 5 |
+
Uses synthetic tensors to avoid model loading.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from utils.ablation_metrics import compute_kl_divergence, get_token_probability_deltas
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestComputeKLDivergence:
|
| 15 |
+
"""Tests for compute_kl_divergence function."""
|
| 16 |
+
|
| 17 |
+
def test_identical_distributions_zero_kl(self, identical_logits):
|
| 18 |
+
"""KL divergence of identical distributions should be approximately 0."""
|
| 19 |
+
logits_p, logits_q = identical_logits
|
| 20 |
+
kl_divs = compute_kl_divergence(logits_p, logits_q)
|
| 21 |
+
|
| 22 |
+
assert isinstance(kl_divs, list)
|
| 23 |
+
assert len(kl_divs) == 2 # seq_len = 2
|
| 24 |
+
for kl in kl_divs:
|
| 25 |
+
assert abs(kl) < 1e-5, f"Expected ~0, got {kl}"
|
| 26 |
+
|
| 27 |
+
def test_different_distributions_positive_kl(self, different_logits):
|
| 28 |
+
"""KL divergence of different distributions should be positive."""
|
| 29 |
+
logits_p, logits_q = different_logits
|
| 30 |
+
kl_divs = compute_kl_divergence(logits_p, logits_q)
|
| 31 |
+
|
| 32 |
+
assert isinstance(kl_divs, list)
|
| 33 |
+
for kl in kl_divs:
|
| 34 |
+
assert kl > 0, f"Expected positive KL, got {kl}"
|
| 35 |
+
|
| 36 |
+
def test_kl_divergence_asymmetry(self, different_logits):
|
| 37 |
+
"""KL(P||Q) should not equal KL(Q||P) for different distributions."""
|
| 38 |
+
logits_p, logits_q = different_logits
|
| 39 |
+
kl_pq = compute_kl_divergence(logits_p, logits_q)
|
| 40 |
+
kl_qp = compute_kl_divergence(logits_q, logits_p)
|
| 41 |
+
|
| 42 |
+
# They should generally be different (asymmetry of KL divergence)
|
| 43 |
+
assert kl_pq != kl_qp, "KL divergence should be asymmetric"
|
| 44 |
+
|
| 45 |
+
def test_handles_3d_input(self):
|
| 46 |
+
"""Should handle [batch, seq_len, vocab_size] input correctly."""
|
| 47 |
+
logits = torch.randn(1, 5, 100) # batch=1, seq=5, vocab=100
|
| 48 |
+
kl_divs = compute_kl_divergence(logits, logits)
|
| 49 |
+
|
| 50 |
+
assert len(kl_divs) == 5
|
| 51 |
+
for kl in kl_divs:
|
| 52 |
+
assert abs(kl) < 1e-5
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TestGetTokenProbabilityDeltas:
|
| 56 |
+
"""Tests for get_token_probability_deltas function."""
|
| 57 |
+
|
| 58 |
+
def test_deltas_with_synthetic_data(self):
|
| 59 |
+
"""Test probability delta computation with known inputs."""
|
| 60 |
+
# Logits shape: [1, seq_len, vocab_size] where seq_len matches input_ids
|
| 61 |
+
# input_ids has 3 tokens, so logits needs 3 positions
|
| 62 |
+
logits_ref = torch.tensor([[[1.0, 2.0, 3.0, 10.0], # pos 0
|
| 63 |
+
[1.0, 2.0, 10.0, 3.0], # pos 1
|
| 64 |
+
[1.0, 2.0, 3.0, 4.0]]]) # pos 2
|
| 65 |
+
logits_abl = torch.tensor([[[10.0, 2.0, 3.0, 1.0],
|
| 66 |
+
[10.0, 2.0, 1.0, 3.0],
|
| 67 |
+
[1.0, 2.0, 3.0, 4.0]]])
|
| 68 |
+
input_ids = torch.tensor([[0, 3, 2]])
|
| 69 |
+
|
| 70 |
+
deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
|
| 71 |
+
|
| 72 |
+
# Should return list of length seq_len - 1 (shifted prediction)
|
| 73 |
+
assert isinstance(deltas, list)
|
| 74 |
+
assert len(deltas) == 2 # seq_len=3, so 2 predictions (pos 0 predicts token 1, pos 1 predicts token 2)
|
| 75 |
+
|
| 76 |
+
def test_identical_logits_zero_delta(self):
|
| 77 |
+
"""Identical logits should produce zero deltas."""
|
| 78 |
+
# Logits need seq_len=3 to match input_ids
|
| 79 |
+
logits = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
|
| 80 |
+
[2.0, 3.0, 4.0, 5.0],
|
| 81 |
+
[3.0, 4.0, 5.0, 6.0]]])
|
| 82 |
+
input_ids = torch.tensor([[0, 3, 2]])
|
| 83 |
+
|
| 84 |
+
deltas = get_token_probability_deltas(logits, logits.clone(), input_ids)
|
| 85 |
+
|
| 86 |
+
for delta in deltas:
|
| 87 |
+
assert abs(delta) < 1e-5, f"Expected ~0 delta, got {delta}"
|
| 88 |
+
|
| 89 |
+
def test_delta_direction(self):
|
| 90 |
+
"""When ablation increases a token's probability, delta should be positive."""
|
| 91 |
+
# 3 positions to match 3 input_ids
|
| 92 |
+
logits_ref = torch.tensor([[[1.0, 0.0, 0.0, 0.0], # favors token 0
|
| 93 |
+
[1.0, 0.0, 0.0, 0.0], # favors token 0
|
| 94 |
+
[1.0, 0.0, 0.0, 0.0]]])
|
| 95 |
+
logits_abl = torch.tensor([[[0.0, 10.0, 0.0, 0.0], # favors token 1
|
| 96 |
+
[0.0, 10.0, 0.0, 0.0], # favors token 1
|
| 97 |
+
[0.0, 10.0, 0.0, 0.0]]])
|
| 98 |
+
input_ids = torch.tensor([[0, 1, 1]]) # Target tokens: 1, 1
|
| 99 |
+
|
| 100 |
+
deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
|
| 101 |
+
|
| 102 |
+
# Both deltas should be positive (ablation increased target prob)
|
| 103 |
+
for delta in deltas:
|
| 104 |
+
assert delta > 0, f"Expected positive delta, got {delta}"
|
| 105 |
+
|
| 106 |
+
def test_delta_range(self):
|
| 107 |
+
"""Deltas should be bounded by [-1, 1] since they're probability differences."""
|
| 108 |
+
# 3 positions to match input_ids
|
| 109 |
+
logits_ref = torch.tensor([[[100.0, -100.0, -100.0, -100.0],
|
| 110 |
+
[-100.0, 100.0, -100.0, -100.0],
|
| 111 |
+
[-100.0, -100.0, 100.0, -100.0]]])
|
| 112 |
+
logits_abl = torch.tensor([[[-100.0, 100.0, -100.0, -100.0],
|
| 113 |
+
[-100.0, -100.0, 100.0, -100.0],
|
| 114 |
+
[-100.0, -100.0, -100.0, 100.0]]])
|
| 115 |
+
input_ids = torch.tensor([[0, 0, 1]]) # Targets: 0, 1
|
| 116 |
+
|
| 117 |
+
deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
|
| 118 |
+
|
| 119 |
+
for delta in deltas:
|
| 120 |
+
assert -1.0 <= delta <= 1.0, f"Delta {delta} out of bounds"
|
tests/test_head_detection.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for utils/head_detection.py
|
| 3 |
+
|
| 4 |
+
Tests attention head categorization heuristics using synthetic attention matrices.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from utils.head_detection import (
|
| 11 |
+
compute_attention_entropy,
|
| 12 |
+
detect_previous_token_head,
|
| 13 |
+
detect_first_token_head,
|
| 14 |
+
detect_bow_head,
|
| 15 |
+
detect_syntactic_head,
|
| 16 |
+
categorize_attention_head,
|
| 17 |
+
categorize_all_heads,
|
| 18 |
+
format_categorization_summary,
|
| 19 |
+
HeadCategorizationConfig
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestComputeAttentionEntropy:
|
| 24 |
+
"""Tests for compute_attention_entropy function."""
|
| 25 |
+
|
| 26 |
+
def test_uniform_distribution_high_entropy(self):
|
| 27 |
+
"""Uniform attention should have high (near 1.0) normalized entropy."""
|
| 28 |
+
# 4 positions with equal attention
|
| 29 |
+
uniform = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
| 30 |
+
entropy = compute_attention_entropy(uniform)
|
| 31 |
+
|
| 32 |
+
# Normalized entropy should be close to 1.0 for uniform
|
| 33 |
+
assert 0.95 <= entropy <= 1.0, f"Expected ~1.0, got {entropy}"
|
| 34 |
+
|
| 35 |
+
def test_peaked_distribution_low_entropy(self):
|
| 36 |
+
"""Peaked attention should have low normalized entropy."""
|
| 37 |
+
# One position dominates
|
| 38 |
+
peaked = torch.tensor([0.97, 0.01, 0.01, 0.01])
|
| 39 |
+
entropy = compute_attention_entropy(peaked)
|
| 40 |
+
|
| 41 |
+
# Should be low entropy
|
| 42 |
+
assert entropy < 0.3, f"Expected low entropy, got {entropy}"
|
| 43 |
+
|
| 44 |
+
def test_entropy_bounds(self):
|
| 45 |
+
"""Entropy should always be between 0 and 1 (normalized)."""
|
| 46 |
+
test_cases = [
|
| 47 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0]), # Extreme peaked
|
| 48 |
+
torch.tensor([0.5, 0.5, 0.0, 0.0]), # Two positions
|
| 49 |
+
torch.tensor([0.25, 0.25, 0.25, 0.25]), # Uniform
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
for weights in test_cases:
|
| 53 |
+
entropy = compute_attention_entropy(weights)
|
| 54 |
+
assert 0.0 <= entropy <= 1.0, f"Entropy {entropy} out of bounds"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestDetectPreviousTokenHead:
|
| 58 |
+
"""Tests for detect_previous_token_head function."""
|
| 59 |
+
|
| 60 |
+
def test_detects_previous_token_pattern(self, previous_token_attention_matrix, default_head_config):
|
| 61 |
+
"""Should detect matrix with strong previous-token attention."""
|
| 62 |
+
is_prev, score = detect_previous_token_head(
|
| 63 |
+
previous_token_attention_matrix,
|
| 64 |
+
default_head_config
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
assert is_prev == True
|
| 68 |
+
assert score > 0.5, f"Expected high score, got {score}"
|
| 69 |
+
|
| 70 |
+
def test_rejects_uniform_attention(self, uniform_attention_matrix, default_head_config):
|
| 71 |
+
"""Should reject matrix with uniform attention."""
|
| 72 |
+
is_prev, score = detect_previous_token_head(
|
| 73 |
+
uniform_attention_matrix,
|
| 74 |
+
default_head_config
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
assert is_prev == False
|
| 78 |
+
assert score < 0.4, f"Expected low score, got {score}"
|
| 79 |
+
|
| 80 |
+
def test_short_sequence_returns_false(self, default_head_config):
|
| 81 |
+
"""Sequence shorter than min_seq_len should return False."""
|
| 82 |
+
short_matrix = torch.ones(2, 2) / 2
|
| 83 |
+
is_prev, score = detect_previous_token_head(short_matrix, default_head_config)
|
| 84 |
+
|
| 85 |
+
assert is_prev == False
|
| 86 |
+
assert score == 0.0
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TestDetectFirstTokenHead:
|
| 90 |
+
"""Tests for detect_first_token_head function."""
|
| 91 |
+
|
| 92 |
+
def test_detects_first_token_pattern(self, first_token_attention_matrix, default_head_config):
|
| 93 |
+
"""Should detect matrix with strong first-token attention."""
|
| 94 |
+
is_first, score = detect_first_token_head(
|
| 95 |
+
first_token_attention_matrix,
|
| 96 |
+
default_head_config
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
assert is_first == True
|
| 100 |
+
assert score > 0.5, f"Expected high score, got {score}"
|
| 101 |
+
|
| 102 |
+
def test_low_first_token_attention(self, default_head_config):
|
| 103 |
+
"""Matrix with low attention to first token should not be detected."""
|
| 104 |
+
# Create matrix where first token gets very little attention
|
| 105 |
+
# Use size 5 to be above min_seq_len and avoid overlap at [0,0]
|
| 106 |
+
size = 5
|
| 107 |
+
matrix = torch.zeros(size, size)
|
| 108 |
+
for i in range(size):
|
| 109 |
+
# Distribute attention: 5% to first token, 95% to last token
|
| 110 |
+
matrix[i, 0] = 0.05
|
| 111 |
+
matrix[i, -1] = 0.95
|
| 112 |
+
|
| 113 |
+
is_first, score = detect_first_token_head(matrix, default_head_config)
|
| 114 |
+
|
| 115 |
+
assert is_first == False
|
| 116 |
+
assert score < 0.25, f"Expected low score, got {score}"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TestDetectBowHead:
|
| 120 |
+
"""Tests for detect_bow_head (bag-of-words / diffuse attention)."""
|
| 121 |
+
|
| 122 |
+
def test_detects_uniform_as_bow(self, uniform_attention_matrix, default_head_config):
|
| 123 |
+
"""Uniform attention should be detected as BoW head."""
|
| 124 |
+
is_bow, score = detect_bow_head(uniform_attention_matrix, default_head_config)
|
| 125 |
+
|
| 126 |
+
# Uniform has high entropy and low max attention - should be BoW
|
| 127 |
+
assert is_bow == True
|
| 128 |
+
assert score > 0.9, f"Expected high entropy score, got {score}"
|
| 129 |
+
|
| 130 |
+
def test_rejects_peaked_attention(self, peaked_attention_matrix, default_head_config):
|
| 131 |
+
"""Peaked attention should not be detected as BoW."""
|
| 132 |
+
is_bow, score = detect_bow_head(peaked_attention_matrix, default_head_config)
|
| 133 |
+
|
| 134 |
+
# Peaked attention has low entropy - should not be BoW
|
| 135 |
+
assert is_bow == False
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TestDetectSyntacticHead:
|
| 139 |
+
"""Tests for detect_syntactic_head function."""
|
| 140 |
+
|
| 141 |
+
def test_consistent_distance_pattern(self, default_head_config):
|
| 142 |
+
"""Matrix with consistent distance pattern should be detected as syntactic."""
|
| 143 |
+
# Create matrix where each position attends to position 2 tokens back
|
| 144 |
+
size = 6
|
| 145 |
+
matrix = torch.zeros(size, size)
|
| 146 |
+
for i in range(size):
|
| 147 |
+
target = max(0, i - 2) # 2 tokens back
|
| 148 |
+
matrix[i, target] = 1.0
|
| 149 |
+
|
| 150 |
+
is_syn, score = detect_syntactic_head(matrix, default_head_config)
|
| 151 |
+
|
| 152 |
+
# Should have consistent distance pattern
|
| 153 |
+
assert score > 0.0, f"Expected positive score for consistent pattern"
|
| 154 |
+
|
| 155 |
+
def test_random_attention_returns_valid_values(self, default_head_config):
|
| 156 |
+
"""Random attention should return valid boolean and score."""
|
| 157 |
+
torch.manual_seed(42)
|
| 158 |
+
random_matrix = torch.softmax(torch.randn(6, 6), dim=-1)
|
| 159 |
+
|
| 160 |
+
is_syn, score = detect_syntactic_head(random_matrix, default_head_config)
|
| 161 |
+
|
| 162 |
+
# Check it returns valid types (bool or numpy bool, and numeric score)
|
| 163 |
+
assert is_syn in [True, False] or bool(is_syn) in [True, False]
|
| 164 |
+
assert 0 <= float(score) <= 1
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TestCategorizeAttentionHead:
|
| 168 |
+
"""Tests for categorize_attention_head function."""
|
| 169 |
+
|
| 170 |
+
def test_categorizes_previous_token_head(self, previous_token_attention_matrix, default_head_config):
|
| 171 |
+
"""Should categorize previous-token pattern correctly."""
|
| 172 |
+
result = categorize_attention_head(
|
| 173 |
+
previous_token_attention_matrix,
|
| 174 |
+
layer_idx=0,
|
| 175 |
+
head_idx=3,
|
| 176 |
+
config=default_head_config
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
assert result['category'] == 'previous_token'
|
| 180 |
+
assert result['layer'] == 0
|
| 181 |
+
assert result['head'] == 3
|
| 182 |
+
assert result['label'] == 'L0-H3'
|
| 183 |
+
assert 'scores' in result
|
| 184 |
+
|
| 185 |
+
def test_categorizes_first_token_head(self, first_token_attention_matrix, default_head_config):
|
| 186 |
+
"""Should categorize first-token pattern correctly."""
|
| 187 |
+
result = categorize_attention_head(
|
| 188 |
+
first_token_attention_matrix,
|
| 189 |
+
layer_idx=2,
|
| 190 |
+
head_idx=5,
|
| 191 |
+
config=default_head_config
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
assert result['category'] == 'first_token'
|
| 195 |
+
assert result['label'] == 'L2-H5'
|
| 196 |
+
|
| 197 |
+
def test_categorizes_bow_head(self, default_head_config):
|
| 198 |
+
"""Should categorize diffuse attention as BoW when it doesn't match other patterns."""
|
| 199 |
+
# Create BoW-like matrix: diffuse attention but first token gets LESS than threshold
|
| 200 |
+
# This avoids triggering first_token detection (threshold 0.25)
|
| 201 |
+
size = 5
|
| 202 |
+
matrix = torch.zeros(size, size)
|
| 203 |
+
for i in range(size):
|
| 204 |
+
# First token gets only 0.1, rest get roughly equal share
|
| 205 |
+
matrix[i, 0] = 0.1
|
| 206 |
+
remaining = 0.9 / (size - 1)
|
| 207 |
+
for j in range(1, size):
|
| 208 |
+
matrix[i, j] = remaining
|
| 209 |
+
|
| 210 |
+
result = categorize_attention_head(
|
| 211 |
+
matrix,
|
| 212 |
+
layer_idx=1,
|
| 213 |
+
head_idx=0,
|
| 214 |
+
config=default_head_config
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
assert result['category'] == 'bow'
|
| 218 |
+
|
| 219 |
+
def test_result_structure(self, uniform_attention_matrix):
|
| 220 |
+
"""Result should have all required keys."""
|
| 221 |
+
result = categorize_attention_head(
|
| 222 |
+
uniform_attention_matrix,
|
| 223 |
+
layer_idx=0,
|
| 224 |
+
head_idx=0
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
required_keys = ['layer', 'head', 'category', 'scores', 'label']
|
| 228 |
+
for key in required_keys:
|
| 229 |
+
assert key in result, f"Missing key: {key}"
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class TestCategorizeAllHeads:
|
| 233 |
+
"""Tests for categorize_all_heads function."""
|
| 234 |
+
|
| 235 |
+
def test_returns_all_categories(self, mock_activation_data, default_head_config):
|
| 236 |
+
"""Should return dict with all category keys."""
|
| 237 |
+
result = categorize_all_heads(mock_activation_data, default_head_config)
|
| 238 |
+
|
| 239 |
+
expected_categories = ['previous_token', 'first_token', 'bow', 'syntactic', 'other']
|
| 240 |
+
for cat in expected_categories:
|
| 241 |
+
assert cat in result, f"Missing category: {cat}"
|
| 242 |
+
assert isinstance(result[cat], list)
|
| 243 |
+
|
| 244 |
+
def test_handles_empty_attention_data(self, default_head_config):
|
| 245 |
+
"""Should handle activation data with no attention outputs."""
|
| 246 |
+
empty_data = {'attention_outputs': {}}
|
| 247 |
+
result = categorize_all_heads(empty_data, default_head_config)
|
| 248 |
+
|
| 249 |
+
# Should return empty lists for all categories
|
| 250 |
+
for cat, heads in result.items():
|
| 251 |
+
assert heads == []
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class TestFormatCategorizationSummary:
|
| 255 |
+
"""Tests for format_categorization_summary function."""
|
| 256 |
+
|
| 257 |
+
def test_formats_empty_categorization(self):
|
| 258 |
+
"""Should format empty categorization without error."""
|
| 259 |
+
empty = {
|
| 260 |
+
'previous_token': [],
|
| 261 |
+
'first_token': [],
|
| 262 |
+
'bow': [],
|
| 263 |
+
'syntactic': [],
|
| 264 |
+
'other': []
|
| 265 |
+
}
|
| 266 |
+
result = format_categorization_summary(empty)
|
| 267 |
+
|
| 268 |
+
assert isinstance(result, str)
|
| 269 |
+
assert "Total Heads: 0" in result
|
| 270 |
+
|
| 271 |
+
def test_formats_with_heads(self):
|
| 272 |
+
"""Should format categorization with heads correctly."""
|
| 273 |
+
categorized = {
|
| 274 |
+
'previous_token': [
|
| 275 |
+
{'layer': 0, 'head': 1, 'label': 'L0-H1'},
|
| 276 |
+
{'layer': 0, 'head': 2, 'label': 'L0-H2'},
|
| 277 |
+
],
|
| 278 |
+
'first_token': [
|
| 279 |
+
{'layer': 1, 'head': 0, 'label': 'L1-H0'},
|
| 280 |
+
],
|
| 281 |
+
'bow': [],
|
| 282 |
+
'syntactic': [],
|
| 283 |
+
'other': []
|
| 284 |
+
}
|
| 285 |
+
result = format_categorization_summary(categorized)
|
| 286 |
+
|
| 287 |
+
assert "Total Heads: 3" in result
|
| 288 |
+
assert "Previous-Token Heads: 2" in result
|
| 289 |
+
assert "First/Positional-Token Heads: 1" in result
|
| 290 |
+
assert "Layer 0" in result
|
| 291 |
+
assert "Layer 1" in result
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class TestHeadCategorizationConfig:
|
| 295 |
+
"""Tests for HeadCategorizationConfig defaults."""
|
| 296 |
+
|
| 297 |
+
def test_default_values(self):
|
| 298 |
+
"""Default config should have reasonable values."""
|
| 299 |
+
config = HeadCategorizationConfig()
|
| 300 |
+
|
| 301 |
+
assert 0 < config.prev_token_threshold < 1
|
| 302 |
+
assert 0 < config.first_token_threshold < 1
|
| 303 |
+
assert 0 < config.bow_entropy_threshold < 1
|
| 304 |
+
assert config.min_seq_len > 0
|
| 305 |
+
|
| 306 |
+
def test_config_is_mutable(self):
|
| 307 |
+
"""Config values should be mutable for customization."""
|
| 308 |
+
config = HeadCategorizationConfig()
|
| 309 |
+
original = config.prev_token_threshold
|
| 310 |
+
|
| 311 |
+
config.prev_token_threshold = 0.8
|
| 312 |
+
assert config.prev_token_threshold == 0.8
|
| 313 |
+
assert config.prev_token_threshold != original
|
tests/test_model_config.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for utils/model_config.py
|
| 3 |
+
|
| 4 |
+
Tests model family lookups, configuration retrieval, and auto-selection logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from utils.model_config import (
|
| 9 |
+
get_model_family,
|
| 10 |
+
get_family_config,
|
| 11 |
+
get_auto_selections,
|
| 12 |
+
_pattern_matches_template,
|
| 13 |
+
MODEL_TO_FAMILY,
|
| 14 |
+
MODEL_FAMILIES
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestGetModelFamily:
|
| 19 |
+
"""Tests for get_model_family function."""
|
| 20 |
+
|
| 21 |
+
def test_known_gpt2_model(self):
|
| 22 |
+
"""Known GPT-2 model should return 'gpt2' family."""
|
| 23 |
+
assert get_model_family("gpt2") == "gpt2"
|
| 24 |
+
assert get_model_family("gpt2-medium") == "gpt2"
|
| 25 |
+
assert get_model_family("openai-community/gpt2") == "gpt2"
|
| 26 |
+
|
| 27 |
+
def test_known_llama_model(self):
|
| 28 |
+
"""Known LLaMA-like models should return 'llama_like' family."""
|
| 29 |
+
assert get_model_family("Qwen/Qwen2.5-0.5B") == "llama_like"
|
| 30 |
+
assert get_model_family("meta-llama/Llama-2-7b-hf") == "llama_like"
|
| 31 |
+
assert get_model_family("mistralai/Mistral-7B-v0.1") == "llama_like"
|
| 32 |
+
|
| 33 |
+
def test_known_opt_model(self):
|
| 34 |
+
"""Known OPT models should return 'opt' family."""
|
| 35 |
+
assert get_model_family("facebook/opt-125m") == "opt"
|
| 36 |
+
assert get_model_family("facebook/opt-1.3b") == "opt"
|
| 37 |
+
|
| 38 |
+
def test_unknown_model_returns_none(self):
|
| 39 |
+
"""Unknown models should return None."""
|
| 40 |
+
assert get_model_family("unknown/model-name") is None
|
| 41 |
+
assert get_model_family("random-string") is None
|
| 42 |
+
assert get_model_family("") is None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestGetFamilyConfig:
|
| 46 |
+
"""Tests for get_family_config function."""
|
| 47 |
+
|
| 48 |
+
def test_valid_gpt2_config(self):
|
| 49 |
+
"""GPT-2 family config should have correct structure."""
|
| 50 |
+
config = get_family_config("gpt2")
|
| 51 |
+
assert config is not None
|
| 52 |
+
assert "templates" in config
|
| 53 |
+
assert "attention_pattern" in config["templates"]
|
| 54 |
+
assert config["templates"]["attention_pattern"] == "transformer.h.{N}.attn"
|
| 55 |
+
assert config["norm_type"] == "layernorm"
|
| 56 |
+
|
| 57 |
+
def test_valid_llama_config(self):
|
| 58 |
+
"""LLaMA-like family config should have correct structure."""
|
| 59 |
+
config = get_family_config("llama_like")
|
| 60 |
+
assert config is not None
|
| 61 |
+
assert config["templates"]["attention_pattern"] == "model.layers.{N}.self_attn"
|
| 62 |
+
assert config["norm_type"] == "rmsnorm"
|
| 63 |
+
assert config["norm_parameter"] == "model.norm.weight"
|
| 64 |
+
|
| 65 |
+
def test_invalid_family_returns_none(self):
|
| 66 |
+
"""Invalid family name should return None."""
|
| 67 |
+
assert get_family_config("invalid_family") is None
|
| 68 |
+
assert get_family_config("") is None
|
| 69 |
+
assert get_family_config("GPT2") is None # Case-sensitive
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TestPatternMatchesTemplate:
|
| 73 |
+
"""Tests for _pattern_matches_template function."""
|
| 74 |
+
|
| 75 |
+
def test_exact_match(self):
|
| 76 |
+
"""Pattern that exactly matches template should return True."""
|
| 77 |
+
assert _pattern_matches_template(
|
| 78 |
+
"model.layers.{N}.self_attn",
|
| 79 |
+
"model.layers.{N}.self_attn"
|
| 80 |
+
) is True
|
| 81 |
+
|
| 82 |
+
def test_matching_with_n_placeholder(self):
|
| 83 |
+
"""Patterns with {N} placeholder should match correctly."""
|
| 84 |
+
assert _pattern_matches_template(
|
| 85 |
+
"transformer.h.{N}.attn",
|
| 86 |
+
"transformer.h.{N}.attn"
|
| 87 |
+
) is True
|
| 88 |
+
|
| 89 |
+
def test_non_matching_pattern(self):
|
| 90 |
+
"""Different patterns should not match."""
|
| 91 |
+
assert _pattern_matches_template(
|
| 92 |
+
"model.layers.{N}.self_attn",
|
| 93 |
+
"transformer.h.{N}.attn"
|
| 94 |
+
) is False
|
| 95 |
+
|
| 96 |
+
def test_empty_template_returns_false(self):
|
| 97 |
+
"""Empty template should return False."""
|
| 98 |
+
assert _pattern_matches_template("model.layers.{N}.self_attn", "") is False
|
| 99 |
+
assert _pattern_matches_template("", "") is False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class TestGetAutoSelections:
|
| 103 |
+
"""Tests for get_auto_selections function."""
|
| 104 |
+
|
| 105 |
+
def test_unknown_model_returns_empty_selections(self):
|
| 106 |
+
"""Unknown model should return empty selections."""
|
| 107 |
+
result = get_auto_selections(
|
| 108 |
+
"unknown/model",
|
| 109 |
+
{"model.layers.{N}.self_attn": ["model.layers.0.self_attn"]},
|
| 110 |
+
{"model.norm.weight": ["model.norm.weight"]}
|
| 111 |
+
)
|
| 112 |
+
assert result["attention_selection"] == []
|
| 113 |
+
assert result["block_selection"] == []
|
| 114 |
+
assert result["norm_selection"] == []
|
| 115 |
+
assert result["family_name"] is None
|
| 116 |
+
|
| 117 |
+
def test_known_model_matches_patterns(self, mock_module_patterns, mock_param_patterns):
|
| 118 |
+
"""Known model should match appropriate patterns."""
|
| 119 |
+
result = get_auto_selections(
|
| 120 |
+
"Qwen/Qwen2.5-0.5B", # llama_like family
|
| 121 |
+
mock_module_patterns,
|
| 122 |
+
mock_param_patterns
|
| 123 |
+
)
|
| 124 |
+
assert result["family_name"] == "llama_like"
|
| 125 |
+
# Should find self_attn pattern
|
| 126 |
+
assert "model.layers.{N}.self_attn" in result["attention_selection"]
|
| 127 |
+
# Should find block pattern
|
| 128 |
+
assert "model.layers.{N}" in result["block_selection"]
|
| 129 |
+
# Should find norm pattern
|
| 130 |
+
assert result["norm_selection"] == ["model.norm.weight"]
|
| 131 |
+
|
| 132 |
+
def test_result_structure(self, mock_module_patterns, mock_param_patterns):
|
| 133 |
+
"""Result should have all required keys."""
|
| 134 |
+
result = get_auto_selections(
|
| 135 |
+
"gpt2",
|
| 136 |
+
{}, # Empty patterns - no matches expected
|
| 137 |
+
{}
|
| 138 |
+
)
|
| 139 |
+
assert "attention_selection" in result
|
| 140 |
+
assert "block_selection" in result
|
| 141 |
+
assert "norm_selection" in result
|
| 142 |
+
assert "family_name" in result
|
| 143 |
+
assert isinstance(result["attention_selection"], list)
|
| 144 |
+
assert isinstance(result["norm_selection"], list)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TestModelRegistryIntegrity:
|
| 148 |
+
"""Tests to verify the model registry data is consistent."""
|
| 149 |
+
|
| 150 |
+
def test_all_families_have_required_fields(self):
|
| 151 |
+
"""All model families should have required configuration fields."""
|
| 152 |
+
required_fields = ["description", "templates", "norm_type"]
|
| 153 |
+
for family_name, config in MODEL_FAMILIES.items():
|
| 154 |
+
for field in required_fields:
|
| 155 |
+
assert field in config, f"Family {family_name} missing {field}"
|
| 156 |
+
|
| 157 |
+
def test_all_mapped_families_exist(self):
|
| 158 |
+
"""All families referenced in MODEL_TO_FAMILY should exist in MODEL_FAMILIES."""
|
| 159 |
+
for model_name, family_name in MODEL_TO_FAMILY.items():
|
| 160 |
+
assert family_name in MODEL_FAMILIES, \
|
| 161 |
+
f"Model {model_name} references unknown family {family_name}"
|
tests/test_model_patterns.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for utils/model_patterns.py
|
| 3 |
+
|
| 4 |
+
Tests pure logic functions that don't require model loading:
|
| 5 |
+
- merge_token_probabilities
|
| 6 |
+
- safe_to_serializable
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from utils.model_patterns import merge_token_probabilities, safe_to_serializable
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestMergeTokenProbabilities:
|
| 16 |
+
"""Tests for merge_token_probabilities function."""
|
| 17 |
+
|
| 18 |
+
def test_merges_tokens_with_leading_space(self):
|
| 19 |
+
"""Tokens with and without leading space should be merged."""
|
| 20 |
+
token_probs = [
|
| 21 |
+
(" cat", 0.15),
|
| 22 |
+
("cat", 0.05),
|
| 23 |
+
(" dog", 0.10),
|
| 24 |
+
]
|
| 25 |
+
result = merge_token_probabilities(token_probs)
|
| 26 |
+
|
| 27 |
+
# Convert to dict for easier checking
|
| 28 |
+
result_dict = dict(result)
|
| 29 |
+
|
| 30 |
+
assert "cat" in result_dict
|
| 31 |
+
assert abs(result_dict["cat"] - 0.20) < 1e-6 # 0.15 + 0.05
|
| 32 |
+
assert "dog" in result_dict
|
| 33 |
+
assert abs(result_dict["dog"] - 0.10) < 1e-6
|
| 34 |
+
|
| 35 |
+
def test_sorts_by_probability_descending(self):
|
| 36 |
+
"""Results should be sorted by probability (highest first)."""
|
| 37 |
+
token_probs = [
|
| 38 |
+
("low", 0.01),
|
| 39 |
+
("high", 0.50),
|
| 40 |
+
("medium", 0.20),
|
| 41 |
+
]
|
| 42 |
+
result = merge_token_probabilities(token_probs)
|
| 43 |
+
|
| 44 |
+
# Check order: high, medium, low
|
| 45 |
+
assert result[0][0] == "high"
|
| 46 |
+
assert result[1][0] == "medium"
|
| 47 |
+
assert result[2][0] == "low"
|
| 48 |
+
|
| 49 |
+
def test_handles_empty_input(self):
|
| 50 |
+
"""Empty input should return empty list."""
|
| 51 |
+
result = merge_token_probabilities([])
|
| 52 |
+
assert result == []
|
| 53 |
+
|
| 54 |
+
def test_handles_single_token(self):
|
| 55 |
+
"""Single token should be returned as-is (stripped)."""
|
| 56 |
+
result = merge_token_probabilities([(" hello", 0.5)])
|
| 57 |
+
|
| 58 |
+
assert len(result) == 1
|
| 59 |
+
assert result[0][0] == "hello"
|
| 60 |
+
assert result[0][1] == 0.5
|
| 61 |
+
|
| 62 |
+
def test_strips_multiple_spaces(self):
|
| 63 |
+
"""Multiple leading spaces should all be stripped."""
|
| 64 |
+
token_probs = [
|
| 65 |
+
(" word", 0.3), # Two spaces
|
| 66 |
+
(" word", 0.2), # One space
|
| 67 |
+
("word", 0.1), # No space
|
| 68 |
+
]
|
| 69 |
+
result = merge_token_probabilities(token_probs)
|
| 70 |
+
|
| 71 |
+
result_dict = dict(result)
|
| 72 |
+
assert "word" in result_dict
|
| 73 |
+
assert abs(result_dict["word"] - 0.6) < 1e-6 # All merged
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestSafeToSerializable:
|
| 77 |
+
"""Tests for safe_to_serializable function."""
|
| 78 |
+
|
| 79 |
+
def test_converts_tensor_to_list(self):
|
| 80 |
+
"""PyTorch tensor should be converted to Python list."""
|
| 81 |
+
tensor = torch.tensor([1.0, 2.0, 3.0])
|
| 82 |
+
result = safe_to_serializable(tensor)
|
| 83 |
+
|
| 84 |
+
assert isinstance(result, list)
|
| 85 |
+
assert result == [1.0, 2.0, 3.0]
|
| 86 |
+
|
| 87 |
+
def test_converts_nested_tensor(self):
|
| 88 |
+
"""2D tensor should become nested list."""
|
| 89 |
+
tensor = torch.tensor([[1, 2], [3, 4]])
|
| 90 |
+
result = safe_to_serializable(tensor)
|
| 91 |
+
|
| 92 |
+
assert isinstance(result, list)
|
| 93 |
+
assert result == [[1, 2], [3, 4]]
|
| 94 |
+
|
| 95 |
+
def test_converts_list_of_tensors(self):
|
| 96 |
+
"""List containing tensors should have tensors converted."""
|
| 97 |
+
data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
| 98 |
+
result = safe_to_serializable(data)
|
| 99 |
+
|
| 100 |
+
assert result == [[1, 2], [3, 4]]
|
| 101 |
+
|
| 102 |
+
def test_converts_dict_with_tensor_values(self):
|
| 103 |
+
"""Dict with tensor values should have values converted."""
|
| 104 |
+
data = {
|
| 105 |
+
"a": torch.tensor([1.0, 2.0]),
|
| 106 |
+
"b": "string_value",
|
| 107 |
+
"c": 42
|
| 108 |
+
}
|
| 109 |
+
result = safe_to_serializable(data)
|
| 110 |
+
|
| 111 |
+
assert result["a"] == [1.0, 2.0]
|
| 112 |
+
assert result["b"] == "string_value"
|
| 113 |
+
assert result["c"] == 42
|
| 114 |
+
|
| 115 |
+
def test_handles_tuple_input(self):
|
| 116 |
+
"""Tuple with tensors should be converted to list."""
|
| 117 |
+
data = (torch.tensor([1]), torch.tensor([2]))
|
| 118 |
+
result = safe_to_serializable(data)
|
| 119 |
+
|
| 120 |
+
assert isinstance(result, list)
|
| 121 |
+
assert result == [[1], [2]]
|
| 122 |
+
|
| 123 |
+
def test_passes_through_primitives(self):
|
| 124 |
+
"""Primitive types should pass through unchanged."""
|
| 125 |
+
assert safe_to_serializable(42) == 42
|
| 126 |
+
assert safe_to_serializable(3.14) == 3.14
|
| 127 |
+
assert safe_to_serializable("hello") == "hello"
|
| 128 |
+
assert safe_to_serializable(None) is None
|
| 129 |
+
assert safe_to_serializable(True) is True
|
| 130 |
+
|
| 131 |
+
def test_handles_deeply_nested_structure(self):
|
| 132 |
+
"""Should handle deeply nested structures with tensors."""
|
| 133 |
+
data = {
|
| 134 |
+
"level1": {
|
| 135 |
+
"level2": {
|
| 136 |
+
"tensor": torch.tensor([1, 2, 3])
|
| 137 |
+
}
|
| 138 |
+
},
|
| 139 |
+
"list": [torch.tensor([4, 5])]
|
| 140 |
+
}
|
| 141 |
+
result = safe_to_serializable(data)
|
| 142 |
+
|
| 143 |
+
assert result["level1"]["level2"]["tensor"] == [1, 2, 3]
|
| 144 |
+
assert result["list"] == [[4, 5]]
|
| 145 |
+
|
| 146 |
+
def test_handles_empty_containers(self):
|
| 147 |
+
"""Empty lists, dicts, tuples should remain empty."""
|
| 148 |
+
assert safe_to_serializable([]) == []
|
| 149 |
+
assert safe_to_serializable({}) == {}
|
| 150 |
+
assert safe_to_serializable(()) == [] # Tuple becomes list
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TestSafeToSerializableEdgeCases:
|
| 154 |
+
"""Edge case tests for safe_to_serializable."""
|
| 155 |
+
|
| 156 |
+
def test_handles_scalar_tensor(self):
|
| 157 |
+
"""Scalar tensor should become a Python scalar."""
|
| 158 |
+
scalar = torch.tensor(42.0)
|
| 159 |
+
result = safe_to_serializable(scalar)
|
| 160 |
+
|
| 161 |
+
# Scalar tensor.tolist() returns a Python number
|
| 162 |
+
assert result == 42.0
|
| 163 |
+
|
| 164 |
+
def test_handles_integer_tensor(self):
|
| 165 |
+
"""Integer tensor should be converted correctly."""
|
| 166 |
+
tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
|
| 167 |
+
result = safe_to_serializable(tensor)
|
| 168 |
+
|
| 169 |
+
assert result == [1, 2, 3]
|
| 170 |
+
assert all(isinstance(x, int) for x in result)
|
| 171 |
+
|
| 172 |
+
def test_handles_mixed_list(self):
|
| 173 |
+
"""List with mixed tensor and non-tensor items should work."""
|
| 174 |
+
data = [torch.tensor([1]), "string", 42, {"key": torch.tensor([2])}]
|
| 175 |
+
result = safe_to_serializable(data)
|
| 176 |
+
|
| 177 |
+
assert result[0] == [1]
|
| 178 |
+
assert result[1] == "string"
|
| 179 |
+
assert result[2] == 42
|
| 180 |
+
assert result[3] == {"key": [2]}
|
tests/test_token_attribution.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for utils/token_attribution.py
|
| 3 |
+
|
| 4 |
+
Tests the visualization data formatting function (pure logic).
|
| 5 |
+
The gradient computation functions require models and are not tested here.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
from utils.token_attribution import create_attribution_visualization_data
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestCreateAttributionVisualizationData:
|
| 13 |
+
"""Tests for create_attribution_visualization_data function."""
|
| 14 |
+
|
| 15 |
+
def test_returns_correct_structure(self, mock_attribution_result):
|
| 16 |
+
"""Should return list of dicts with required keys."""
|
| 17 |
+
result = create_attribution_visualization_data(mock_attribution_result)
|
| 18 |
+
|
| 19 |
+
assert isinstance(result, list)
|
| 20 |
+
assert len(result) == 3 # 3 tokens in mock data
|
| 21 |
+
|
| 22 |
+
required_keys = ['token', 'index', 'attribution', 'normalized', 'color', 'text_color']
|
| 23 |
+
for item in result:
|
| 24 |
+
for key in required_keys:
|
| 25 |
+
assert key in item, f"Missing key: {key}"
|
| 26 |
+
|
| 27 |
+
def test_preserves_token_order(self, mock_attribution_result):
|
| 28 |
+
"""Tokens should be in same order as input."""
|
| 29 |
+
result = create_attribution_visualization_data(mock_attribution_result)
|
| 30 |
+
|
| 31 |
+
assert result[0]['token'] == 'Hello'
|
| 32 |
+
assert result[1]['token'] == ' world'
|
| 33 |
+
assert result[2]['token'] == '!'
|
| 34 |
+
|
| 35 |
+
assert result[0]['index'] == 0
|
| 36 |
+
assert result[1]['index'] == 1
|
| 37 |
+
assert result[2]['index'] == 2
|
| 38 |
+
|
| 39 |
+
def test_preserves_attribution_values(self, mock_attribution_result):
|
| 40 |
+
"""Raw attribution values should be preserved."""
|
| 41 |
+
result = create_attribution_visualization_data(mock_attribution_result)
|
| 42 |
+
|
| 43 |
+
assert result[0]['attribution'] == 0.5
|
| 44 |
+
assert result[1]['attribution'] == 1.0
|
| 45 |
+
assert result[2]['attribution'] == 0.2
|
| 46 |
+
|
| 47 |
+
def test_color_format(self, mock_attribution_result):
|
| 48 |
+
"""Colors should be valid RGB format."""
|
| 49 |
+
result = create_attribution_visualization_data(mock_attribution_result)
|
| 50 |
+
|
| 51 |
+
for item in result:
|
| 52 |
+
color = item['color']
|
| 53 |
+
assert color.startswith('rgb(')
|
| 54 |
+
assert color.endswith(')')
|
| 55 |
+
# Extract RGB values
|
| 56 |
+
rgb_str = color[4:-1]
|
| 57 |
+
r, g, b = [int(x) for x in rgb_str.split(',')]
|
| 58 |
+
assert 0 <= r <= 255
|
| 59 |
+
assert 0 <= g <= 255
|
| 60 |
+
assert 0 <= b <= 255
|
| 61 |
+
|
| 62 |
+
def test_text_color_contrast(self, mock_attribution_result):
|
| 63 |
+
"""Text color should be black or white for contrast."""
|
| 64 |
+
result = create_attribution_visualization_data(mock_attribution_result)
|
| 65 |
+
|
| 66 |
+
for item in result:
|
| 67 |
+
assert item['text_color'] in ['#000000', '#ffffff']
|
| 68 |
+
|
| 69 |
+
def test_high_attribution_gets_color(self):
|
| 70 |
+
"""High attribution should result in colored background."""
|
| 71 |
+
data = {
|
| 72 |
+
'tokens': ['high'],
|
| 73 |
+
'token_ids': [1],
|
| 74 |
+
'attributions': [1.0], # Maximum positive attribution
|
| 75 |
+
'normalized_attributions': [1.0],
|
| 76 |
+
'target_token': 'x',
|
| 77 |
+
'target_token_id': 100
|
| 78 |
+
}
|
| 79 |
+
result = create_attribution_visualization_data(data)
|
| 80 |
+
|
| 81 |
+
# High positive attribution should have red-ish color (r=255)
|
| 82 |
+
color = result[0]['color']
|
| 83 |
+
rgb_str = color[4:-1]
|
| 84 |
+
r, g, b = [int(x) for x in rgb_str.split(',')]
|
| 85 |
+
|
| 86 |
+
# Red should be at max, green/blue should be reduced
|
| 87 |
+
assert r == 255
|
| 88 |
+
assert g < 255 # Reduced for visibility
|
| 89 |
+
assert b < 255
|
| 90 |
+
|
| 91 |
+
def test_handles_zero_attributions(self):
|
| 92 |
+
"""Zero attributions should produce neutral colors."""
|
| 93 |
+
data = {
|
| 94 |
+
'tokens': ['zero'],
|
| 95 |
+
'token_ids': [1],
|
| 96 |
+
'attributions': [0.0],
|
| 97 |
+
'normalized_attributions': [0.0],
|
| 98 |
+
'target_token': 'x',
|
| 99 |
+
'target_token_id': 100
|
| 100 |
+
}
|
| 101 |
+
result = create_attribution_visualization_data(data)
|
| 102 |
+
|
| 103 |
+
# Zero normalized attribution should give white-ish color
|
| 104 |
+
color = result[0]['color']
|
| 105 |
+
rgb_str = color[4:-1]
|
| 106 |
+
r, g, b = [int(x) for x in rgb_str.split(',')]
|
| 107 |
+
|
| 108 |
+
# All components should be high (near white)
|
| 109 |
+
assert r == 255
|
| 110 |
+
assert g == 255
|
| 111 |
+
assert b == 255
|
| 112 |
+
|
| 113 |
+
def test_handles_negative_attributions(self):
|
| 114 |
+
"""Negative attributions should get blue-ish color."""
|
| 115 |
+
data = {
|
| 116 |
+
'tokens': ['negative'],
|
| 117 |
+
'token_ids': [1],
|
| 118 |
+
'attributions': [-1.0], # Negative attribution
|
| 119 |
+
'normalized_attributions': [1.0], # Abs normalized
|
| 120 |
+
'target_token': 'x',
|
| 121 |
+
'target_token_id': 100
|
| 122 |
+
}
|
| 123 |
+
result = create_attribution_visualization_data(data)
|
| 124 |
+
|
| 125 |
+
# Negative attribution should have blue-ish color
|
| 126 |
+
color = result[0]['color']
|
| 127 |
+
rgb_str = color[4:-1]
|
| 128 |
+
r, g, b = [int(x) for x in rgb_str.split(',')]
|
| 129 |
+
|
| 130 |
+
# Blue should be at max, red/green should be reduced
|
| 131 |
+
assert b == 255
|
| 132 |
+
assert r < 255
|
| 133 |
+
assert g < 255
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TestAttributionVisualizationEdgeCases:
|
| 137 |
+
"""Edge case tests for create_attribution_visualization_data."""
|
| 138 |
+
|
| 139 |
+
def test_handles_single_token(self):
|
| 140 |
+
"""Should handle single token input."""
|
| 141 |
+
data = {
|
| 142 |
+
'tokens': ['only'],
|
| 143 |
+
'token_ids': [1],
|
| 144 |
+
'attributions': [0.5],
|
| 145 |
+
'normalized_attributions': [1.0], # Normalized to max
|
| 146 |
+
'target_token': 'x',
|
| 147 |
+
'target_token_id': 100
|
| 148 |
+
}
|
| 149 |
+
result = create_attribution_visualization_data(data)
|
| 150 |
+
|
| 151 |
+
assert len(result) == 1
|
| 152 |
+
assert result[0]['token'] == 'only'
|
| 153 |
+
|
| 154 |
+
def test_handles_empty_input(self):
|
| 155 |
+
"""Should handle empty token list."""
|
| 156 |
+
data = {
|
| 157 |
+
'tokens': [],
|
| 158 |
+
'token_ids': [],
|
| 159 |
+
'attributions': [],
|
| 160 |
+
'normalized_attributions': [],
|
| 161 |
+
'target_token': 'x',
|
| 162 |
+
'target_token_id': 100
|
| 163 |
+
}
|
| 164 |
+
result = create_attribution_visualization_data(data)
|
| 165 |
+
|
| 166 |
+
assert result == []
|
| 167 |
+
|
| 168 |
+
def test_handles_special_characters_in_tokens(self):
|
| 169 |
+
"""Should handle tokens with special characters."""
|
| 170 |
+
data = {
|
| 171 |
+
'tokens': ['<s>', '</s>', '\n', ' '],
|
| 172 |
+
'token_ids': [1, 2, 3, 4],
|
| 173 |
+
'attributions': [0.1, 0.2, 0.3, 0.4],
|
| 174 |
+
'normalized_attributions': [0.25, 0.5, 0.75, 1.0],
|
| 175 |
+
'target_token': 'x',
|
| 176 |
+
'target_token_id': 100
|
| 177 |
+
}
|
| 178 |
+
result = create_attribution_visualization_data(data)
|
| 179 |
+
|
| 180 |
+
assert len(result) == 4
|
| 181 |
+
assert result[0]['token'] == '<s>'
|
| 182 |
+
assert result[2]['token'] == '\n'
|
todo.md
CHANGED
|
@@ -1,5 +1,14 @@
|
|
| 1 |
# Todo
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
## Completed: Pipeline Explanation Refactor
|
| 4 |
|
| 5 |
### Phase 1: New Components (Done)
|
|
@@ -16,19 +25,3 @@
|
|
| 16 |
- [x] Delete `prompt_comparison.py`
|
| 17 |
- [x] Update `utils/__init__.py` exports
|
| 18 |
- [x] Add pipeline CSS styles to `assets/style.css`
|
| 19 |
-
|
| 20 |
-
---
|
| 21 |
-
|
| 22 |
-
## Next Steps
|
| 23 |
-
|
| 24 |
-
### Testing
|
| 25 |
-
- [ ] Run the dashboard and verify all pipeline stages render correctly
|
| 26 |
-
- [ ] Test ablation experiment workflow
|
| 27 |
-
- [ ] Test token attribution (both methods)
|
| 28 |
-
- [ ] Verify beam search still works with multi-token generation
|
| 29 |
-
|
| 30 |
-
### Enhancements (Optional)
|
| 31 |
-
- [ ] Add loading spinners to investigation tools
|
| 32 |
-
- [ ] Improve attention visualization formatting
|
| 33 |
-
- [ ] Add more detailed MLP stage visualization
|
| 34 |
-
- [ ] Consider adding "copy to clipboard" for token data
|
|
|
|
| 1 |
# Todo
|
| 2 |
|
| 3 |
+
## Completed: Test Suite Setup (Done)
|
| 4 |
+
- [x] Create `tests/` folder with `__init__.py` and `conftest.py` (shared fixtures)
|
| 5 |
+
- [x] Create `test_model_config.py` - 15 tests for model family lookups
|
| 6 |
+
- [x] Create `test_ablation_metrics.py` - 8 tests for KL divergence and probability deltas
|
| 7 |
+
- [x] Create `test_head_detection.py` - 20 tests for attention head categorization
|
| 8 |
+
- [x] Create `test_model_patterns.py` - 16 tests for merge_token_probabilities, safe_to_serializable
|
| 9 |
+
- [x] Create `test_token_attribution.py` - 11 tests for visualization data formatting
|
| 10 |
+
- [x] Verify all 73 tests pass with `pytest tests/ -v`
|
| 11 |
+
|
| 12 |
## Completed: Pipeline Explanation Refactor
|
| 13 |
|
| 14 |
### Phase 1: New Components (Done)
|
|
|
|
| 25 |
- [x] Delete `prompt_comparison.py`
|
| 26 |
- [x] Update `utils/__init__.py` exports
|
| 27 |
- [x] Add pipeline CSS styles to `assets/style.css`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/utils/__pycache__/__init__.cpython-311.pyc and b/utils/__pycache__/__init__.cpython-311.pyc differ
|
|
|
utils/__pycache__/token_attribution.cpython-311.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|