cdpearlman commited on
Commit
e13b1ba
·
1 Parent(s): 3850656

Add comprehensive test suite for utility modules (73 tests)

Browse files
.cursor/debug.log DELETED
@@ -1,4 +0,0 @@
1
- {"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H1", "location": "app.py:render_heatmap:entry", "message": "render_heatmap entry", "data": {"has_activation_data": true, "has_activation_data2": false, "has_original_activation_data": false, "mode_data": {"comparison": "prompt1", "ablation": "original"}, "model_name": "Qwen/Qwen2.5-0.5B", "plotly_version": "6.3.0"}, "timestamp": 1768427085403}
2
- {"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H2", "location": "app.py:render_heatmap:active_data", "message": "selected active data source", "data": {"comparison_mode": "prompt1", "ablation_mode": "original", "show_comparison_toggle": false, "show_ablation_toggle": false, "active_data_source": "activation_data"}, "timestamp": 1768427090436}
3
- {"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H3", "location": "app.py:render_heatmap:matrix_data", "message": "matrix data summary", "data": {"rows": 24, "cols": 11, "tokens_len": 11, "layers_len": 24, "top_tokens_rows": 24}, "timestamp": 1768427136380}
4
- {"sessionId": "debug-session", "runId": "pre-fix", "hypothesisId": "H1", "location": "app.py:render_heatmap:colorbar_config", "message": "heatmap colorbar config", "data": {"colorbar_config": {"title": {"text": "Delta", "side": "right"}}}, "timestamp": 1768427136385}
 
 
 
 
 
.cursor/rules/minimal_changes.mdc CHANGED
@@ -11,6 +11,12 @@ alwaysApply: true
11
 
12
  # Minimal Change Rules
13
 
 
 
 
 
 
 
14
  - Plan first:
15
  - Update `todo.md` with the smallest next actions tied to `plans.md`.
16
  - Keep tasks atomic and check them off as you go.
 
11
 
12
  # Minimal Change Rules
13
 
14
+ - Testing & verification:
15
+ - For substantial code changes (new files, new functionality), write tests first in `tests/` that describe expected behavior.
16
+ - Skip tests for UI/frontend changes, trivial additions, and documentation.
17
+ - After implementing changes, run `pytest` to verify all tests pass.
18
+ - If tests fail, iterate on debugging until fixed.
19
+
20
  - Plan first:
21
  - Update `todo.md` with the smallest next actions tied to `plans.md`.
22
  - Keep tasks atomic and check them off as you go.
.cursor/rules/planning_mode.mdc ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ description: Rules for dividing work into parallel agent worktrees
3
+ globs:
4
+ - "plans.md"
5
+ - "todo.md"
6
+ - "**/PLAN*.md"
7
+ alwaysApply: false
8
+ ---
9
+
10
+ # Parallel Agent Planning Rules
11
+
12
+ When creating or updating plans (plan mode active), structure work for independent parallel execution.
13
+
14
+ ## Worktree Division Principles
15
+
16
+ - **Isolate by file/module**: Each agent task must target a distinct set of files that no other agent will touch.
17
+ - **No shared edits**: If two tasks could modify the same file, merge them into one task.
18
+ - **Define boundaries explicitly**: In your plan file, list the exact files/directories each agent owns.
19
+
20
+ ## Plan Structure for Parallel Agents
21
+
22
+ When dividing work, create a section like:
23
+
24
+ ```
25
+ ## Parallel Worktrees
26
+
27
+ ### Agent A: [Task Name]
28
+ - **Owns**: `utils/feature_a.py`, `tests/test_feature_a.py`
29
+ - **Does not touch**: everything else
30
+ - **Deliverable**: [specific outcome]
31
+
32
+ ### Agent B: [Task Name]
33
+ - **Owns**: `components/widget.py`, `tests/test_widget.py`
34
+ - **Does not touch**: everything else
35
+ - **Deliverable**: [specific outcome]
36
+
37
+ ### Sequential (after parallel completes)
38
+ - Integration task that touches shared files (e.g., `app.py` imports)
39
+ ```
40
+
41
+ ## Rules for Each Agent Task
42
+
43
+ - Each task must be **self-contained**: write, test, verify independently.
44
+ - Each task must specify:
45
+ 1. Files it will create or modify (exclusive ownership)
46
+ 2. Tests it will write/run
47
+ 3. Success criteria
48
+ - Shared dependencies (imports, configs) should be locked before parallel work begins.
49
+
50
+ ## Conflict Prevention Checklist
51
+
52
+ Before finalizing a parallel plan, verify:
53
+ - [ ] No two agents modify the same file
54
+ - [ ] No two agents add imports to the same `__init__.py`
55
+ - [ ] Shared interfaces are defined and frozen before parallel work
56
+ - [ ] Each agent's tests can run independently
57
+
58
+ ## Git Strategy for Parallel Agents
59
+
60
+ - Each agent works on its own feature branch: `feature/<agent-task-name>`
61
+ - Branches are merged sequentially after all pass tests
62
+ - Order: merge least-dependent branches first
components/__pycache__/investigation_panel.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
components/__pycache__/model_selector.cpython-311.pyc CHANGED
Binary files a/components/__pycache__/model_selector.cpython-311.pyc and b/components/__pycache__/model_selector.cpython-311.pyc differ
 
components/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (23.4 kB). View file
 
components/model_selector.py CHANGED
@@ -12,9 +12,20 @@ from dash import html, dcc
12
  AVAILABLE_MODELS = [
13
  # LLaMA-like models (Qwen)
14
  {"label": "Qwen2.5-0.5B", "value": "Qwen/Qwen2.5-0.5B"},
 
15
 
16
  # GPT-2 family
17
  {"label": "GPT-2 (124M)", "value": "gpt2"}
 
 
 
 
 
 
 
 
 
 
18
  ]
19
 
20
  def create_model_selector():
 
12
  AVAILABLE_MODELS = [
13
  # LLaMA-like models (Qwen)
14
  {"label": "Qwen2.5-0.5B", "value": "Qwen/Qwen2.5-0.5B"},
15
+ # {"label": "Qwen2.5-1.5B", "value": "Qwen/Qwen2.5-1.5B"},
16
 
17
  # GPT-2 family
18
  {"label": "GPT-2 (124M)", "value": "gpt2"}
19
+ # {"label": "GPT-2 Medium (355M)", "value": "gpt2-medium"},
20
+ # {"label": "GPT-2 Large (774M)", "value": "gpt2-large"},
21
+
22
+ # # OPT family
23
+ # {"label": "OPT-125M", "value": "facebook/opt-125m"},
24
+ # {"label": "OPT-350M", "value": "facebook/opt-350m"},
25
+
26
+ # # GPT-NeoX family (Pythia)
27
+ # {"label": "Pythia-70M", "value": "EleutherAI/pythia-70m"},
28
+ # {"label": "Pythia-160M", "value": "EleutherAI/pythia-160m"},
29
  ]
30
 
31
  def create_model_selector():
plans.md CHANGED
@@ -1,19 +1,30 @@
1
- ## Current Plan
 
 
2
 
3
- ### Pipeline Explanation Refactor - COMPLETED
 
4
 
5
- The dashboard has been refactored from a testing/analysis tool into an explanation-first interface:
 
6
 
7
- 1. **New Pipeline Visualization**: Linear flow (Input → Tokens → Embed → Attention → MLP → Output) with click-to-expand stages
8
- 2. **Investigation Panel**: Consolidated ablation and token attribution tools
9
- 3. **Simplified Codebase**: Removed heatmap, comparison mode, and ~900 lines of code
10
- 4. **Token Attribution**: New gradient-based feature importance analysis
11
 
