File size: 5,378 Bytes
2730fd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | """Tests for utility functions."""
import pytest
from namer.utils import (
EOS_IDX,
VOCABULARY,
decode,
digits_to_int,
encode,
int_to_digits,
read_digits,
read_double,
read_triplet,
)
class TestIntToDigits:
"""Tests for int_to_digits function."""
def test_zero(self) -> None:
assert int_to_digits(0) == [0]
def test_positive(self) -> None:
assert int_to_digits(123) == [1, 2, 3]
assert int_to_digits(7) == [7]
def test_negative(self) -> None:
assert int_to_digits(-456) == [4, 5, 6]
def test_large_number(self) -> None:
assert int_to_digits(1002003) == [1, 0, 0, 2, 0, 0, 3]
class TestDigitsToInt:
"""Tests for digits_to_int function."""
def test_empty(self) -> None:
assert digits_to_int([]) == 0
def test_single_digit(self) -> None:
assert digits_to_int([5]) == 5
def test_multiple_digits(self) -> None:
assert digits_to_int([1, 2, 3]) == 123
def test_with_zeros(self) -> None:
assert digits_to_int([1, 0, 0, 2]) == 1002
def test_invalid_digit(self) -> None:
with pytest.raises(ValueError, match="Invalid digit"):
digits_to_int([10])
class TestRoundTrip:
"""Tests for int_to_digits <-> digits_to_int round-trip."""
def test_round_trip(self) -> None:
for n in [0, 42, 123, 1000, 999999, 1000000]:
assert digits_to_int(int_to_digits(n)) == abs(n)
class TestReadDouble:
"""Tests for read_double function."""
def test_single_digit(self) -> None:
assert read_double(0, 7) == "seven"
assert read_double(0, 0) == "zero"
def test_teens(self) -> None:
assert read_double(1, 1) == "eleven"
assert read_double(1, 9) == "nineteen"
def test_tens(self) -> None:
assert read_double(3, 0) == "thirty"
assert read_double(5, 0) == "fifty"
def test_tens_and_ones(self) -> None:
assert read_double(2, 3) == "twenty three"
assert read_double(5, 9) == "fifty nine"
def test_invalid_digits(self) -> None:
with pytest.raises(ValueError, match="must be between 0 and 9"):
read_double(10, 5)
class TestReadTriplet:
"""Tests for read_triplet function."""
def test_hundreds(self) -> None:
assert read_triplet(1, 0, 6) == "one hundred six"
assert read_triplet(2, 0, 0) == "two hundred"
def test_zero_hundreds(self) -> None:
assert read_triplet(0, 5, 5) == "fifty five"
def test_all_zeros(self) -> None:
assert read_triplet(0, 0, 0) == "zero"
class TestReadDigits:
"""Tests for read_digits function."""
def test_empty(self) -> None:
assert read_digits([]) == "zero"
def test_zero(self) -> None:
assert read_digits([0]) == "zero"
assert read_digits([0, 0, 0]) == "zero"
def test_single_digit(self) -> None:
assert read_digits([5]) == "five"
def test_double_digit(self) -> None:
assert read_digits([4, 2]) == "forty two"
def test_triple_digit(self) -> None:
assert read_digits([1, 2, 3]) == "one hundred twenty three"
def test_thousands(self) -> None:
assert read_digits([1, 0, 0, 0]) == "one thousand"
assert read_digits([1, 2, 3, 4]) == "one thousand two hundred thirty four"
def test_millions(self) -> None:
assert read_digits([1, 0, 0, 0, 0, 0, 0]) == "one million"
def test_complex(self) -> None:
# 1,234,567
digits = [1, 2, 3, 4, 5, 6, 7]
result = read_digits(digits)
assert "one million" in result
assert "two hundred thirty four thousand" in result
assert "five hundred sixty seven" in result
def test_invalid_digit(self) -> None:
with pytest.raises(ValueError, match="must be digits"):
read_digits([1, 10, 3])
class TestEncode:
"""Tests for encode function."""
def test_simple(self) -> None:
indices = encode("one million")
assert len(indices) == 2
assert all(0 <= i < len(VOCABULARY) for i in indices)
def test_multi_word(self) -> None:
indices = encode("twenty three")
assert len(indices) == 2
def test_empty(self) -> None:
assert encode("") == []
assert encode(" ") == []
def test_unknown_word(self) -> None:
with pytest.raises(ValueError, match="Unknown word"):
encode("unknown")
class TestDecode:
"""Tests for decode function."""
def test_simple(self) -> None:
encoded = encode("one million")
assert decode(encoded) == "one million"
def test_with_eos(self) -> None:
encoded = encode("one million") + [EOS_IDX]
assert decode(encoded) == "one million"
def test_empty(self) -> None:
assert decode([]) == ""
def test_invalid_index(self) -> None:
with pytest.raises(ValueError, match="out of range"):
decode([9999])
class TestEncodeDecodeRoundTrip:
"""Tests for encode/decode round-trip."""
def test_round_trip(self) -> None:
test_cases = [
"one million",
"twenty three",
"one hundred twenty three",
"zero",
"nine hundred nineteen",
]
for text in test_cases:
encoded = encode(text)
assert decode(encoded) == text
|