namer / tests /test_utils.py
Edwin Jose Palathinkal
Initial commit
2730fd2
"""Tests for utility functions."""
import pytest
from namer.utils import (
EOS_IDX,
VOCABULARY,
decode,
digits_to_int,
encode,
int_to_digits,
read_digits,
read_double,
read_triplet,
)
class TestIntToDigits:
"""Tests for int_to_digits function."""
def test_zero(self) -> None:
assert int_to_digits(0) == [0]
def test_positive(self) -> None:
assert int_to_digits(123) == [1, 2, 3]
assert int_to_digits(7) == [7]
def test_negative(self) -> None:
assert int_to_digits(-456) == [4, 5, 6]
def test_large_number(self) -> None:
assert int_to_digits(1002003) == [1, 0, 0, 2, 0, 0, 3]
class TestDigitsToInt:
"""Tests for digits_to_int function."""
def test_empty(self) -> None:
assert digits_to_int([]) == 0
def test_single_digit(self) -> None:
assert digits_to_int([5]) == 5
def test_multiple_digits(self) -> None:
assert digits_to_int([1, 2, 3]) == 123
def test_with_zeros(self) -> None:
assert digits_to_int([1, 0, 0, 2]) == 1002
def test_invalid_digit(self) -> None:
with pytest.raises(ValueError, match="Invalid digit"):
digits_to_int([10])
class TestRoundTrip:
"""Tests for int_to_digits <-> digits_to_int round-trip."""
def test_round_trip(self) -> None:
for n in [0, 42, 123, 1000, 999999, 1000000]:
assert digits_to_int(int_to_digits(n)) == abs(n)
class TestReadDouble:
"""Tests for read_double function."""
def test_single_digit(self) -> None:
assert read_double(0, 7) == "seven"
assert read_double(0, 0) == "zero"
def test_teens(self) -> None:
assert read_double(1, 1) == "eleven"
assert read_double(1, 9) == "nineteen"
def test_tens(self) -> None:
assert read_double(3, 0) == "thirty"
assert read_double(5, 0) == "fifty"
def test_tens_and_ones(self) -> None:
assert read_double(2, 3) == "twenty three"
assert read_double(5, 9) == "fifty nine"
def test_invalid_digits(self) -> None:
with pytest.raises(ValueError, match="must be between 0 and 9"):
read_double(10, 5)
class TestReadTriplet:
"""Tests for read_triplet function."""
def test_hundreds(self) -> None:
assert read_triplet(1, 0, 6) == "one hundred six"
assert read_triplet(2, 0, 0) == "two hundred"
def test_zero_hundreds(self) -> None:
assert read_triplet(0, 5, 5) == "fifty five"
def test_all_zeros(self) -> None:
assert read_triplet(0, 0, 0) == "zero"
class TestReadDigits:
"""Tests for read_digits function."""
def test_empty(self) -> None:
assert read_digits([]) == "zero"
def test_zero(self) -> None:
assert read_digits([0]) == "zero"
assert read_digits([0, 0, 0]) == "zero"
def test_single_digit(self) -> None:
assert read_digits([5]) == "five"
def test_double_digit(self) -> None:
assert read_digits([4, 2]) == "forty two"
def test_triple_digit(self) -> None:
assert read_digits([1, 2, 3]) == "one hundred twenty three"
def test_thousands(self) -> None:
assert read_digits([1, 0, 0, 0]) == "one thousand"
assert read_digits([1, 2, 3, 4]) == "one thousand two hundred thirty four"
def test_millions(self) -> None:
assert read_digits([1, 0, 0, 0, 0, 0, 0]) == "one million"
def test_complex(self) -> None:
# 1,234,567
digits = [1, 2, 3, 4, 5, 6, 7]
result = read_digits(digits)
assert "one million" in result
assert "two hundred thirty four thousand" in result
assert "five hundred sixty seven" in result
def test_invalid_digit(self) -> None:
with pytest.raises(ValueError, match="must be digits"):
read_digits([1, 10, 3])
class TestEncode:
"""Tests for encode function."""
def test_simple(self) -> None:
indices = encode("one million")
assert len(indices) == 2
assert all(0 <= i < len(VOCABULARY) for i in indices)
def test_multi_word(self) -> None:
indices = encode("twenty three")
assert len(indices) == 2
def test_empty(self) -> None:
assert encode("") == []
assert encode(" ") == []
def test_unknown_word(self) -> None:
with pytest.raises(ValueError, match="Unknown word"):
encode("unknown")
class TestDecode:
"""Tests for decode function."""
def test_simple(self) -> None:
encoded = encode("one million")
assert decode(encoded) == "one million"
def test_with_eos(self) -> None:
encoded = encode("one million") + [EOS_IDX]
assert decode(encoded) == "one million"
def test_empty(self) -> None:
assert decode([]) == ""
def test_invalid_index(self) -> None:
with pytest.raises(ValueError, match="out of range"):
decode([9999])
class TestEncodeDecodeRoundTrip:
"""Tests for encode/decode round-trip."""
def test_round_trip(self) -> None:
test_cases = [
"one million",
"twenty three",
"one hundred twenty three",
"zero",
"nine hundred nineteen",
]
for text in test_cases:
encoded = encode(text)
assert decode(encoded) == text