"""Tests for LatinBertTokenizer. Validates that the HF wrapper produces identical output to the original tensor2tensor SubwordTextEncoder used in Bamman & Burns (2020). Reference IDs generated from the standalone encoder on the cluster. """ import os import pytest from pathlib import Path VOCAB_FILE = str(Path(__file__).parent / "latin.subword.encoder") @pytest.fixture def tokenizer(): from latincy_latinbert import LatinBertTokenizer return LatinBertTokenizer(vocab_file=VOCAB_FILE) class TestSpecialTokens: def test_special_token_ids(self, tokenizer): """BERT special tokens must occupy IDs 0-4.""" assert tokenizer.convert_tokens_to_ids("[PAD]") == 0 assert tokenizer.convert_tokens_to_ids("[UNK]") == 1 assert tokenizer.convert_tokens_to_ids("[CLS]") == 2 assert tokenizer.convert_tokens_to_ids("[SEP]") == 3 assert tokenizer.convert_tokens_to_ids("[MASK]") == 4 def test_special_token_strings(self, tokenizer): assert tokenizer.pad_token == "[PAD]" assert tokenizer.unk_token == "[UNK]" assert tokenizer.cls_token == "[CLS]" assert tokenizer.sep_token == "[SEP]" assert tokenizer.mask_token == "[MASK]" def test_vocab_size_includes_specials(self, tokenizer): """vocab_size = 5 special + 32895 subtokens = 32900.""" assert tokenizer.vocab_size == 32900 def test_subtoken_offset(self, tokenizer): """First subtoken '_' from encoder should be at ID 5, not 0.""" assert tokenizer.convert_tokens_to_ids("_") == 5 def test_add_special_tokens_encoding(self, tokenizer): """encode with add_special_tokens=True should wrap with [CLS]/[SEP].""" ids = tokenizer.encode("et", add_special_tokens=True) assert ids[0] == 2 # [CLS] assert ids[-1] == 3 # [SEP] class TestVocab: def test_vocab_size(self, tokenizer): assert tokenizer.vocab_size == 32900 def test_pad_token_id(self, tokenizer): assert tokenizer.pad_token == "[PAD]" assert tokenizer.convert_tokens_to_ids("[PAD]") == 0 def test_eos_token(self, tokenizer): assert tokenizer.eos_token == "_" assert tokenizer.convert_tokens_to_ids("_") == 6 # was 1, now 1+5 class TestEncoding: """Reference IDs from original LatinTokenizer (with +5 offset).""" def test_gallia(self, tokenizer): ids = tokenizer.encode("Gallia est omnis divisa in partes tres", add_special_tokens=False) # With do_lower_case=True, "Gallia" → "gallia_" (single token) expected = [6533, 32888, 7735, 13, 15, 32888, 7735, 13, 343, 32888, 7735, 13, 6773, 32888, 7735, 13, 12, 32888, 7735, 13, 568, 32888, 7735, 13, 564] assert ids == expected def test_arma(self, tokenizer): ids = tokenizer.encode("arma virumque cano", add_special_tokens=False) expected = [915, 32888, 7735, 13, 18566, 8107, 32888, 7735, 13, 4420] assert ids == expected def test_uppercase(self, tokenizer): """Uppercase input should be lowercased, not escaped to codepoints.""" ids = tokenizer.encode("ROMA", add_special_tokens=False) expected = [2560] # 'roma_' — single token, not 10 escaped codepoints assert ids == expected def test_empty(self, tokenizer): ids = tokenizer.encode("", add_special_tokens=False) assert ids == [] class TestRoundtrip: def test_decode_lowercase(self, tokenizer): """Lowercase text should roundtrip exactly.""" text = "gallia est omnis divisa in partes tres" ids = tokenizer.encode(text, add_special_tokens=False) decoded = tokenizer.decode(ids) assert decoded == text def test_decode_arma(self, tokenizer): text = "arma virumque cano" ids = tokenizer.encode(text, add_special_tokens=False) decoded = tokenizer.decode(ids) assert decoded == text def test_decode_with_punctuation(self, tokenizer): text = "gallia est omnis divisa in partes tres." ids = tokenizer.encode(text, add_special_tokens=False) decoded = tokenizer.decode(ids) assert decoded == text def test_decode_uppercase_lossy(self, tokenizer): """Uppercase input decodes to lowercase (lowercasing is lossy).""" ids = tokenizer.encode("Gallia", add_special_tokens=False) decoded = tokenizer.decode(ids) assert decoded == "gallia" class TestLowercasing: """Verify do_lower_case=True matches original Latin BERT behavior.""" def test_case_insensitive_ids(self, tokenizer): """Uppercase and lowercase input must produce identical IDs.""" assert (tokenizer.encode("gallia", add_special_tokens=False) == tokenizer.encode("Gallia", add_special_tokens=False)) assert (tokenizer.encode("roma", add_special_tokens=False) == tokenizer.encode("ROMA", add_special_tokens=False)) def test_no_codepoint_escapes(self, tokenizer): """Uppercase letters should not produce \\; escape sequences.""" tokens = tokenizer.tokenize("Cytherea") # Should be clean subwords, not ['\\', '67', ';', ...] assert tokens[0] != "\\" assert all(not t.isdigit() or len(t) > 2 for t in tokens) def test_reasonable_expansion_ratio(self, tokenizer): """With lowercasing, proper nouns should not explode into codepoint escapes. Spaces are escaped by tensor2tensor's design (3 tokens each), so we count only word subtokens, excluding space-escape sequences. """ text = "Cytherea Camenis Roma Gallia" tokens = tokenizer.tokenize(text) # Filter out space-escape tokens (\, 32, ;_) word_tokens = [t for t in tokens if t not in ("\\", "32", ";_")] words = text.split() ratio = len(word_tokens) / len(words) assert ratio < 2.5, f"Word expansion ratio {ratio:.1f}x is too high" def test_do_lower_case_false(self): """With do_lower_case=False, uppercase chars are escaped (old behavior).""" from latincy_latinbert import LatinBertTokenizer tok = LatinBertTokenizer(vocab_file=VOCAB_FILE, do_lower_case=False) tokens = tok.tokenize("Cytherea") # First token should be backslash escape for uppercase C assert tokens[0] == "\\" def test_do_lower_case_default_true(self, tokenizer): """Default tokenizer has do_lower_case=True.""" assert tokenizer.do_lower_case is True class TestSaveLoad: def test_save_and_reload(self, tokenizer, tmp_path): tokenizer.save_pretrained(tmp_path) from latincy_latinbert import LatinBertTokenizer loaded = LatinBertTokenizer.from_pretrained(tmp_path) text = "Gallia est omnis divisa in partes tres" assert tokenizer.encode(text) == loaded.encode(text) def test_vocab_file_saved(self, tokenizer, tmp_path): tokenizer.save_pretrained(tmp_path) assert (tmp_path / "latin.subword.encoder").exists() def test_config_saved(self, tokenizer, tmp_path): tokenizer.save_pretrained(tmp_path) assert (tmp_path / "tokenizer_config.json").exists()