|
|
"""Tests for tokenization utilities.""" |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from utils.tokenization import ( |
|
|
RussianTextTokenizer, |
|
|
create_tokenizer, |
|
|
tokenize_text_pair, |
|
|
) |
|
|
|
|
|
|
|
|
class TestRussianTextTokenizer: |
|
|
"""Tests for Russian text tokenizer.""" |
|
|
|
|
|
def test_initialization(self): |
|
|
"""Test tokenizer initialization.""" |
|
|
tokenizer = RussianTextTokenizer( |
|
|
model_name="DeepPavlov/rubert-base-cased", |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
assert tokenizer.tokenizer is not None |
|
|
assert tokenizer.max_length == 128 |
|
|
assert tokenizer.get_vocab_size() > 0 |
|
|
|
|
|
def test_tokenize_russian_text(self): |
|
|
"""Test tokenization of Russian text.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
text = "Привет, мир!" |
|
|
tokens = tokenizer.tokenize(text) |
|
|
|
|
|
assert isinstance(tokens, list) |
|
|
assert len(tokens) > 0 |
|
|
|
|
|
assert any('[CLS]' in str(t) or 'CLS' in str(t) for t in tokens) or len(tokens) > 0 |
|
|
|
|
|
def test_encode_russian_text(self): |
|
|
"""Test encoding of Russian text.""" |
|
|
tokenizer = RussianTextTokenizer(max_length=128) |
|
|
|
|
|
text = "Это тестовый текст на русском языке" |
|
|
encoded = tokenizer.encode(text) |
|
|
|
|
|
assert 'input_ids' in encoded |
|
|
assert 'attention_mask' in encoded |
|
|
assert encoded['input_ids'].shape[1] == 128 |
|
|
assert encoded['attention_mask'].shape[1] == 128 |
|
|
|
|
|
def test_encode_batch(self): |
|
|
"""Test batch encoding.""" |
|
|
tokenizer = RussianTextTokenizer(max_length=64) |
|
|
|
|
|
texts = [ |
|
|
"Первая новость", |
|
|
"Вторая новость", |
|
|
"Третья новость" |
|
|
] |
|
|
|
|
|
encoded = tokenizer.encode_batch(texts) |
|
|
|
|
|
assert encoded['input_ids'].shape[0] == 3 |
|
|
assert encoded['input_ids'].shape[1] == 64 |
|
|
assert encoded['attention_mask'].shape[0] == 3 |
|
|
|
|
|
def test_decode(self): |
|
|
"""Test decoding token IDs back to text.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
text = "Привет, мир!" |
|
|
encoded = tokenizer.encode(text, return_tensors=None) |
|
|
|
|
|
decoded = tokenizer.decode(encoded['input_ids'][0]) |
|
|
|
|
|
|
|
|
assert isinstance(decoded, str) |
|
|
assert len(decoded) > 0 |
|
|
|
|
|
def test_special_tokens(self): |
|
|
"""Test special token handling.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
special_tokens = tokenizer.get_special_tokens() |
|
|
|
|
|
assert 'pad_token_id' in special_tokens |
|
|
assert 'cls_token_id' in special_tokens |
|
|
assert 'sep_token_id' in special_tokens |
|
|
assert special_tokens['pad_token_id'] is not None |
|
|
|
|
|
def test_padding(self): |
|
|
"""Test padding behavior.""" |
|
|
tokenizer = RussianTextTokenizer(max_length=20, padding='max_length') |
|
|
|
|
|
text = "Короткий текст" |
|
|
encoded = tokenizer.encode(text) |
|
|
|
|
|
|
|
|
assert encoded['input_ids'].shape[1] == 20 |
|
|
assert encoded['attention_mask'].shape[1] == 20 |
|
|
|
|
|
def test_truncation(self): |
|
|
"""Test truncation of long texts.""" |
|
|
tokenizer = RussianTextTokenizer(max_length=10, truncation=True) |
|
|
|
|
|
|
|
|
long_text = " ".join(["слово"] * 50) |
|
|
encoded = tokenizer.encode(long_text) |
|
|
|
|
|
|
|
|
assert encoded['input_ids'].shape[1] == 10 |
|
|
|
|
|
def test_subword_tokenization(self): |
|
|
"""Test that subword tokenization handles unknown words.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
text = "НеизвестноеСловоКоторогоНетВСловаре" |
|
|
tokens = tokenizer.tokenize(text, add_special_tokens=False) |
|
|
|
|
|
|
|
|
assert len(tokens) > 0 |
|
|
|
|
|
assert all(isinstance(t, str) for t in tokens) |
|
|
|
|
|
|
|
|
class TestTokenizerFactory: |
|
|
"""Tests for tokenizer factory function.""" |
|
|
|
|
|
def test_create_tokenizer(self): |
|
|
"""Test tokenizer creation.""" |
|
|
tokenizer = create_tokenizer( |
|
|
model_name="DeepPavlov/rubert-base-cased", |
|
|
max_length=256 |
|
|
) |
|
|
|
|
|
assert isinstance(tokenizer, RussianTextTokenizer) |
|
|
assert tokenizer.max_length == 256 |
|
|
|
|
|
def test_create_multilingual_tokenizer(self): |
|
|
"""Test creating multilingual tokenizer.""" |
|
|
tokenizer = create_tokenizer( |
|
|
model_name="bert-base-multilingual-cased", |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
assert tokenizer.model_name == "bert-base-multilingual-cased" |
|
|
assert tokenizer.max_length == 128 |
|
|
|
|
|
|
|
|
class TestTextPairTokenization: |
|
|
"""Tests for title-snippet pair tokenization.""" |
|
|
|
|
|
def test_tokenize_text_pair(self): |
|
|
"""Test tokenizing title and snippet pair.""" |
|
|
tokenizer = create_tokenizer() |
|
|
|
|
|
title = "Заголовок новости" |
|
|
snippet = "Краткое описание новости" |
|
|
|
|
|
encoded = tokenize_text_pair( |
|
|
title=title, |
|
|
snippet=snippet, |
|
|
tokenizer=tokenizer, |
|
|
max_title_len=64, |
|
|
max_snippet_len=128 |
|
|
) |
|
|
|
|
|
assert 'title_input_ids' in encoded |
|
|
assert 'title_attention_mask' in encoded |
|
|
assert 'snippet_input_ids' in encoded |
|
|
assert 'snippet_attention_mask' in encoded |
|
|
|
|
|
assert encoded['title_input_ids'].shape[0] == 64 |
|
|
assert encoded['snippet_input_ids'].shape[0] == 128 |
|
|
|
|
|
def test_tokenize_title_only(self): |
|
|
"""Test tokenizing title without snippet.""" |
|
|
tokenizer = create_tokenizer() |
|
|
|
|
|
title = "Заголовок" |
|
|
|
|
|
encoded = tokenize_text_pair( |
|
|
title=title, |
|
|
snippet=None, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
assert 'title_input_ids' in encoded |
|
|
assert 'snippet_input_ids' not in encoded |
|
|
|
|
|
|
|
|
class TestRussianTextHandling: |
|
|
"""Tests for proper Russian text handling.""" |
|
|
|
|
|
def test_cyrillic_characters(self): |
|
|
"""Test handling of Cyrillic characters.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
texts = [ |
|
|
"АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ", |
|
|
"абвгдеёжзийклмнопрстуфхцчшщъыьэюя", |
|
|
"1234567890", |
|
|
"Смешанный текст: English and русский", |
|
|
] |
|
|
|
|
|
for text in texts: |
|
|
encoded = tokenizer.encode(text) |
|
|
decoded = tokenizer.decode(encoded['input_ids'][0]) |
|
|
|
|
|
|
|
|
assert encoded['input_ids'].shape[0] > 0 |
|
|
assert isinstance(decoded, str) |
|
|
|
|
|
def test_russian_punctuation(self): |
|
|
"""Test handling of Russian punctuation.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
text = "Текст с пунктуацией: запятые, точки. Восклицания! Вопросы?" |
|
|
encoded = tokenizer.encode(text) |
|
|
|
|
|
assert encoded['input_ids'].shape[0] > 0 |
|
|
assert not torch.isnan(encoded['input_ids']).any() |
|
|
|
|
|
def test_empty_text_handling(self): |
|
|
"""Test handling of empty or whitespace-only text.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
encoded = tokenizer.encode("") |
|
|
assert encoded['input_ids'].shape[0] > 0 |
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(" ") |
|
|
assert encoded['input_ids'].shape[0] > 0 |
|
|
|
|
|
def test_very_long_text(self): |
|
|
"""Test handling of very long texts (should truncate).""" |
|
|
tokenizer = RussianTextTokenizer(max_length=50, truncation=True) |
|
|
|
|
|
|
|
|
long_text = " ".join(["слово"] * 200) |
|
|
encoded = tokenizer.encode(long_text) |
|
|
|
|
|
|
|
|
assert encoded['input_ids'].shape[1] == 50 |
|
|
|
|
|
|
|
|
class TestSubwordTokenization: |
|
|
"""Tests for subword tokenization features.""" |
|
|
|
|
|
def test_unknown_word_handling(self): |
|
|
"""Test that unknown words are handled via subword tokenization.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
unknown_word = "НесуществующееСловоКоторогоТочноНетВСловаре12345" |
|
|
tokens = tokenizer.tokenize(unknown_word, add_special_tokens=False) |
|
|
|
|
|
|
|
|
assert len(tokens) > 0 |
|
|
|
|
|
assert all(isinstance(t, str) for t in tokens) |
|
|
|
|
|
def test_word_piece_tokenization(self): |
|
|
"""Test WordPiece subword tokenization.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
text = "правительство" |
|
|
tokens = tokenizer.tokenize(text, add_special_tokens=False) |
|
|
|
|
|
|
|
|
assert len(tokens) > 0 |
|
|
|
|
|
def test_vocabulary_coverage(self): |
|
|
"""Test that tokenizer has good vocabulary coverage.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
vocab_size = tokenizer.get_vocab_size() |
|
|
|
|
|
|
|
|
assert vocab_size > 10000 |
|
|
assert vocab_size < 1000000 |
|
|
|
|
|
def test_token_info(self): |
|
|
"""Test getting token information.""" |
|
|
tokenizer = RussianTextTokenizer() |
|
|
|
|
|
|
|
|
special_tokens = tokenizer.get_special_tokens() |
|
|
pad_id = special_tokens['pad_token_id'] |
|
|
|
|
|
info = tokenizer.get_token_info(pad_id) |
|
|
|
|
|
assert 'token_id' in info |
|
|
assert 'token' in info |
|
|
assert 'is_special' in info |
|
|
assert info['token_id'] == pad_id |
|
|
|
|
|
|