12
- #### File Changes
13
- - `app.py`: Reduced from 1781 to ~750 lines
14
- - `components/pipeline.py`: NEW - Main explanation flow
15
- - `components/investigation_panel.py`: NEW - Ablation + Attribution
16
- - `utils/token_attribution.py`: NEW - Integrated Gradients
17
- - `model_selector.py`: Simplified (removed comparison UI)
18
- - `main_panel.py`: DELETED
19
- - `prompt_comparison.py`: DELETED
 
 
 
 
 
 
 
 
 
 
1
+ Generation Settings:
2
+ - change "Max New Tokens" label to "Number of New Tokens"
3
+ - the concept of beams is non-trivial, change "Beam" referencse to something like "Number of Generation Choices"
4
 
5
+ Generated Sequences:
6
+ - don't show the score value, it is virtually meaningless to a user
7
 
8
+ 1. Tokenization
9
+ - stack tokens vertically instead of listing horizontally (more visually appealing)
10
 
11
+ 2. Embedding
12
+ - add some explanation that embedding is taken from pre-learned table (improves understanding about how ID goes to embedding)
 
 
13
 
14
+ 3. Attention
15
+ - use BertViz head view instead of model view, less overwhelming and lets user scroll through attention heads
16
+ - add explanation of what the user is looking at and how to navigate BertViz vizualization (answers questions "What am I looking at?", "What should this visualization show me?", "What should I look for?")
17
+ - categorize attention heads so that the different components of the model's learning are visible
18
+ - "Most attended tokens" section doesn't provide much value, remove this and focus on categorized attention heads and BertViz
19
+
20
+ 4. MLP
21
+ - need more explanation about how the feed-forward networks are learned in training, allowing it to understand the current words based on its training set
22
+
23
+ 5. Output
24
+ - include full prompt in "predicted next token", with the predicted token appended at the end (still highlighted)
25
+ - in top 5 tokens graph, the hover-over data should just show percent and token, not the long decimal value
26
+
27
+ Overall Conceptual Changes:
28
+ - the analysis should be done off the initial user-given prompt, not the prompt that includes max tokens
29
+ - the selected beam will be used for comparison after experiments, not for analysis. For example, the user input is "The capital of the US is", the chosen beam is "The capital of the US is Washington D.C.", but the analysis is only done with the first prompt. After either experiment, the chosen beam is used to compare to the new results. If ablation made the output "The capital of the US is New York City", then that can be compared to the original chosen beam to show differences.
30
+ - add testing and verification to the entire project so that each round of changes can be double checked and verified for correctness (try to avoid running the app, just test functions)
requirements.txt CHANGED
@@ -13,3 +13,6 @@ bertviz>=1.4.0
13
 
14
  # Utility dependencies
15
  numpy>=1.24.0
 
 
 
 
13
 
14
  # Utility dependencies
15
  numpy>=1.24.0
16
+
17
+ # Testing dependencies
18
+ pytest>=7.0.0
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Test suite for Transformer Activation Capture and Visualization
tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
tests/__pycache__/conftest.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (7.43 kB). View file
 
tests/__pycache__/test_ablation_metrics.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (18.2 kB). View file
 
tests/__pycache__/test_head_detection.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (44.9 kB). View file
 
tests/__pycache__/test_model_config.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (37 kB). View file
 
tests/__pycache__/test_model_patterns.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (34.1 kB). View file
 
tests/__pycache__/test_token_attribution.cpython-311-pytest-8.4.2.pyc ADDED
Binary file (27.1 kB). View file
 
