File size: 4,162 Bytes
edede4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Integration tests for the end-to-end MathTok Pipeline.
"""

import pytest
from mathtok.pipeline import MathTokPipeline, TokenizedOutput


@pytest.fixture(scope="module")
def pipeline():
    return MathTokPipeline(include_metadata=True)


class TestBasicEncode:
    def test_returns_output(self, pipeline):
        out = pipeline.encode("x^2 + 1")
        assert isinstance(out, TokenizedOutput)

    def test_tokens_nonempty(self, pipeline):
        out = pipeline.encode("sin(x)")
        assert len(out.tokens) > 0

    def test_input_ids_match_tokens(self, pipeline):
        out = pipeline.encode("x^2 + 2*x + 1")
        assert len(out.tokens) == len(out.input_ids)

    def test_ids_are_integers(self, pipeline):
        out = pipeline.encode("x + 1")
        assert all(isinstance(i, int) for i in out.input_ids)

    def test_no_negative_ids(self, pipeline):
        out = pipeline.encode("x + 1")
        # All IDs should be non-negative (UNK=1 is minimum valid)
        assert all(i >= 0 for i in out.input_ids)


class TestMathSpans:
    def test_math_start_end_tokens(self, pipeline):
        out = pipeline.encode("x^2")
        assert "[MATH_START]" in out.tokens
        assert "[MATH_END]" in out.tokens

    def test_sexp_nonempty(self, pipeline):
        out = pipeline.encode("x^2 + 1")
        assert len(out.sexp) > 0

    def test_sexp_contains_op(self, pipeline):
        out = pipeline.encode("x^2")
        assert "OP_POW" in out.sexp

    def test_canon_results(self, pipeline):
        # Use a simple ASCII expression guaranteed to parse successfully
        out = pipeline.encode("x^2 + 1")
        assert len(out.canon_results) >= 1
        assert out.canon_results[0].success


class TestMixedInput:
    def test_mixed_latex(self, pipeline):
        out = pipeline.encode("The result is $x^2 + 1$.")
        assert len(out.tokens) > 0

    def test_mixed_ascii(self, pipeline):
        out = pipeline.encode("Compute sin(x) for x = pi.")
        assert len(out.tokens) > 0

    def test_multiple_math_spans(self, pipeline):
        out = pipeline.encode("If $a > 0$ and $b < 0$ then $a + b$ can be zero.")
        # Should have at least some math tokens
        math_toks = [t for t in out.tokens if t.startswith("OP_") or t.startswith("VAR_")]
        assert len(math_toks) > 0


class TestMetadata:
    def test_metadata_present(self, pipeline):
        out = pipeline.encode("x + 1")
        assert len(out.metadata) > 0

    def test_metadata_positions_sequential(self, pipeline):
        out = pipeline.encode("x^2 + 1")
        positions = [m.position for m in out.metadata]
        assert positions == sorted(positions)

    def test_metadata_categories(self, pipeline):
        out = pipeline.encode("x + 1")
        categories = {m.token_category for m in out.metadata}
        assert "operator" in categories or "variable" in categories or "constant" in categories

    def test_tree_position_keys(self, pipeline):
        out = pipeline.encode("x + 1")
        keys = [m.tree_position_key for m in out.metadata if m.node_id >= 0]
        assert len(keys) > 0
        assert all(isinstance(k, str) for k in keys)


class TestEncodeMathOnly:
    def test_encode_math_only(self, pipeline):
        out = pipeline.encode_math_only("x^2 + 2*x + 1")
        assert len(out.tokens) > 0
        assert "OP_ADD" in out.tokens or "OP_POW" in out.tokens

    def test_encode_batch(self, pipeline):
        exprs = ["x + 1", "sin(x)", "x^2"]
        outs = pipeline.encode_batch(exprs)
        assert len(outs) == 3
        assert all(len(o.tokens) > 0 for o in outs)


class TestHFTokenizer:
    def test_hf_tokenizer_callable(self, pipeline):
        hf_tok = pipeline.get_hf_tokenizer()
        result = hf_tok("x^2 + 1")
        assert "input_ids" in result
        assert len(result["input_ids"]) == 1

    def test_hf_tokenizer_encode(self, pipeline):
        hf_tok = pipeline.get_hf_tokenizer()
        ids = hf_tok.encode("sin(x)")
        assert isinstance(ids, list)
        assert len(ids) > 0

    def test_hf_vocab_size(self, pipeline):
        hf_tok = pipeline.get_hf_tokenizer()
        assert len(hf_tok) > 100