Spaces:
Running
Running
File size: 9,744 Bytes
eb1c19a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 | """Tests for the memory evaluation framework."""
from headroom.evals.memory.judge import _parse_judge_response, simple_judge
from headroom.evals.memory.locomo import (
LOCOMO_CATEGORIES,
DialogueTurn,
LoCoMoCase,
LoCoMoConversation,
Session,
get_locomo_stats,
)
class TestLoCoMoDataStructures:
"""Test LoCoMo data structures."""
def test_dialogue_turn_from_dict(self):
"""Test DialogueTurn parsing."""
data = {
"speaker": "Alice",
"text": "Hello Bob!",
"dia_id": "D1:1",
}
turn = DialogueTurn.from_dict(data)
assert turn.speaker == "Alice"
assert turn.text == "Hello Bob!"
assert turn.dia_id == "D1:1"
assert turn.image_url is None
def test_dialogue_turn_with_image(self):
"""Test DialogueTurn with image."""
data = {
"speaker": "Bob",
"text": "Check this out",
"dia_id": "D1:2",
"img_file": "http://example.com/img.jpg",
"blip_caption": "A beautiful sunset",
}
turn = DialogueTurn.from_dict(data)
assert turn.image_url == "http://example.com/img.jpg"
assert turn.image_caption == "A beautiful sunset"
def test_dialogue_turn_to_message_format(self):
"""Test message format conversion."""
turn = DialogueTurn(
speaker="Alice",
text="I love Python",
dia_id="D1:1",
)
msg = turn.to_message_format()
assert msg == "Alice: I love Python"
# With image
turn_img = DialogueTurn(
speaker="Bob",
text="Look at this",
dia_id="D1:2",
image_url="http://example.com/img.jpg",
image_caption="A dog playing",
)
msg_img = turn_img.to_message_format()
assert "[shares image: A dog playing]" in msg_img
def test_session_properties(self):
"""Test Session properties."""
dialogues = [
DialogueTurn(speaker="Alice", text="Hi", dia_id="D1:1"),
DialogueTurn(speaker="Bob", text="Hello", dia_id="D1:2"),
]
session = Session(session_num=1, datetime="2024-01-15", dialogues=dialogues)
assert session.num_turns == 2
assert "Alice: Hi" in session.text
assert "Bob: Hello" in session.text
def test_locomo_case_properties(self):
"""Test LoCoMoCase properties."""
case = LoCoMoCase(
question="What is Alice's favorite color?",
answer="Blue",
category=1,
evidence=["D1:5", "D2:3"],
conversation_id="sample_1",
)
assert case.category_name == "single_hop"
assert case.is_answerable is True
# Test unanswerable case
case_na = LoCoMoCase(
question="What is unknown?",
answer="N/A",
category=5,
evidence=[],
conversation_id="sample_1",
)
assert case_na.is_answerable is False
def test_locomo_categories(self):
"""Test category definitions."""
assert LOCOMO_CATEGORIES[1] == "single_hop"
assert LOCOMO_CATEGORIES[2] == "temporal"
assert LOCOMO_CATEGORIES[3] == "multi_hop"
assert LOCOMO_CATEGORIES[4] == "open_domain"
assert LOCOMO_CATEGORIES[5] == "adversarial"
class TestLoCoMoStats:
"""Test LoCoMo statistics."""
def test_get_stats_empty(self):
"""Test stats with empty list."""
stats = get_locomo_stats([])
assert stats["num_conversations"] == 0
assert stats["num_qa_pairs"] == 0
def test_get_stats_with_data(self):
"""Test stats calculation."""
# Create mock conversation
dialogues = [
DialogueTurn(speaker="A", text="Hello", dia_id="D1:1"),
DialogueTurn(speaker="B", text="Hi there", dia_id="D1:2"),
]
session = Session(session_num=1, datetime="2024-01-15", dialogues=dialogues)
qa_cases = [
LoCoMoCase(question="Q1", answer="A1", category=1, evidence=[], conversation_id="s1"),
LoCoMoCase(question="Q2", answer="A2", category=2, evidence=[], conversation_id="s1"),
]
conv = LoCoMoConversation(
sample_id="s1",
speaker_a="Alice",
speaker_b="Bob",
sessions=[session],
qa_cases=qa_cases,
)
stats = get_locomo_stats([conv])
assert stats["num_conversations"] == 1
assert stats["num_sessions"] == 1
assert stats["num_turns"] == 2
assert stats["num_qa_pairs"] == 2
assert "single_hop" in stats["questions_by_category"]
assert "temporal" in stats["questions_by_category"]
class TestJudge:
"""Test LLM judge functions."""
def test_parse_judge_response_standard(self):
"""Test parsing standard judge response."""
response = """Reasoning: The prediction captures the main point.
Score: 4"""
score, reasoning = _parse_judge_response(response)
assert score == 4.0
assert "main point" in reasoning
def test_parse_judge_response_with_decimal(self):
"""Test parsing score with decimal."""
response = """Reasoning: Partially correct.
Score: 3.5"""
score, reasoning = _parse_judge_response(response)
assert score == 3.5
def test_parse_judge_response_clamping(self):
"""Test score clamping to valid range."""
# Score too high
response = "Reasoning: Perfect\nScore: 10"
score, _ = _parse_judge_response(response)
assert score == 5.0
# Score too low
response = "Reasoning: Terrible\nScore: 0"
score, _ = _parse_judge_response(response)
assert score == 1.0
def test_simple_judge_exact_match(self):
"""Test simple judge with exact match."""
score, reasoning = simple_judge(
"What color?",
"Blue",
"Blue",
)
assert score == 5.0
assert "Exact match" in reasoning
def test_simple_judge_high_overlap(self):
"""Test simple judge with high F1."""
score, reasoning = simple_judge(
"What happened?",
"Alice went to the store to buy groceries",
"Alice went to the store for groceries",
)
assert score >= 4.0
assert "F1" in reasoning
def test_simple_judge_no_overlap(self):
"""Test simple judge with no overlap."""
score, reasoning = simple_judge(
"What color?",
"Blue",
"The weather is nice",
)
assert score == 1.0
assert "Very low" in reasoning
class TestMemoryEvalConfig:
"""Test MemoryEvalConfig."""
def test_default_config(self):
"""Test default configuration."""
from headroom.evals.memory import MemoryEvalConfig
config = MemoryEvalConfig()
assert config.n_conversations is None
assert config.skip_adversarial is True
assert config.top_k_memories == 10
assert config.llm_judge_enabled is False
assert config.f1_threshold == 0.5
def test_custom_config(self):
"""Test custom configuration."""
from headroom.evals.memory import MemoryEvalConfig
config = MemoryEvalConfig(
n_conversations=5,
categories=[1, 2],
top_k_memories=20,
llm_judge_enabled=True,
f1_threshold=0.7,
)
assert config.n_conversations == 5
assert config.categories == [1, 2]
assert config.top_k_memories == 20
assert config.llm_judge_enabled is True
assert config.f1_threshold == 0.7
class TestMemoryEvalResult:
"""Test MemoryEvalResult and MemoryEvalSuiteResult."""
def test_eval_result_to_dict(self):
"""Test result serialization."""
from headroom.evals.memory.runner import MemoryEvalResult
case = LoCoMoCase(
question="What color?",
answer="Blue",
category=1,
evidence=[],
conversation_id="s1",
)
result = MemoryEvalResult(
case=case,
predicted_answer="Blue",
retrieved_memories=["Memory 1", "Memory 2"],
retrieval_scores=[0.9, 0.8],
f1_score=1.0,
exact_match=True,
is_correct=True,
)
d = result.to_dict()
assert d["question"] == "What color?"
assert d["ground_truth"] == "Blue"
assert d["predicted"] == "Blue"
assert d["f1_score"] == 1.0
assert d["is_correct"] is True
def test_suite_result_summary(self):
"""Test suite result summary generation."""
from headroom.evals.memory.runner import MemoryEvalSuiteResult
suite_result = MemoryEvalSuiteResult(
total_cases=100,
correct_cases=75,
accuracy=0.75,
avg_f1_score=0.82,
exact_match_rate=0.5,
avg_llm_judge_score=4.2,
metrics_by_category={
"single_hop": {"count": 30, "accuracy": 0.9, "avg_f1": 0.88, "correct": 27},
"temporal": {"count": 25, "accuracy": 0.7, "avg_f1": 0.75, "correct": 18},
},
total_duration_seconds=120.5,
avg_retrieval_latency_ms=15.3,
avg_generation_latency_ms=250.0,
)
summary = suite_result.summary()
assert "100" in summary
assert "75" in summary # Accuracy percentage
assert "0.820" in summary # F1 score
assert "single_hop" in summary
assert "temporal" in summary
|