LLMVis / tests /test_token_attribution.py
cdpearlman's picture
Add comprehensive test suite for utility modules (73 tests)
e13b1ba
"""
Tests for utils/token_attribution.py
Tests the visualization data formatting function (pure logic).
The gradient computation functions require models and are not tested here.
"""
import pytest
from utils.token_attribution import create_attribution_visualization_data
class TestCreateAttributionVisualizationData:
"""Tests for create_attribution_visualization_data function."""
def test_returns_correct_structure(self, mock_attribution_result):
"""Should return list of dicts with required keys."""
result = create_attribution_visualization_data(mock_attribution_result)
assert isinstance(result, list)
assert len(result) == 3 # 3 tokens in mock data
required_keys = ['token', 'index', 'attribution', 'normalized', 'color', 'text_color']
for item in result:
for key in required_keys:
assert key in item, f"Missing key: {key}"
def test_preserves_token_order(self, mock_attribution_result):
"""Tokens should be in same order as input."""
result = create_attribution_visualization_data(mock_attribution_result)
assert result[0]['token'] == 'Hello'
assert result[1]['token'] == ' world'
assert result[2]['token'] == '!'
assert result[0]['index'] == 0
assert result[1]['index'] == 1
assert result[2]['index'] == 2
def test_preserves_attribution_values(self, mock_attribution_result):
"""Raw attribution values should be preserved."""
result = create_attribution_visualization_data(mock_attribution_result)
assert result[0]['attribution'] == 0.5
assert result[1]['attribution'] == 1.0
assert result[2]['attribution'] == 0.2
def test_color_format(self, mock_attribution_result):
"""Colors should be valid RGB format."""
result = create_attribution_visualization_data(mock_attribution_result)
for item in result:
color = item['color']
assert color.startswith('rgb(')
assert color.endswith(')')
# Extract RGB values
rgb_str = color[4:-1]
r, g, b = [int(x) for x in rgb_str.split(',')]
assert 0 <= r <= 255
assert 0 <= g <= 255
assert 0 <= b <= 255
def test_text_color_contrast(self, mock_attribution_result):
"""Text color should be black or white for contrast."""
result = create_attribution_visualization_data(mock_attribution_result)
for item in result:
assert item['text_color'] in ['#000000', '#ffffff']
def test_high_attribution_gets_color(self):
"""High attribution should result in colored background."""
data = {
'tokens': ['high'],
'token_ids': [1],
'attributions': [1.0], # Maximum positive attribution
'normalized_attributions': [1.0],
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
# High positive attribution should have red-ish color (r=255)
color = result[0]['color']
rgb_str = color[4:-1]
r, g, b = [int(x) for x in rgb_str.split(',')]
# Red should be at max, green/blue should be reduced
assert r == 255
assert g < 255 # Reduced for visibility
assert b < 255
def test_handles_zero_attributions(self):
"""Zero attributions should produce neutral colors."""
data = {
'tokens': ['zero'],
'token_ids': [1],
'attributions': [0.0],
'normalized_attributions': [0.0],
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
# Zero normalized attribution should give white-ish color
color = result[0]['color']
rgb_str = color[4:-1]
r, g, b = [int(x) for x in rgb_str.split(',')]
# All components should be high (near white)
assert r == 255
assert g == 255
assert b == 255
def test_handles_negative_attributions(self):
"""Negative attributions should get blue-ish color."""
data = {
'tokens': ['negative'],
'token_ids': [1],
'attributions': [-1.0], # Negative attribution
'normalized_attributions': [1.0], # Abs normalized
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
# Negative attribution should have blue-ish color
color = result[0]['color']
rgb_str = color[4:-1]
r, g, b = [int(x) for x in rgb_str.split(',')]
# Blue should be at max, red/green should be reduced
assert b == 255
assert r < 255
assert g < 255
class TestAttributionVisualizationEdgeCases:
"""Edge case tests for create_attribution_visualization_data."""
def test_handles_single_token(self):
"""Should handle single token input."""
data = {
'tokens': ['only'],
'token_ids': [1],
'attributions': [0.5],
'normalized_attributions': [1.0], # Normalized to max
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
assert len(result) == 1
assert result[0]['token'] == 'only'
def test_handles_empty_input(self):
"""Should handle empty token list."""
data = {
'tokens': [],
'token_ids': [],
'attributions': [],
'normalized_attributions': [],
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
assert result == []
def test_handles_special_characters_in_tokens(self):
"""Should handle tokens with special characters."""
data = {
'tokens': ['<s>', '</s>', '\n', ' '],
'token_ids': [1, 2, 3, 4],
'attributions': [0.1, 0.2, 0.3, 0.4],
'normalized_attributions': [0.25, 0.5, 0.75, 1.0],
'target_token': 'x',
'target_token_id': 100
}
result = create_attribution_visualization_data(data)
assert len(result) == 4
assert result[0]['token'] == '<s>'
assert result[2]['token'] == '\n'