File size: 5,378 Bytes
2730fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""Tests for utility functions."""

import pytest

from namer.utils import (
    EOS_IDX,
    VOCABULARY,
    decode,
    digits_to_int,
    encode,
    int_to_digits,
    read_digits,
    read_double,
    read_triplet,
)


class TestIntToDigits:
    """Tests for int_to_digits function."""

    def test_zero(self) -> None:
        assert int_to_digits(0) == [0]

    def test_positive(self) -> None:
        assert int_to_digits(123) == [1, 2, 3]
        assert int_to_digits(7) == [7]

    def test_negative(self) -> None:
        assert int_to_digits(-456) == [4, 5, 6]

    def test_large_number(self) -> None:
        assert int_to_digits(1002003) == [1, 0, 0, 2, 0, 0, 3]


class TestDigitsToInt:
    """Tests for digits_to_int function."""

    def test_empty(self) -> None:
        assert digits_to_int([]) == 0

    def test_single_digit(self) -> None:
        assert digits_to_int([5]) == 5

    def test_multiple_digits(self) -> None:
        assert digits_to_int([1, 2, 3]) == 123

    def test_with_zeros(self) -> None:
        assert digits_to_int([1, 0, 0, 2]) == 1002

    def test_invalid_digit(self) -> None:
        with pytest.raises(ValueError, match="Invalid digit"):
            digits_to_int([10])


class TestRoundTrip:
    """Tests for int_to_digits <-> digits_to_int round-trip."""

    def test_round_trip(self) -> None:
        for n in [0, 42, 123, 1000, 999999, 1000000]:
            assert digits_to_int(int_to_digits(n)) == abs(n)


class TestReadDouble:
    """Tests for read_double function."""

    def test_single_digit(self) -> None:
        assert read_double(0, 7) == "seven"
        assert read_double(0, 0) == "zero"

    def test_teens(self) -> None:
        assert read_double(1, 1) == "eleven"
        assert read_double(1, 9) == "nineteen"

    def test_tens(self) -> None:
        assert read_double(3, 0) == "thirty"
        assert read_double(5, 0) == "fifty"

    def test_tens_and_ones(self) -> None:
        assert read_double(2, 3) == "twenty three"
        assert read_double(5, 9) == "fifty nine"

    def test_invalid_digits(self) -> None:
        with pytest.raises(ValueError, match="must be between 0 and 9"):
            read_double(10, 5)


class TestReadTriplet:
    """Tests for read_triplet function."""

    def test_hundreds(self) -> None:
        assert read_triplet(1, 0, 6) == "one hundred six"
        assert read_triplet(2, 0, 0) == "two hundred"

    def test_zero_hundreds(self) -> None:
        assert read_triplet(0, 5, 5) == "fifty five"

    def test_all_zeros(self) -> None:
        assert read_triplet(0, 0, 0) == "zero"


class TestReadDigits:
    """Tests for read_digits function."""

    def test_empty(self) -> None:
        assert read_digits([]) == "zero"

    def test_zero(self) -> None:
        assert read_digits([0]) == "zero"
        assert read_digits([0, 0, 0]) == "zero"

    def test_single_digit(self) -> None:
        assert read_digits([5]) == "five"

    def test_double_digit(self) -> None:
        assert read_digits([4, 2]) == "forty two"

    def test_triple_digit(self) -> None:
        assert read_digits([1, 2, 3]) == "one hundred twenty three"

    def test_thousands(self) -> None:
        assert read_digits([1, 0, 0, 0]) == "one thousand"
        assert read_digits([1, 2, 3, 4]) == "one thousand two hundred thirty four"

    def test_millions(self) -> None:
        assert read_digits([1, 0, 0, 0, 0, 0, 0]) == "one million"

    def test_complex(self) -> None:
        # 1,234,567
        digits = [1, 2, 3, 4, 5, 6, 7]
        result = read_digits(digits)
        assert "one million" in result
        assert "two hundred thirty four thousand" in result
        assert "five hundred sixty seven" in result

    def test_invalid_digit(self) -> None:
        with pytest.raises(ValueError, match="must be digits"):
            read_digits([1, 10, 3])


class TestEncode:
    """Tests for encode function."""

    def test_simple(self) -> None:
        indices = encode("one million")
        assert len(indices) == 2
        assert all(0 <= i < len(VOCABULARY) for i in indices)

    def test_multi_word(self) -> None:
        indices = encode("twenty three")
        assert len(indices) == 2

    def test_empty(self) -> None:
        assert encode("") == []
        assert encode("   ") == []

    def test_unknown_word(self) -> None:
        with pytest.raises(ValueError, match="Unknown word"):
            encode("unknown")


class TestDecode:
    """Tests for decode function."""

    def test_simple(self) -> None:
        encoded = encode("one million")
        assert decode(encoded) == "one million"

    def test_with_eos(self) -> None:
        encoded = encode("one million") + [EOS_IDX]
        assert decode(encoded) == "one million"

    def test_empty(self) -> None:
        assert decode([]) == ""

    def test_invalid_index(self) -> None:
        with pytest.raises(ValueError, match="out of range"):
            decode([9999])


class TestEncodeDecodeRoundTrip:
    """Tests for encode/decode round-trip."""

    def test_round_trip(self) -> None:
        test_cases = [
            "one million",
            "twenty three",
            "one hundred twenty three",
            "zero",
            "nine hundred nineteen",
        ]
        for text in test_cases:
            encoded = encode(text)
            assert decode(encoded) == text