multilabel-news-classifier / tests /test_tokenization.py
Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""Tests for tokenization utilities."""
import pytest
import torch
from utils.tokenization import (
RussianTextTokenizer,
create_tokenizer,
tokenize_text_pair,
)
class TestRussianTextTokenizer:
"""Tests for Russian text tokenizer."""
def test_initialization(self):
"""Test tokenizer initialization."""
tokenizer = RussianTextTokenizer(
model_name="DeepPavlov/rubert-base-cased",
max_length=128
)
assert tokenizer.tokenizer is not None
assert tokenizer.max_length == 128
assert tokenizer.get_vocab_size() > 0
def test_tokenize_russian_text(self):
"""Test tokenization of Russian text."""
tokenizer = RussianTextTokenizer()
text = "Привет, мир!"
tokens = tokenizer.tokenize(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
# Should include special tokens if add_special_tokens=True
assert any('[CLS]' in str(t) or 'CLS' in str(t) for t in tokens) or len(tokens) > 0
def test_encode_russian_text(self):
"""Test encoding of Russian text."""
tokenizer = RussianTextTokenizer(max_length=128)
text = "Это тестовый текст на русском языке"
encoded = tokenizer.encode(text)
assert 'input_ids' in encoded
assert 'attention_mask' in encoded
assert encoded['input_ids'].shape[1] == 128 # max_length
assert encoded['attention_mask'].shape[1] == 128
def test_encode_batch(self):
"""Test batch encoding."""
tokenizer = RussianTextTokenizer(max_length=64)
texts = [
"Первая новость",
"Вторая новость",
"Третья новость"
]
encoded = tokenizer.encode_batch(texts)
assert encoded['input_ids'].shape[0] == 3 # batch size
assert encoded['input_ids'].shape[1] == 64 # max_length
assert encoded['attention_mask'].shape[0] == 3
def test_decode(self):
"""Test decoding token IDs back to text."""
tokenizer = RussianTextTokenizer()
text = "Привет, мир!"
encoded = tokenizer.encode(text, return_tensors=None)
decoded = tokenizer.decode(encoded['input_ids'][0])
# Decoded text should be similar (may have different casing/punctuation)
assert isinstance(decoded, str)
assert len(decoded) > 0
def test_special_tokens(self):
"""Test special token handling."""
tokenizer = RussianTextTokenizer()
special_tokens = tokenizer.get_special_tokens()
assert 'pad_token_id' in special_tokens
assert 'cls_token_id' in special_tokens
assert 'sep_token_id' in special_tokens
assert special_tokens['pad_token_id'] is not None
def test_padding(self):
"""Test padding behavior."""
tokenizer = RussianTextTokenizer(max_length=20, padding='max_length')
text = "Короткий текст"
encoded = tokenizer.encode(text)
# Should be padded to max_length
assert encoded['input_ids'].shape[1] == 20
assert encoded['attention_mask'].shape[1] == 20
def test_truncation(self):
"""Test truncation of long texts."""
tokenizer = RussianTextTokenizer(max_length=10, truncation=True)
# Create a long text
long_text = " ".join(["слово"] * 50)
encoded = tokenizer.encode(long_text)
# Should be truncated to max_length
assert encoded['input_ids'].shape[1] == 10
def test_subword_tokenization(self):
"""Test that subword tokenization handles unknown words."""
tokenizer = RussianTextTokenizer()
# Use a word that might not be in vocabulary
text = "НеизвестноеСловоКоторогоНетВСловаре"
tokens = tokenizer.tokenize(text, add_special_tokens=False)
# Should still tokenize (using subwords)
assert len(tokens) > 0
# Subword tokens often start with ## or are split
assert all(isinstance(t, str) for t in tokens)
class TestTokenizerFactory:
"""Tests for tokenizer factory function."""
def test_create_tokenizer(self):
"""Test tokenizer creation."""
tokenizer = create_tokenizer(
model_name="DeepPavlov/rubert-base-cased",
max_length=256
)
assert isinstance(tokenizer, RussianTextTokenizer)
assert tokenizer.max_length == 256
def test_create_multilingual_tokenizer(self):
"""Test creating multilingual tokenizer."""
tokenizer = create_tokenizer(
model_name="bert-base-multilingual-cased",
max_length=128
)
assert tokenizer.model_name == "bert-base-multilingual-cased"
assert tokenizer.max_length == 128
class TestTextPairTokenization:
"""Tests for title-snippet pair tokenization."""
def test_tokenize_text_pair(self):
"""Test tokenizing title and snippet pair."""
tokenizer = create_tokenizer()
title = "Заголовок новости"
snippet = "Краткое описание новости"
encoded = tokenize_text_pair(
title=title,
snippet=snippet,
tokenizer=tokenizer,
max_title_len=64,
max_snippet_len=128
)
assert 'title_input_ids' in encoded
assert 'title_attention_mask' in encoded
assert 'snippet_input_ids' in encoded
assert 'snippet_attention_mask' in encoded
assert encoded['title_input_ids'].shape[0] == 64
assert encoded['snippet_input_ids'].shape[0] == 128
def test_tokenize_title_only(self):
"""Test tokenizing title without snippet."""
tokenizer = create_tokenizer()
title = "Заголовок"
encoded = tokenize_text_pair(
title=title,
snippet=None,
tokenizer=tokenizer
)
assert 'title_input_ids' in encoded
assert 'snippet_input_ids' not in encoded
class TestRussianTextHandling:
"""Tests for proper Russian text handling."""
def test_cyrillic_characters(self):
"""Test handling of Cyrillic characters."""
tokenizer = RussianTextTokenizer()
# Test various Cyrillic characters
texts = [
"АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ",
"абвгдеёжзийклмнопрстуфхцчшщъыьэюя",
"1234567890",
"Смешанный текст: English and русский",
]
for text in texts:
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded['input_ids'][0])
# Should handle all without errors
assert encoded['input_ids'].shape[0] > 0
assert isinstance(decoded, str)
def test_russian_punctuation(self):
"""Test handling of Russian punctuation."""
tokenizer = RussianTextTokenizer()
text = "Текст с пунктуацией: запятые, точки. Восклицания! Вопросы?"
encoded = tokenizer.encode(text)
assert encoded['input_ids'].shape[0] > 0
assert not torch.isnan(encoded['input_ids']).any()
def test_empty_text_handling(self):
"""Test handling of empty or whitespace-only text."""
tokenizer = RussianTextTokenizer()
# Empty string
encoded = tokenizer.encode("")
assert encoded['input_ids'].shape[0] > 0
# Whitespace only
encoded = tokenizer.encode(" ")
assert encoded['input_ids'].shape[0] > 0
def test_very_long_text(self):
"""Test handling of very long texts (should truncate)."""
tokenizer = RussianTextTokenizer(max_length=50, truncation=True)
# Create very long text
long_text = " ".join(["слово"] * 200)
encoded = tokenizer.encode(long_text)
# Should be truncated
assert encoded['input_ids'].shape[1] == 50
class TestSubwordTokenization:
"""Tests for subword tokenization features."""
def test_unknown_word_handling(self):
"""Test that unknown words are handled via subword tokenization."""
tokenizer = RussianTextTokenizer()
# Word that likely doesn't exist in vocabulary
unknown_word = "НесуществующееСловоКоторогоТочноНетВСловаре12345"
tokens = tokenizer.tokenize(unknown_word, add_special_tokens=False)
# Should be split into subwords
assert len(tokens) > 0
# All should be valid tokens
assert all(isinstance(t, str) for t in tokens)
def test_word_piece_tokenization(self):
"""Test WordPiece subword tokenization."""
tokenizer = RussianTextTokenizer()
# Common Russian word
text = "правительство"
tokens = tokenizer.tokenize(text, add_special_tokens=False)
# Should tokenize (may be single token or multiple subwords)
assert len(tokens) > 0
def test_vocabulary_coverage(self):
"""Test that tokenizer has good vocabulary coverage."""
tokenizer = RussianTextTokenizer()
vocab_size = tokenizer.get_vocab_size()
# BERT models typically have 30K+ vocabulary
assert vocab_size > 10000
assert vocab_size < 1000000 # Reasonable upper bound
def test_token_info(self):
"""Test getting token information."""
tokenizer = RussianTextTokenizer()
# Get a token ID
special_tokens = tokenizer.get_special_tokens()
pad_id = special_tokens['pad_token_id']
info = tokenizer.get_token_info(pad_id)
assert 'token_id' in info
assert 'token' in info
assert 'is_special' in info
assert info['token_id'] == pad_id