mathtok / tests /test_comparison.py
SurweeshSP's picture
Initial clean MathTok release
edede4c
"""
Tests for the Semantic Tokenizer Comparison Framework.
"""
import pytest
from evaluation.comparison import (
TokenizerStats, ComparisonRecord, TokenizerComparison,
_score_char, _score_gpt2, _score_mathtok,
_jaccard, _mean,
STANDARD_EXPRESSIONS, DEEP_NESTING_EXPRESSIONS, CANONICAL_PAIRS,
)
from mathtok.pipeline import MathTokPipeline
@pytest.fixture(scope="module")
def pipeline():
return MathTokPipeline(include_metadata=True)
@pytest.fixture(scope="module")
def comp(pipeline):
return TokenizerComparison(pipeline, gpt2_fn=None, save_jsonl=False)
# ── TokenizerStats ────────────────────────────────────────────────────────
class TestTokenizerStats:
def test_scr_computed(self):
stats = TokenizerStats(
name="test", tokens=["OP_ADD", "VAR_X", "CONST_1"],
token_count=3,
operator_nodes=1, tree_depth=1,
parent_child_relations=1, function_scope=0,
canonical_bonus=2,
)
stats.compute_scr()
assert stats.structural_score == 5 # 1+1+1+0+2
assert abs(stats.raw_scr - 5/3) < 1e-9
assert abs(stats.structural_efficiency - 1/3) < 1e-9
def test_zero_token_count_safe(self):
stats = TokenizerStats(name="empty", tokens=[], token_count=0)
stats.compute_scr()
assert stats.raw_scr == 0.0
# ── Character-level scorer ─────────────────────────────────────────────────
class TestCharScore:
def test_simple(self):
stats = _score_char("x + 1")
assert stats.token_count == 5
assert stats.operator_nodes >= 1 # at least +
assert stats.raw_scr >= 0
def test_nested_parens_depth(self):
stats = _score_char("sin((x+1)^2)")
assert stats.tree_depth >= 2 # at least 2 levels of parens
def test_no_function_scope(self):
# Character-level can't identify functions
stats = _score_char("sin(x)")
assert stats.function_scope == 0
# ── GPT-2 heuristic scorer ─────────────────────────────────────────────────
class TestGPT2Score:
def test_operators_detected(self):
tokens = ["(", "x", "+", "1", ")", "^", "2"]
stats = _score_gpt2(tokens)
assert stats.operator_nodes >= 1
def test_function_detected(self):
tokens = ["sin", "(", "x", ")"]
stats = _score_gpt2(tokens)
assert stats.function_scope >= 1
def test_paren_depth(self):
tokens = ["(", "(", "x", ")", ")"]
stats = _score_gpt2(tokens)
assert stats.tree_depth == 2
def test_scr_positive(self):
tokens = ["sin", "(", "x", "^", "2", ")"]
stats = _score_gpt2(tokens)
stats.compute_scr()
assert stats.raw_scr >= 0
# ── MathTok scorer ────────────────────────────────────────────────────────
class TestMathTokScore:
def test_add_expression(self, pipeline):
out = pipeline.encode_math_only("x + 1")
stats = _score_mathtok(out)
assert stats.token_count > 0
assert stats.operator_nodes >= 1 # OP_ADD
assert stats.canonical_bonus == 2 # successful parse
def test_function_expression(self, pipeline):
out = pipeline.encode_math_only("sin(x^2)")
stats = _score_mathtok(out)
assert stats.function_scope >= 1 # FUNC_SIN
def test_depth_nonzero(self, pipeline):
out = pipeline.encode_math_only("sin(x^2 + 1)")
stats = _score_mathtok(out)
assert stats.tree_depth >= 2
def test_scr_computed(self, pipeline):
out = pipeline.encode_math_only("(x+1)^2")
stats = _score_mathtok(out)
assert stats.raw_scr > 0
def test_mathtok_scr_higher_than_char(self, pipeline):
expr = "sin(x^2 + 1)"
out = pipeline.encode_math_only(expr)
mt = _score_mathtok(out)
ch = _score_char(expr)
# MathTok should have higher SCR due to semantic richness
assert mt.raw_scr > ch.raw_scr
# ── Comparison mechanics ──────────────────────────────────────────────────
class TestComparison:
def test_compare_one(self, comp):
rec = comp._compare_one("x + 1", "test")
assert isinstance(rec, ComparisonRecord)
assert rec.mathtok.token_count > 0
assert rec.char_level.token_count > 0
assert rec.gpt2 is None # no GPT-2 in fixture
def test_scr_improvement_vs_char(self, comp):
rec = comp._compare_one("sin(x^2)", "test")
# MathTok should outperform char-level on SCR
assert rec.scr_improvement_vs_char > 0
def test_canonical_jaccard(self, comp, pipeline):
# Equivalent expressions should have high Jaccard
out_a = pipeline.encode_math_only("x + 2")
out_b = pipeline.encode_math_only("2 + x")
mt_a = set(t for t in out_a.tokens if not t.startswith("["))
mt_b = set(t for t in out_b.tokens if not t.startswith("["))
jac = _jaccard(mt_a, mt_b)
assert jac > 0.5 # should be near 1.0 due to canonicalization
def test_run_standard_small(self, comp):
# Run just 3 expressions to keep test fast
for expr in STANDARD_EXPRESSIONS[:3]:
rec = comp._compare_one(expr, "standard")
assert rec.mathtok.token_count > 0
def test_deep_nesting_depth_increases(self, comp, pipeline):
flat = pipeline.encode_math_only("x + 1")
nested = pipeline.encode_math_only("sin(cos((x+1)^2))")
flat_d = max((m.depth for m in flat.metadata if m.depth >= 0), default=0)
nest_d = max((m.depth for m in nested.metadata if m.depth >= 0), default=0)
assert nest_d > flat_d
# ── Utility helpers ───────────────────────────────────────────────────────
class TestHelpers:
def test_jaccard_identical(self):
assert _jaccard({"a", "b"}, {"a", "b"}) == 1.0
def test_jaccard_disjoint(self):
assert _jaccard({"a"}, {"b"}) == 0.0
def test_jaccard_partial(self):
j = _jaccard({"a", "b"}, {"b", "c"})
assert abs(j - 1/3) < 1e-9
def test_mean_empty(self):
assert _mean([]) == 0.0
def test_mean_values(self):
assert abs(_mean([1.0, 2.0, 3.0]) - 2.0) < 1e-9