tests/conftest.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared pytest fixtures for the test suite.
3
+
4
+ Provides reusable mock data structures and synthetic tensors
5
+ to test utility functions without loading actual ML models.
6
+ """
7
+
8
+ import pytest
9
+ import torch
10
+ import numpy as np
11
+
12
+
13
+ # =============================================================================
14
+ # Synthetic Attention Matrices
15
+ # =============================================================================
16
+
17
+ @pytest.fixture
18
+ def uniform_attention_matrix():
19
+ """4x4 uniform attention matrix (each position attends equally to all)."""
20
+ size = 4
21
+ return torch.ones(size, size) / size
22
+
23
+
24
+ @pytest.fixture
25
+ def previous_token_attention_matrix():
26
+ """
27
+ 4x4 attention matrix where each position attends primarily to the previous token.
28
+ Position 0 attends to itself (no previous token).
29
+ """
30
+ size = 4
31
+ matrix = torch.zeros(size, size)
32
+ # Position 0 attends to itself
33
+ matrix[0, 0] = 1.0
34
+ # Other positions attend strongly to previous token
35
+ for i in range(1, size):
36
+ matrix[i, i-1] = 0.8
37
+ matrix[i, i] = 0.2 # Some self-attention
38
+ return matrix
39
+
40
+
41
+ @pytest.fixture
42
+ def first_token_attention_matrix():
43
+ """4x4 attention matrix where all positions attend strongly to first token."""
44
+ size = 4
45
+ matrix = torch.zeros(size, size)
46
+ for i in range(size):
47
+ matrix[i, 0] = 0.7 # Strong attention to first token
48
+ matrix[i, i] = 0.3 # Some self-attention
49
+ return matrix
50
+
51
+
52
+ @pytest.fixture
53
+ def peaked_attention_matrix():
54
+ """4x4 attention matrix with peaked (low entropy) attention at one position."""
55
+ size = 4
56
+ matrix = torch.zeros(size, size)
57
+ # Each position attends almost entirely to position 2
58
+ for i in range(size):
59
+ matrix[i, 2] = 0.95
60
+ # Distribute remaining across others
61
+ for j in range(size):
62
+ if j != 2:
63
+ matrix[i, j] = 0.05 / (size - 1)
64
+ return matrix
65
+
66
+
67
+ # =============================================================================
68
+ # Mock Activation Data Structures
69
+ # =============================================================================
70
+
71
+ @pytest.fixture
72
+ def mock_activation_data():
73
+ """
74
+ Mock activation data structure similar to execute_forward_pass output.
75
+ Used for testing functions that process activation data.
76
+ """
77
+ return {
78
+ 'model': 'mock-model',
79
+ 'prompt': 'Hello world',
80
+ 'input_ids': [[1, 2, 3, 4]],
81
+ 'attention_modules': ['model.layers.0.self_attn', 'model.layers.1.self_attn'],
82
+ 'attention_outputs': {
83
+ 'model.layers.0.self_attn': {
84
+ 'output': [
85
+ [[0.1, 0.2, 0.3]], # Hidden states (simplified)
86
+ [[[[0.25, 0.25, 0.25, 0.25], # Attention weights [batch, heads, seq, seq]
87
+ [0.25, 0.25, 0.25, 0.25],
88
+ [0.25, 0.25, 0.25, 0.25],
89
+ [0.25, 0.25, 0.25, 0.25]]]]
90
+ ]
91
+ },
92
+ 'model.layers.1.self_attn': {
93
+ 'output': [
94
+ [[0.1, 0.2, 0.3]],
95
+ [[[[0.1, 0.2, 0.3, 0.4],
96
+ [0.1, 0.2, 0.3, 0.4],
97
+ [0.1, 0.2, 0.3, 0.4],
98
+ [0.1, 0.2, 0.3, 0.4]]]]
99
+ ]
100
+ }
101
+ },
102
+ 'block_modules': ['model.layers.0', 'model.layers.1'],
103
+ 'block_outputs': {
104
+ 'model.layers.0': {'output': [[[0.1, 0.2, 0.3, 0.4]]]},
105
+ 'model.layers.1': {'output': [[[0.2, 0.3, 0.4, 0.5]]]}
106
+ },
107
+ 'norm_parameters': ['model.norm.weight'],
108
+ 'norm_data': [[1.0, 1.0, 1.0, 1.0]],
109
+ 'actual_output': {'token': ' world', 'probability': 0.85},
110
+ 'global_top5_tokens': [
111
+ {'token': 'world', 'probability': 0.85},
112
+ {'token': 'there', 'probability': 0.05},
113
+ {'token': 'friend', 'probability': 0.03},
114
+ {'token': 'everyone', 'probability': 0.02},
115
+ {'token': 'all', 'probability': 0.01}
116
+ ]
117
+ }
118
+
119
+
120
+ # =============================================================================
121
+ # Mock Module/Parameter Patterns
122
+ # =============================================================================
123
+
124
+ @pytest.fixture
125
+ def mock_module_patterns():
126
+ """Mock module patterns as returned by extract_patterns."""
127
+ return {
128
+ 'model.layers.{N}.self_attn': ['model.layers.0.self_attn', 'model.layers.1.self_attn'],
129
+ 'model.layers.{N}.mlp': ['model.layers.0.mlp', 'model.layers.1.mlp'],
130
+ 'model.layers.{N}': ['model.layers.0', 'model.layers.1'],
131
+ 'model.embed_tokens': ['model.embed_tokens'],
132
+ 'model.norm': ['model.norm']
133
+ }
134
+
135
+
136
+ @pytest.fixture
137
+ def mock_param_patterns():
138
+ """Mock parameter patterns as returned by extract_patterns."""
139
+ return {
140
+ 'model.layers.{N}.self_attn.q_proj.weight': ['model.layers.0.self_attn.q_proj.weight'],
141
+ 'model.layers.{N}.self_attn.k_proj.weight': ['model.layers.0.self_attn.k_proj.weight'],
142
+ 'model.norm.weight': ['model.norm.weight'],
143
+ 'lm_head.weight': ['lm_head.weight']
144
+ }
145
+
146
+
147
+ # =============================================================================
148
+ # Synthetic Logits for Ablation Metrics
149
+ # =============================================================================
150
+
151
+ @pytest.fixture
152
+ def identical_logits():
153
+ """Two identical logit tensors for testing KL divergence = 0."""
154
+ logits = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
155
+ [2.0, 3.0, 4.0, 5.0]]]) # [1, 2, 4] = [batch, seq, vocab]
156
+ return logits, logits.clone()
157
+
158
+
159
+ @pytest.fixture
160
+ def different_logits():
161
+ """Two different logit tensors for testing KL divergence > 0."""
162
+ logits_p = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
163
+ [2.0, 3.0, 4.0, 5.0]]])
164
+ logits_q = torch.tensor([[[4.0, 3.0, 2.0, 1.0],
165
+ [5.0, 4.0, 3.0, 2.0]]])
166
+ return logits_p, logits_q
167
+
168
+
169
+ @pytest.fixture
170
+ def prob_delta_data():
171
+ """Data for testing probability delta computation."""
172
+ # Reference favors token 3, ablated favors token 0
173
+ logits_ref = torch.tensor([[[1.0, 2.0, 3.0, 10.0], # pos 0: predicts token 3
174
+ [1.0, 2.0, 10.0, 3.0]]]) # pos 1: predicts token 2
175
+ logits_abl = torch.tensor([[[10.0, 2.0, 3.0, 1.0], # pos 0: predicts token 0
176
+ [10.0, 2.0, 1.0, 3.0]]]) # pos 1: predicts token 0
177
+ input_ids = torch.tensor([[0, 3, 2]]) # Actual tokens: start, 3, 2
178
+ return logits_ref, logits_abl, input_ids
179
+
180
+
181
+ # =============================================================================
182
+ # Attribution Data for Visualization Tests
183
+ # =============================================================================
184
+
185
+ @pytest.fixture
186
+ def mock_attribution_result():
187
+ """Mock output from compute_integrated_gradients or compute_simple_gradient_attribution."""
188
+ return {
189
+ 'tokens': ['Hello', ' world', '!'],
190
+ 'token_ids': [1, 2, 3],
191
+ 'attributions': [0.5, 1.0, 0.2], # Raw attribution scores
192
+ 'normalized_attributions': [0.5, 1.0, 0.2], # Already normalized for simplicity
193
+ 'target_token': 'next',
194
+ 'target_token_id': 100
195
+ }
196
+
197
+
198
+ # =============================================================================
199
+ # Head Categorization Config
200
+ # =============================================================================
201
+
202
+ @pytest.fixture
203
+ def default_head_config():
204
+ """Default head categorization configuration for testing."""
205
+ from utils.head_detection import HeadCategorizationConfig
206
+ return HeadCategorizationConfig()
tests/test_ablation_metrics.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for utils/ablation_metrics.py
3
+
4
+ Tests KL divergence computation and probability delta calculations.
5
+ Uses synthetic tensors to avoid model loading.
6
+ """
7
+
8
+ import pytest
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from utils.ablation_metrics import compute_kl_divergence, get_token_probability_deltas
12
+
13
+
14
+ class TestComputeKLDivergence:
15
+ """Tests for compute_kl_divergence function."""
16
+
17
+ def test_identical_distributions_zero_kl(self, identical_logits):
18
+ """KL divergence of identical distributions should be approximately 0."""
19
+ logits_p, logits_q = identical_logits
20
+ kl_divs = compute_kl_divergence(logits_p, logits_q)
21
+
22
+ assert isinstance(kl_divs, list)
23
+ assert len(kl_divs) == 2 # seq_len = 2
24
+ for kl in kl_divs:
25
+ assert abs(kl) < 1e-5, f"Expected ~0, got {kl}"
26
+
27
+ def test_different_distributions_positive_kl(self, different_logits):
28
+ """KL divergence of different distributions should be positive."""
29
+ logits_p, logits_q = different_logits
30
+ kl_divs = compute_kl_divergence(logits_p, logits_q)
31
+
32
+ assert isinstance(kl_divs, list)
33
+ for kl in kl_divs:
34
+ assert kl > 0, f"Expected positive KL, got {kl}"
35
+
36
+ def test_kl_divergence_asymmetry(self, different_logits):
37
+ """KL(P||Q) should not equal KL(Q||P) for different distributions."""
38
+ logits_p, logits_q = different_logits
39
+ kl_pq = compute_kl_divergence(logits_p, logits_q)
40
+ kl_qp = compute_kl_divergence(logits_q, logits_p)
41
+
42
+ # They should generally be different (asymmetry of KL divergence)
43
+ assert kl_pq != kl_qp, "KL divergence should be asymmetric"
44
+
45
+ def test_handles_3d_input(self):
46
+ """Should handle [batch, seq_len, vocab_size] input correctly."""
47
+ logits = torch.randn(1, 5, 100) # batch=1, seq=5, vocab=100
48
+ kl_divs = compute_kl_divergence(logits, logits)
49
+
50
+ assert len(kl_divs) == 5
51
+ for kl in kl_divs:
52
+ assert abs(kl) < 1e-5
53
+
54
+
55
+ class TestGetTokenProbabilityDeltas:
56
+ """Tests for get_token_probability_deltas function."""
57
+
58
+ def test_deltas_with_synthetic_data(self):
59
+ """Test probability delta computation with known inputs."""
60
+ # Logits shape: [1, seq_len, vocab_size] where seq_len matches input_ids
61
+ # input_ids has 3 tokens, so logits needs 3 positions
62
+ logits_ref = torch.tensor([[[1.0, 2.0, 3.0, 10.0], # pos 0
63
+ [1.0, 2.0, 10.0, 3.0], # pos 1
64
+ [1.0, 2.0, 3.0, 4.0]]]) # pos 2
65
+ logits_abl = torch.tensor([[[10.0, 2.0, 3.0, 1.0],
66
+ [10.0, 2.0, 1.0, 3.0],
67
+ [1.0, 2.0, 3.0, 4.0]]])
68
+ input_ids = torch.tensor([[0, 3, 2]])
69
+
70
+ deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
71
+
72
+ # Should return list of length seq_len - 1 (shifted prediction)
73
+ assert isinstance(deltas, list)
74
+ assert len(deltas) == 2 # seq_len=3, so 2 predictions (pos 0 predicts token 1, pos 1 predicts token 2)
75
+
76
+ def test_identical_logits_zero_delta(self):
77
+ """Identical logits should produce zero deltas."""
78
+ # Logits need seq_len=3 to match input_ids
79
+ logits = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
80
+ [2.0, 3.0, 4.0, 5.0],
81
+ [3.0, 4.0, 5.0, 6.0]]])
82
+ input_ids = torch.tensor([[0, 3, 2]])
83
+
84
+ deltas = get_token_probability_deltas(logits, logits.clone(), input_ids)
85
+
86
+ for delta in deltas:
87
+ assert abs(delta) < 1e-5, f"Expected ~0 delta, got {delta}"
88
+
89
+ def test_delta_direction(self):
90
+ """When ablation increases a token's probability, delta should be positive."""
91
+ # 3 positions to match 3 input_ids
92
+ logits_ref = torch.tensor([[[1.0, 0.0, 0.0, 0.0], # favors token 0
93
+ [1.0, 0.0, 0.0, 0.0], # favors token 0
94
+ [1.0, 0.0, 0.0, 0.0]]])
95
+ logits_abl = torch.tensor([[[0.0, 10.0, 0.0, 0.0], # favors token 1
96
+ [0.0, 10.0, 0.0, 0.0], # favors token 1
97
+ [0.0, 10.0, 0.0, 0.0]]])
98
+ input_ids = torch.tensor([[0, 1, 1]]) # Target tokens: 1, 1
99
+
100
+ deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
101
+
102
+ # Both deltas should be positive (ablation increased target prob)
103
+ for delta in deltas:
104
+ assert delta > 0, f"Expected positive delta, got {delta}"
105
+
106
+ def test_delta_range(self):
107
+ """Deltas should be bounded by [-1, 1] since they're probability differences."""
108
+ # 3 positions to match input_ids
109
+ logits_ref = torch.tensor([[[100.0, -100.0, -100.0, -100.0],
110
+ [-100.0, 100.0, -100.0, -100.0],
111
+ [-100.0, -100.0, 100.0, -100.0]]])
112
+ logits_abl = torch.tensor([[[-100.0, 100.0, -100.0, -100.0],
113
+ [-100.0, -100.0, 100.0, -100.0],
114
+ [-100.0, -100.0, -100.0, 100.0]]])
115
+ input_ids = torch.tensor([[0, 0, 1]]) # Targets: 0, 1
116
+
117
+ deltas = get_token_probability_deltas(logits_ref, logits_abl, input_ids)
118
+
119
+ for delta in deltas:
120
+ assert -1.0 <= delta <= 1.0, f"Delta {delta} out of bounds"
tests/test_head_detection.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for utils/head_detection.py
3
+
4
+ Tests attention head categorization heuristics using synthetic attention matrices.
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+ import numpy as np
10
+ from utils.head_detection import (
11
+ compute_attention_entropy,
12
+ detect_previous_token_head,
13
+ detect_first_token_head,
14
+ detect_bow_head,
15
+ detect_syntactic_head,
16
+ categorize_attention_head,
17
+ categorize_all_heads,
18
+ format_categorization_summary,
19
+ HeadCategorizationConfig
20
+ )
21
+
22
+
23
+ class TestComputeAttentionEntropy:
24
+ """Tests for compute_attention_entropy function."""
25
+
26
+ def test_uniform_distribution_high_entropy(self):
27
+ """Uniform attention should have high (near 1.0) normalized entropy."""
28
+ # 4 positions with equal attention
29
+ uniform = torch.tensor([0.25, 0.25, 0.25, 0.25])
30
+ entropy = compute_attention_entropy(uniform)
31
+
32
+ # Normalized entropy should be close to 1.0 for uniform
33
+ assert 0.95 <= entropy <= 1.0, f"Expected ~1.0, got {entropy}"
34
+
35
+ def test_peaked_distribution_low_entropy(self):
36
+ """Peaked attention should have low normalized entropy."""
37
+ # One position dominates
38
+ peaked = torch.tensor([0.97, 0.01, 0.01, 0.01])
39
+ entropy = compute_attention_entropy(peaked)
40
+
41
+ # Should be low entropy
42
+ assert entropy < 0.3, f"Expected low entropy, got {entropy}"
43
+
44
+ def test_entropy_bounds(self):
45
+ """Entropy should always be between 0 and 1 (normalized)."""
46
+ test_cases = [
47
+ torch.tensor([1.0, 0.0, 0.0, 0.0]), # Extreme peaked
48
+ torch.tensor([0.5, 0.5, 0.0, 0.0]), # Two positions
49
+ torch.tensor([0.25, 0.25, 0.25, 0.25]), # Uniform
50
+ ]
51
+
52
+ for weights in test_cases:
53
+ entropy = compute_attention_entropy(weights)
54
+ assert 0.0 <= entropy <= 1.0, f"Entropy {entropy} out of bounds"
55
+
56
+
57
+ class TestDetectPreviousTokenHead:
58
+ """Tests for detect_previous_token_head function."""
59
+
60
+ def test_detects_previous_token_pattern(self, previous_token_attention_matrix, default_head_config):
61
+ """Should detect matrix with strong previous-token attention."""
62
+ is_prev, score = detect_previous_token_head(
63
+ previous_token_attention_matrix,
64
+ default_head_config
65
+ )
66
+
67
+ assert is_prev == True
68
+ assert score > 0.5, f"Expected high score, got {score}"
69
+
70
+ def test_rejects_uniform_attention(self, uniform_attention_matrix, default_head_config):
71
+ """Should reject matrix with uniform attention."""
72
+ is_prev, score = detect_previous_token_head(
73
+ uniform_attention_matrix,
74
+ default_head_config
75
+ )
76
+
77
+ assert is_prev == False
78
+ assert score < 0.4, f"Expected low score, got {score}"
79
+
80
+ def test_short_sequence_returns_false(self, default_head_config):
81
+ """Sequence shorter than min_seq_len should return False."""
82
+ short_matrix = torch.ones(2, 2) / 2
83
+ is_prev, score = detect_previous_token_head(short_matrix, default_head_config)
84
+
85
+ assert is_prev == False
86
+ assert score == 0.0
87
+
88
+
89
+ class TestDetectFirstTokenHead:
90
+ """Tests for detect_first_token_head function."""
91
+
92
+ def test_detects_first_token_pattern(self, first_token_attention_matrix, default_head_config):
93
+ """Should detect matrix with strong first-token attention."""
94
+ is_first, score = detect_first_token_head(
95
+ first_token_attention_matrix,
96
+ default_head_config
97
+ )
98
+
99
+ assert is_first == True
100
+ assert score > 0.5, f"Expected high score, got {score}"
101
+
102
+ def test_low_first_token_attention(self, default_head_config):
103
+ """Matrix with low attention to first token should not be detected."""
104
+ # Create matrix where first token gets very little attention
105
+ # Use size 5 to be above min_seq_len and avoid overlap at [0,0]
106
+ size = 5
107
+ matrix = torch.zeros(size, size)
108
+ for i in range(size):
109
+ # Distribute attention: 5% to first token, 95% to last token
110
+ matrix[i, 0] = 0.05
111
+ matrix[i, -1] = 0.95
112
+
113
+ is_first, score = detect_first_token_head(matrix, default_head_config)
114
+
115
+ assert is_first == False
116
+ assert score < 0.25, f"Expected low score, got {score}"
117
+
118
+
119
+ class TestDetectBowHead:
120
+ """Tests for detect_bow_head (bag-of-words / diffuse attention)."""
121
+
122
+ def test_detects_uniform_as_bow(self, uniform_attention_matrix, default_head_config):
123
+ """Uniform attention should be detected as BoW head."""
124
+ is_bow, score = detect_bow_head(uniform_attention_matrix, default_head_config)
125
+
126
+ # Uniform has high entropy and low max attention - should be BoW
127
+ assert is_bow == True
128
+ assert score > 0.9, f"Expected high entropy score, got {score}"
129
+
130
+ def test_rejects_peaked_attention(self, peaked_attention_matrix, default_head_config):
131
+ """Peaked attention should not be detected as BoW."""
132
+ is_bow, score = detect_bow_head(peaked_attention_matrix, default_head_config)
133
+
134
+ # Peaked attention has low entropy - should not be BoW
135
+ assert is_bow == False
136
+
137
+
138
+ class TestDetectSyntacticHead:
139
+ """Tests for detect_syntactic_head function."""
140
+
141
+ def test_consistent_distance_pattern(self, default_head_config):
142
+ """Matrix with consistent distance pattern should be detected as syntactic."""
143
+ # Create matrix where each position attends to position 2 tokens back
144
+ size = 6
145
+ matrix = torch.zeros(size, size)
146
+ for i in range(size):
147
+ target = max(0, i - 2) # 2 tokens back
148
+ matrix[i, target] = 1.0
149
+
150
+ is_syn, score = detect_syntactic_head(matrix, default_head_config)
151
+
152
+ # Should have consistent distance pattern
153
+ assert score > 0.0, f"Expected positive score for consistent pattern"
154
+
155
+ def test_random_attention_returns_valid_values(self, default_head_config):
156
+ """Random attention should return valid boolean and score."""
157
+ torch.manual_seed(42)
158
+ random_matrix = torch.softmax(torch.randn(6, 6), dim=-1)
159
+
160
+ is_syn, score = detect_syntactic_head(random_matrix, default_head_config)
161
+
162
+ # Check it returns valid types (bool or numpy bool, and numeric score)
163
+ assert is_syn in [True, False] or bool(is_syn) in [True, False]
164
+ assert 0 <= float(score) <= 1
165
+
166
+
167
+ class TestCategorizeAttentionHead:
168
+ """Tests for categorize_attention_head function."""
169
+
170
+ def test_categorizes_previous_token_head(self, previous_token_attention_matrix, default_head_config):
171
+ """Should categorize previous-token pattern correctly."""
172
+ result = categorize_attention_head(
173
+ previous_token_attention_matrix,
174
+ layer_idx=0,
175
+ head_idx=3,
176
+ config=default_head_config
177
+ )
178
+
179
+ assert result['category'] == 'previous_token'
180
+ assert result['layer'] == 0
181
+ assert result['head'] == 3
182
+ assert result['label'] == 'L0-H3'
183
+ assert 'scores' in result
184
+
185
+ def test_categorizes_first_token_head(self, first_token_attention_matrix, default_head_config):
186
+ """Should categorize first-token pattern correctly."""
187
+ result = categorize_attention_head(
188
+ first_token_attention_matrix,
189
+ layer_idx=2,
190
+ head_idx=5,
191
+ config=default_head_config
192
+ )
193
+
194
+ assert result['category'] == 'first_token'
195
+ assert result['label'] == 'L2-H5'
196
+
197
+ def test_categorizes_bow_head(self, default_head_config):
198
+ """Should categorize diffuse attention as BoW when it doesn't match other patterns."""
199
+ # Create BoW-like matrix: diffuse attention but first token gets LESS than threshold
200
+ # This avoids triggering first_token detection (threshold 0.25)
201
+ size = 5
202
+ matrix = torch.zeros(size, size)
203
+ for i in range(size):
204
+ # First token gets only 0.1, rest get roughly equal share
205
+ matrix[i, 0] = 0.1
206
+ remaining = 0.9 / (size - 1)
207
+ for j in range(1, size):
208
+ matrix[i, j] = remaining
209
+
210
+ result = categorize_attention_head(
211
+ matrix,
212
+ layer_idx=1,
213
+ head_idx=0,
214
+ config=default_head_config
215
+ )
216
+
217
+ assert result['category'] == 'bow'
218
+
219
+ def test_result_structure(self, uniform_attention_matrix):
220
+ """Result should have all required keys."""
221
+ result = categorize_attention_head(
222
+ uniform_attention_matrix,
223
+ layer_idx=0,
224
+ head_idx=0
225
+ )
226
+
227
+ required_keys = ['layer', 'head', 'category', 'scores', 'label']
228
+ for key in required_keys:
229
+ assert key in result, f"Missing key: {key}"
230
+
231
+
232
+ class TestCategorizeAllHeads:
233
+ """Tests for categorize_all_heads function."""
234
+
235
+ def test_returns_all_categories(self, mock_activation_data, default_head_config):
236
+ """Should return dict with all category keys."""
237
+ result = categorize_all_heads(mock_activation_data, default_head_config)
238
+
239
+ expected_categories = ['previous_token', 'first_token', 'bow', 'syntactic', 'other']
240
+ for cat in expected_categories:
241
+ assert cat in result, f"Missing category: {cat}"
242
+ assert isinstance(result[cat], list)
243
+
244
+ def test_handles_empty_attention_data(self, default_head_config):
245
+ """Should handle activation data with no attention outputs."""
246
+ empty_data = {'attention_outputs': {}}
247
+ result = categorize_all_heads(empty_data, default_head_config)
248
+
249
+ # Should return empty lists for all categories
250
+ for cat, heads in result.items():
251
+ assert heads == []
252
+
253
+
254
+ class TestFormatCategorizationSummary:
255
+ """Tests for format_categorization_summary function."""
256
+
257
+ def test_formats_empty_categorization(self):
258
+ """Should format empty categorization without error."""
259
+ empty = {
260
+ 'previous_token': [],
261
+ 'first_token': [],
262
+ 'bow': [],
263
+ 'syntactic': [],
264
+ 'other': []
265
+ }
266
+ result = format_categorization_summary(empty)
267
+
268
+ assert isinstance(result, str)
269
+ assert "Total Heads: 0" in result
270
+
271
+ def test_formats_with_heads(self):
272
+ """Should format categorization with heads correctly."""
273
+ categorized = {
274
+ 'previous_token': [
275
+ {'layer': 0, 'head': 1, 'label': 'L0-H1'},
276
+ {'layer': 0, 'head': 2, 'label': 'L0-H2'},
277
+ ],
278
+ 'first_token': [
279
+ {'layer': 1, 'head': 0, 'label': 'L1-H0'},
280
+ ],
281
+ 'bow': [],
282
+ 'syntactic': [],
283
+ 'other': []
284
+ }
285
+ result = format_categorization_summary(categorized)
286
+
287
+ assert "Total Heads: 3" in result
288
+ assert "Previous-Token Heads: 2" in result
289
+ assert "First/Positional-Token Heads: 1" in result
290
+ assert "Layer 0" in result
291
+ assert "Layer 1" in result
292
+
293
+
294
+ class TestHeadCategorizationConfig:
295
+ """Tests for HeadCategorizationConfig defaults."""
296
+
297
+ def test_default_values(self):
298
+ """Default config should have reasonable values."""
299
+ config = HeadCategorizationConfig()
300
+
301
+ assert 0 < config.prev_token_threshold < 1
302
+ assert 0 < config.first_token_threshold < 1
303
+ assert 0 < config.bow_entropy_threshold < 1
304
+ assert config.min_seq_len > 0
305
+
306
+ def test_config_is_mutable(self):
307
+ """Config values should be mutable for customization."""
308
+ config = HeadCategorizationConfig()
309
+ original = config.prev_token_threshold
310
+
311
+ config.prev_token_threshold = 0.8
312
+ assert config.prev_token_threshold == 0.8
313
+ assert config.prev_token_threshold != original
tests/test_model_config.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for utils/model_config.py
3
+
4
+ Tests model family lookups, configuration retrieval, and auto-selection logic.
5
+ """
6
+
7
+ import pytest
8
+ from utils.model_config import (
9
+ get_model_family,
10
+ get_family_config,
11
+ get_auto_selections,
12
+ _pattern_matches_template,
13
+ MODEL_TO_FAMILY,
14
+ MODEL_FAMILIES
15
+ )
16
+
17
+
18
+ class TestGetModelFamily:
19
+ """Tests for get_model_family function."""
20
+
21
+ def test_known_gpt2_model(self):
22
+ """Known GPT-2 model should return 'gpt2' family."""
23
+ assert get_model_family("gpt2") == "gpt2"
24
+ assert get_model_family("gpt2-medium") == "gpt2"
25
+ assert get_model_family("openai-community/gpt2") == "gpt2"
26
+
27
+ def test_known_llama_model(self):
28
+ """Known LLaMA-like models should return 'llama_like' family."""
29
+ assert get_model_family("Qwen/Qwen2.5-0.5B") == "llama_like"
30
+ assert get_model_family("meta-llama/Llama-2-7b-hf") == "llama_like"
31
+ assert get_model_family("mistralai/Mistral-7B-v0.1") == "llama_like"
32
+
33
+ def test_known_opt_model(self):
34
+ """Known OPT models should return 'opt' family."""
35
+ assert get_model_family("facebook/opt-125m") == "opt"
36
+ assert get_model_family("facebook/opt-1.3b") == "opt"
37
+
38
+ def test_unknown_model_returns_none(self):
39
+ """Unknown models should return None."""
40
+ assert get_model_family("unknown/model-name") is None
41
+ assert get_model_family("random-string") is None
42
+ assert get_model_family("") is None
43
+
44
+
45
+ class TestGetFamilyConfig:
46
+ """Tests for get_family_config function."""
47
+
48
+ def test_valid_gpt2_config(self):
49
+ """GPT-2 family config should have correct structure."""
50
+ config = get_family_config("gpt2")
51
+ assert config is not None
52
+ assert "templates" in config
53
+ assert "attention_pattern" in config["templates"]
54
+ assert config["templates"]["attention_pattern"] == "transformer.h.{N}.attn"
55
+ assert config["norm_type"] == "layernorm"
56
+
57
+ def test_valid_llama_config(self):
58
+ """LLaMA-like family config should have correct structure."""
59
+ config = get_family_config("llama_like")
60
+ assert config is not None
61
+ assert config["templates"]["attention_pattern"] == "model.layers.{N}.self_attn"
62
+ assert config["norm_type"] == "rmsnorm"
63
+ assert config["norm_parameter"] == "model.norm.weight"
64
+
65
+ def test_invalid_family_returns_none(self):
66
+ """Invalid family name should return None."""
67
+ assert get_family_config("invalid_family") is None
68
+ assert get_family_config("") is None
69
+ assert get_family_config("GPT2") is None # Case-sensitive
70
+
71
+
72
+ class TestPatternMatchesTemplate:
73
+ """Tests for _pattern_matches_template function."""
74
+
75
+ def test_exact_match(self):
76
+ """Pattern that exactly matches template should return True."""
77
+ assert _pattern_matches_template(
78
+ "model.layers.{N}.self_attn",
79
+ "model.layers.{N}.self_attn"
80
+ ) is True
81
+
82
+ def test_matching_with_n_placeholder(self):
83
+ """Patterns with {N} placeholder should match correctly."""
84
+ assert _pattern_matches_template(
85
+ "transformer.h.{N}.attn",
86
+ "transformer.h.{N}.attn"
87
+ ) is True
88
+
89
+ def test_non_matching_pattern(self):
90
+ """Different patterns should not match."""
91
+ assert _pattern_matches_template(
92
+ "model.layers.{N}.self_attn",
93
+ "transformer.h.{N}.attn"
94
+ ) is False
95
+
96
+ def test_empty_template_returns_false(self):
97
+ """Empty template should return False."""
98
+ assert _pattern_matches_template("model.layers.{N}.self_attn", "") is False
99
+ assert _pattern_matches_template("", "") is False
100
+
101
+
102
+ class TestGetAutoSelections:
103
+ """Tests for get_auto_selections function."""
104
+
105
+ def test_unknown_model_returns_empty_selections(self):
106
+ """Unknown model should return empty selections."""
107
+ result = get_auto_selections(
108
+ "unknown/model",
109
+ {"model.layers.{N}.self_attn": ["model.layers.0.self_attn"]},
110
+ {"model.norm.weight": ["model.norm.weight"]}
111
+ )
112
+ assert result["attention_selection"] == []
113
+ assert result["block_selection"] == []
114
+ assert result["norm_selection"] == []
115
+ assert result["family_name"] is None
116
+
117
+ def test_known_model_matches_patterns(self, mock_module_patterns, mock_param_patterns):
118
+ """Known model should match appropriate patterns."""
119
+ result = get_auto_selections(
120
+ "Qwen/Qwen2.5-0.5B", # llama_like family
121
+ mock_module_patterns,
122
+ mock_param_patterns
123
+ )
124
+ assert result["family_name"] == "llama_like"
125
+ # Should find self_attn pattern
126
+ assert "model.layers.{N}.self_attn" in result["attention_selection"]
127
+ # Should find block pattern
128
+ assert "model.layers.{N}" in result["block_selection"]
129
+ # Should find norm pattern
130
+ assert result["norm_selection"] == ["model.norm.weight"]
131
+
132
+ def test_result_structure(self, mock_module_patterns, mock_param_patterns):
133
+ """Result should have all required keys."""
134
+ result = get_auto_selections(
135
+ "gpt2",
136
+ {}, # Empty patterns - no matches expected
137
+ {}
138
+ )
139
+ assert "attention_selection" in result
140
+ assert "block_selection" in result
141
+ assert "norm_selection" in result
142
+ assert "family_name" in result
143
+ assert isinstance(result["attention_selection"], list)
144
+ assert isinstance(result["norm_selection"], list)
145
+
146
+
147
+ class TestModelRegistryIntegrity:
148
+ """Tests to verify the model registry data is consistent."""
149
+
150
+ def test_all_families_have_required_fields(self):
151
+ """All model families should have required configuration fields."""
152
+ required_fields = ["description", "templates", "norm_type"]
153
+ for family_name, config in MODEL_FAMILIES.items():
154
+ for field in required_fields:
155
+ assert field in config, f"Family {family_name} missing {field}"
156
+
157
+ def test_all_mapped_families_exist(self):
158
+ """All families referenced in MODEL_TO_FAMILY should exist in MODEL_FAMILIES."""
159
+ for model_name, family_name in MODEL_TO_FAMILY.items():
160
+ assert family_name in MODEL_FAMILIES, \
161
+ f"Model {model_name} references unknown family {family_name}"
tests/test_model_patterns.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for utils/model_patterns.py
3
+
4
+ Tests pure logic functions that don't require model loading:
5
+ - merge_token_probabilities
6
+ - safe_to_serializable
7
+ """
8
+
9
+ import pytest
10
+ import torch
11
+ import numpy as np
12
+ from utils.model_patterns import merge_token_probabilities, safe_to_serializable
13
+
14
+
15
+ class TestMergeTokenProbabilities:
16
+ """Tests for merge_token_probabilities function."""
17
+
18
+ def test_merges_tokens_with_leading_space(self):
19
+ """Tokens with and without leading space should be merged."""
20
+ token_probs = [
21
+ (" cat", 0.15),
22
+ ("cat", 0.05),
23
+ (" dog", 0.10),
24
+ ]
25
+ result = merge_token_probabilities(token_probs)
26
+
27
+ # Convert to dict for easier checking
28
+ result_dict = dict(result)
29
+
30
+ assert "cat" in result_dict
31
+ assert abs(result_dict["cat"] - 0.20) < 1e-6 # 0.15 + 0.05
32
+ assert "dog" in result_dict
33
+ assert abs(result_dict["dog"] - 0.10) < 1e-6
34
+
35
+ def test_sorts_by_probability_descending(self):
36
+ """Results should be sorted by probability (highest first)."""
37
+ token_probs = [
38
+ ("low", 0.01),
39
+ ("high", 0.50),
40
+ ("medium", 0.20),
41
+ ]
42
+ result = merge_token_probabilities(token_probs)
43
+
44
+ # Check order: high, medium, low
45
+ assert result[0][0] == "high"
46
+ assert result[1][0] == "medium"
47
+ assert result[2][0] == "low"
48
+
49
+ def test_handles_empty_input(self):
50
+ """Empty input should return empty list."""
51
+ result = merge_token_probabilities([])
52
+ assert result == []
53
+
54
+ def test_handles_single_token(self):
55
+ """Single token should be returned as-is (stripped)."""
56
+ result = merge_token_probabilities([(" hello", 0.5)])
57
+
58
+ assert len(result) == 1
59
+ assert result[0][0] == "hello"
60
+ assert result[0][1] == 0.5
61
+
62
+ def test_strips_multiple_spaces(self):
63
+ """Multiple leading spaces should all be stripped."""
64
+ token_probs = [
65
+ (" word", 0.3), # Two spaces
66
+ (" word", 0.2), # One space
67
+ ("word", 0.1), # No space
68
+ ]
69
+ result = merge_token_probabilities(token_probs)
70
+
71
+ result_dict = dict(result)
72
+ assert "word" in result_dict
73
+ assert abs(result_dict["word"] - 0.6) < 1e-6 # All merged
74
+
75
+
76
+ class TestSafeToSerializable:
77
+ """Tests for safe_to_serializable function."""
78
+
79
+ def test_converts_tensor_to_list(self):
80
+ """PyTorch tensor should be converted to Python list."""
81
+ tensor = torch.tensor([1.0, 2.0, 3.0])
82
+ result = safe_to_serializable(tensor)
83
+
84
+ assert isinstance(result, list)
85
+ assert result == [1.0, 2.0, 3.0]
86
+
87
+ def test_converts_nested_tensor(self):
88
+ """2D tensor should become nested list."""
89
+ tensor = torch.tensor([[1, 2], [3, 4]])
90
+ result = safe_to_serializable(tensor)
91
+
92
+ assert isinstance(result, list)
93
+ assert result == [[1, 2], [3, 4]]
94
+
95
+ def test_converts_list_of_tensors(self):
96
+ """List containing tensors should have tensors converted."""
97
+ data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
98
+ result = safe_to_serializable(data)
99
+
100
+ assert result == [[1, 2], [3, 4]]
101
+
102
+ def test_converts_dict_with_tensor_values(self):
103
+ """Dict with tensor values should have values converted."""
104
+ data = {
105
+ "a": torch.tensor([1.0, 2.0]),
106
+ "b": "string_value",
107
+ "c": 42
108
+ }
109
+ result = safe_to_serializable(data)
110
+
111
+ assert result["a"] == [1.0, 2.0]
112
+ assert result["b"] == "string_value"
113
+ assert result["c"] == 42
114
+
115
+ def test_handles_tuple_input(self):
116
+ """Tuple with tensors should be converted to list."""
117
+ data = (torch.tensor([1]), torch.tensor([2]))
118
+ result = safe_to_serializable(data)
119
+
120
+ assert isinstance(result, list)
121
+ assert result == [[1], [2]]
122
+
123
+ def test_passes_through_primitives(self):
124
+ """Primitive types should pass through unchanged."""
125
+ assert safe_to_serializable(42) == 42
126
+ assert safe_to_serializable(3.14) == 3.14
127
+ assert safe_to_serializable("hello") == "hello"
128
+ assert safe_to_serializable(None) is None
129
+ assert safe_to_serializable(True) is True
130
+
131
+ def test_handles_deeply_nested_structure(self):
132
+ """Should handle deeply nested structures with tensors."""
133
+ data = {
134
+ "level1": {
135
+ "level2": {
136
+ "tensor": torch.tensor([1, 2, 3])
137
+ }
138
+ },
139
+ "list": [torch.tensor([4, 5])]
140
+ }
141
+ result = safe_to_serializable(data)
142
+
143
+ assert result["level1"]["level2"]["tensor"] == [1, 2, 3]
144
+ assert result["list"] == [[4, 5]]
145
+
146
+ def test_handles_empty_containers(self):
147
+ """Empty lists, dicts, tuples should remain empty."""
148
+ assert safe_to_serializable([]) == []
149
+ assert safe_to_serializable({}) == {}
150
+ assert safe_to_serializable(()) == [] # Tuple becomes list
151
+
152
+
153
+ class TestSafeToSerializableEdgeCases:
154
+ """Edge case tests for safe_to_serializable."""
155
+
156
+ def test_handles_scalar_tensor(self):
157
+ """Scalar tensor should become a Python scalar."""
158
+ scalar = torch.tensor(42.0)
159
+ result = safe_to_serializable(scalar)
160
+
161
+ # Scalar tensor.tolist() returns a Python number
162
+ assert result == 42.0
163
+
164
+ def test_handles_integer_tensor(self):
165
+ """Integer tensor should be converted correctly."""
166
+ tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
167
+ result = safe_to_serializable(tensor)
168
+
169
+ assert result == [1, 2, 3]
170
+ assert all(isinstance(x, int) for x in result)
171
+
172
+ def test_handles_mixed_list(self):
173
+ """List with mixed tensor and non-tensor items should work."""
174
+ data = [torch.tensor([1]), "string", 42, {"key": torch.tensor([2])}]
175
+ result = safe_to_serializable(data)
176
+
177
+ assert result[0] == [1]
178
+ assert result[1] == "string"
179
+ assert result[2] == 42
180
+ assert result[3] == {"key": [2]}
tests/test_token_attribution.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for utils/token_attribution.py
3
+
4
+ Tests the visualization data formatting function (pure logic).
5
+ The gradient computation functions require models and are not tested here.
6
+ """
7
+
8
+ import pytest
9
+ from utils.token_attribution import create_attribution_visualization_data
10
+
11
+
12
+ class TestCreateAttributionVisualizationData:
13
+ """Tests for create_attribution_visualization_data function."""
14
+
15
+ def test_returns_correct_structure(self, mock_attribution_result):
16
+ """Should return list of dicts with required keys."""
17
+ result = create_attribution_visualization_data(mock_attribution_result)
18
+
19
+ assert isinstance(result, list)
20
+ assert len(result) == 3 # 3 tokens in mock data
21
+
22
+ required_keys = ['token', 'index', 'attribution', 'normalized', 'color', 'text_color']
23
+ for item in result:
24
+ for key in required_keys:
25
+ assert key in item, f"Missing key: {key}"
26
+
27
+ def test_preserves_token_order(self, mock_attribution_result):
28
+ """Tokens should be in same order as input."""
29
+ result = create_attribution_visualization_data(mock_attribution_result)
30
+
31
+ assert result[0]['token'] == 'Hello'
32
+ assert result[1]['token'] == ' world'
33
+ assert result[2]['token'] == '!'
34
+
35
+ assert result[0]['index'] == 0
36
+ assert result[1]['index'] == 1
37
+ assert result[2]['index'] == 2
38
+
39
+ def test_preserves_attribution_values(self, mock_attribution_result):
40
+ """Raw attribution values should be preserved."""
41
+ result = create_attribution_visualization_data(mock_attribution_result)
42
+
43
+ assert result[0]['attribution'] == 0.5
44
+ assert result[1]['attribution'] == 1.0
45
+ assert result[2]['attribution'] == 0.2
46
+
47
+ def test_color_format(self, mock_attribution_result):
48
+ """Colors should be valid RGB format."""
49
+ result = create_attribution_visualization_data(mock_attribution_result)
50
+
51
+ for item in result:
52
+ color = item['color']
53
+ assert color.startswith('rgb(')
54
+ assert color.endswith(')')
55
+ # Extract RGB values
56
+ rgb_str = color[4:-1]
57
+ r, g, b = [int(x) for x in rgb_str.split(',')]
58
+ assert 0 <= r <= 255
59
+ assert 0 <= g <= 255
60
+ assert 0 <= b <= 255
61
+
62
+ def test_text_color_contrast(self, mock_attribution_result):
63
+ """Text color should be black or white for contrast."""
64
+ result = create_attribution_visualization_data(mock_attribution_result)
65
+
66
+ for item in result:
67
+ assert item['text_color'] in ['#000000', '#ffffff']
68
+
69
+ def test_high_attribution_gets_color(self):
70
+ """High attribution should result in colored background."""
71
+ data = {
72
+ 'tokens': ['high'],
73
+ 'token_ids': [1],
74
+ 'attributions': [1.0], # Maximum positive attribution
75
+ 'normalized_attributions': [1.0],
76
+ 'target_token': 'x',
77
+ 'target_token_id': 100
78
+ }
79
+ result = create_attribution_visualization_data(data)
80
+
81
+ # High positive attribution should have red-ish color (r=255)
82
+ color = result[0]['color']
83
+ rgb_str = color[4:-1]
84
+ r, g, b = [int(x) for x in rgb_str.split(',')]
85
+
86
+ # Red should be at max, green/blue should be reduced
87
+ assert r == 255
88
+ assert g < 255 # Reduced for visibility
89
+ assert b < 255
90
+
91
+ def test_handles_zero_attributions(self):
92
+ """Zero attributions should produce neutral colors."""
93
+ data = {
94
+ 'tokens': ['zero'],
95
+ 'token_ids': [1],
96
+ 'attributions': [0.0],
97
+ 'normalized_attributions': [0.0],
98
+ 'target_token': 'x',
99
+ 'target_token_id': 100
100
+ }
101
+ result = create_attribution_visualization_data(data)
102
+
103
+ # Zero normalized attribution should give white-ish color
104
+ color = result[0]['color']
105
+ rgb_str = color[4:-1]
106
+ r, g, b = [int(x) for x in rgb_str.split(',')]
107
+
108
+ # All components should be high (near white)
109
+ assert r == 255
110
+ assert g == 255
111
+ assert b == 255
112
+
113
+ def test_handles_negative_attributions(self):
114
+ """Negative attributions should get blue-ish color."""
115
+ data = {
116
+ 'tokens': ['negative'],
117
+ 'token_ids': [1],
118
+ 'attributions': [-1.0], # Negative attribution
119
+ 'normalized_attributions': [1.0], # Abs normalized
120
+ 'target_token': 'x',
121
+ 'target_token_id': 100
122
+ }
123
+ result = create_attribution_visualization_data(data)
124
+
125
+ # Negative attribution should have blue-ish color
126
+ color = result[0]['color']
127
+ rgb_str = color[4:-1]
128
+ r, g, b = [int(x) for x in rgb_str.split(',')]
129
+
130
+ # Blue should be at max, red/green should be reduced
131
+ assert b == 255
132
+ assert r < 255
133
+ assert g < 255
134
+
135
+
136
+ class TestAttributionVisualizationEdgeCases:
137
+ """Edge case tests for create_attribution_visualization_data."""
138
+
139
+ def test_handles_single_token(self):
140
+ """Should handle single token input."""
141
+ data = {
142
+ 'tokens': ['only'],
143
+ 'token_ids': [1],
144
+ 'attributions': [0.5],
145
+ 'normalized_attributions': [1.0], # Normalized to max
146
+ 'target_token': 'x',
147
+ 'target_token_id': 100
148
+ }
149
+ result = create_attribution_visualization_data(data)
150
+
151
+ assert len(result) == 1
152
+ assert result[0]['token'] == 'only'
153
+
154
+ def test_handles_empty_input(self):
155
+ """Should handle empty token list."""
156
+ data = {
157
+ 'tokens': [],
158
+ 'token_ids': [],
159
+ 'attributions': [],
160
+ 'normalized_attributions': [],
161
+ 'target_token': 'x',
162
+ 'target_token_id': 100
163
+ }
164
+ result = create_attribution_visualization_data(data)
165
+
166
+ assert result == []
167
+
168
+ def test_handles_special_characters_in_tokens(self):
169
+ """Should handle tokens with special characters."""
170
+ data = {
171
+ 'tokens': ['<s>', '</s>', '\n', ' '],
172
+ 'token_ids': [1, 2, 3, 4],
173
+ 'attributions': [0.1, 0.2, 0.3, 0.4],
174
+ 'normalized_attributions': [0.25, 0.5, 0.75, 1.0],
175
+ 'target_token': 'x',
176
+ 'target_token_id': 100
177
+ }
178
+ result = create_attribution_visualization_data(data)
179
+
180
+ assert len(result) == 4
181
+ assert result[0]['token'] == '<s>'
182
+ assert result[2]['token'] == '\n'
todo.md CHANGED
@@ -1,5 +1,14 @@
1
  # Todo
2
 
 
 
 
 
 
 
 
 
 
3
  ## Completed: Pipeline Explanation Refactor
4
 
5
  ### Phase 1: New Components (Done)
@@ -16,19 +25,3 @@
16
  - [x] Delete `prompt_comparison.py`
17
  - [x] Update `utils/__init__.py` exports
18
  - [x] Add pipeline CSS styles to `assets/style.css`
19
-
20
- ---
21
-
22
- ## Next Steps
23
-
24
- ### Testing
25
- - [ ] Run the dashboard and verify all pipeline stages render correctly
26
- - [ ] Test ablation experiment workflow
27
- - [ ] Test token attribution (both methods)
28
- - [ ] Verify beam search still works with multi-token generation
29
-
30
- ### Enhancements (Optional)
31
- - [ ] Add loading spinners to investigation tools
32
- - [ ] Improve attention visualization formatting
33
- - [ ] Add more detailed MLP stage visualization
34
- - [ ] Consider adding "copy to clipboard" for token data
 
1
  # Todo
2
 
3
+ ## Completed: Test Suite Setup (Done)
4
+ - [x] Create `tests/` folder with `__init__.py` and `conftest.py` (shared fixtures)
5
+ - [x] Create `test_model_config.py` - 15 tests for model family lookups
6
+ - [x] Create `test_ablation_metrics.py` - 8 tests for KL divergence and probability deltas
7
+ - [x] Create `test_head_detection.py` - 20 tests for attention head categorization
8
+ - [x] Create `test_model_patterns.py` - 16 tests for merge_token_probabilities, safe_to_serializable
9
+ - [x] Create `test_token_attribution.py` - 11 tests for visualization data formatting
10
+ - [x] Verify all 73 tests pass with `pytest tests/ -v`
11
+
12
  ## Completed: Pipeline Explanation Refactor
13
 
14
  ### Phase 1: New Components (Done)
 
25
  - [x] Delete `prompt_comparison.py`
26
  - [x] Update `utils/__init__.py` exports
27
  - [x] Add pipeline CSS styles to `assets/style.css`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/__init__.cpython-311.pyc and b/utils/__pycache__/__init__.cpython-311.pyc differ
 
utils/__pycache__/token_attribution.cpython-311.pyc ADDED
Binary file (11.8 kB). View file