import pytest
from linalg_zero.grpo.verify import parse_string, verify_answers
class TestVerifyAnswersCorrectness:
"""
Test the end-to-end correctness of verify_answers function.
Focus on whether the permissive extract_math_content leads to wrong verification results.
"""
def test_correct_exact_matches(self):
"""Test that exact matches are correctly verified as True."""
test_cases = [
("42", "42"),
("-17.5", "-17.5"),
("[[1, 2], [3, 4]]", "[[1, 2], [3, 4]]"),
("[1, 2, 3]", "[1, 2, 3]"),
("0", "0"),
("1.5e-10", "1.5e-10"),
]
for completion, ground_truth in test_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is True, f"Failed: {completion} should equal {ground_truth}"
def test_incorrect_exact_mismatches(self):
"""Test that clear mismatches are correctly verified as False."""
test_cases = [
("42", "43"),
("-17.5", "17.5"),
("[[1, 2], [3, 4]]", "[[4, 3], [2, 1]]"),
("[1, 2, 3]", "[3, 2, 1]"),
("5", "0"),
]
for completion, ground_truth in test_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is False, f"Failed: {completion} should NOT equal {ground_truth}"
def test_current_behavior_strict_matching(self):
"""Test that current implementation only does strict string matching."""
test_cases = [
("42", "42", True),
("[[1, 2], [3, 4]]", "[[1, 2], [3, 4]]", True),
("-17.5", "-17.5", True),
("5", "5", True),
("42", "42", False),
("Some text [[1, 2], [3, 4]] more text", "[[1, 2], [3, 4]]", False),
("The determinant is -17.5", "-17.5", False),
("The answer is 42", "42", False),
]
for completion, ground_truth, expected in test_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result == expected, f"Failed: '{completion}' vs '{ground_truth}' expected {expected}, got {result}"
def test_no_answer_tags_fallback(self):
"""Test that text without answer tags is processed as-is."""
test_cases = [
("42", "42"),
("[[1, 2], [3, 4]]", "[[1, 2], [3, 4]]"),
("-17.5", "-17.5"),
("5", "5"),
("The answer is 42", "42"),
]
for completion, ground_truth in test_cases[:-1]:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is True, f"Failed: '{completion}' should match '{ground_truth}'"
# Test that unparseable text returns False
parsed_ground_truth = parse_string("42")
parsed_completion = parse_string("The answer is 42") # This should return None
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is False, "Text without answer tags should not flexibly extract numbers"
def test_malformed_input_edge_cases(self):
"""
Test that malformed input cases return False as expected.
"""
malformed_cases = [
("[[1, 2], [3, 4", "[[1, 2], [3, 4]]"),
("42.5.6", "42.5"),
("Multiple answers: 42 and 24", "42"),
("[[1, 2]] and [[3, 4]]", "[[1, 2]]"),
("[[1, 2], [3, 4", "[[1, 2], [3, 4]]"),
("Text 42 and 24", "42"),
]
for completion, ground_truth in malformed_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is False, f"Malformed input should be False: '{completion}' vs '{ground_truth}'"
def test_verify_mathematical_equivalence(self):
"""Test that mathematically equivalent but differently formatted answers are correctly identified."""
equivalence_cases = [
("2.0", "2"),
("[[1.0, 2.0], [3.0, 4.0]]", "[[1, 2], [3, 4]]"),
("0.0", "0"),
("-0", "0"),
]
for completion, ground_truth in equivalence_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is True, f"Mathematical equivalence failed: '{completion}' should equal '{ground_truth}'"
def test_empty_and_invalid_inputs(self):
"""Test that empty or completely invalid inputs return False."""
edge_cases = [
("", "42"),
("42", ""),
("", ""),
("No numbers here", "42"),
("The result is undefined", "42"),
]
for completion, ground_truth in edge_cases:
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
assert result is False, f"Edge case should be False: '{completion}' vs '{ground_truth}'"
@pytest.mark.parametrize(
"completion,ground_truth,should_match",
[
# Clear correct cases
("42", "42", True),
("[[1, 2], [3, 4]]", "[[1, 2], [3, 4]]", True),
# Clear incorrect cases
("42", "43", False),
("[[1, 2], [3, 4]]", "[[4, 3], [2, 1]]", False),
# Cases that fail
("42", "42", False),
("The answer is 42", "42", False),
("The answer is 42", "43", False),
# Malformed input
("[[1, 2], [3, 4", "[[1, 2], [3, 4]]", False),
("42.5.6", "42.5", False),
],
)
def test_verify_answers_comprehensive(self, completion, ground_truth, should_match):
"""Comprehensive test of verify_answers behavior."""
parsed_ground_truth = parse_string(ground_truth)
parsed_completion = parse_string(completion)
result = verify_answers(parsed_ground_truth, parsed_completion)
if should_match is not None:
assert result == should_match, (
f"Expected {should_match} for '{completion}' vs '{ground_truth}', got {result}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])