File size: 10,452 Bytes
198ccb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""Tests for tokenization utilities."""

import pytest
import torch
from utils.tokenization import (
    RussianTextTokenizer,
    create_tokenizer,
    tokenize_text_pair,
)


class TestRussianTextTokenizer:
    """Tests for Russian text tokenizer."""
    
    def test_initialization(self):
        """Test tokenizer initialization."""
        tokenizer = RussianTextTokenizer(
            model_name="DeepPavlov/rubert-base-cased",
            max_length=128
        )
        
        assert tokenizer.tokenizer is not None
        assert tokenizer.max_length == 128
        assert tokenizer.get_vocab_size() > 0
    
    def test_tokenize_russian_text(self):
        """Test tokenization of Russian text."""
        tokenizer = RussianTextTokenizer()
        
        text = "Привет, мир!"
        tokens = tokenizer.tokenize(text)
        
        assert isinstance(tokens, list)
        assert len(tokens) > 0
        # Should include special tokens if add_special_tokens=True
        assert any('[CLS]' in str(t) or 'CLS' in str(t) for t in tokens) or len(tokens) > 0
    
    def test_encode_russian_text(self):
        """Test encoding of Russian text."""
        tokenizer = RussianTextTokenizer(max_length=128)
        
        text = "Это тестовый текст на русском языке"
        encoded = tokenizer.encode(text)
        
        assert 'input_ids' in encoded
        assert 'attention_mask' in encoded
        assert encoded['input_ids'].shape[1] == 128  # max_length
        assert encoded['attention_mask'].shape[1] == 128
    
    def test_encode_batch(self):
        """Test batch encoding."""
        tokenizer = RussianTextTokenizer(max_length=64)
        
        texts = [
            "Первая новость",
            "Вторая новость",
            "Третья новость"
        ]
        
        encoded = tokenizer.encode_batch(texts)
        
        assert encoded['input_ids'].shape[0] == 3  # batch size
        assert encoded['input_ids'].shape[1] == 64  # max_length
        assert encoded['attention_mask'].shape[0] == 3
    
    def test_decode(self):
        """Test decoding token IDs back to text."""
        tokenizer = RussianTextTokenizer()
        
        text = "Привет, мир!"
        encoded = tokenizer.encode(text, return_tensors=None)
        
        decoded = tokenizer.decode(encoded['input_ids'][0])
        
        # Decoded text should be similar (may have different casing/punctuation)
        assert isinstance(decoded, str)
        assert len(decoded) > 0
    
    def test_special_tokens(self):
        """Test special token handling."""
        tokenizer = RussianTextTokenizer()
        
        special_tokens = tokenizer.get_special_tokens()
        
        assert 'pad_token_id' in special_tokens
        assert 'cls_token_id' in special_tokens
        assert 'sep_token_id' in special_tokens
        assert special_tokens['pad_token_id'] is not None
    
    def test_padding(self):
        """Test padding behavior."""
        tokenizer = RussianTextTokenizer(max_length=20, padding='max_length')
        
        text = "Короткий текст"
        encoded = tokenizer.encode(text)
        
        # Should be padded to max_length
        assert encoded['input_ids'].shape[1] == 20
        assert encoded['attention_mask'].shape[1] == 20
    
    def test_truncation(self):
        """Test truncation of long texts."""
        tokenizer = RussianTextTokenizer(max_length=10, truncation=True)
        
        # Create a long text
        long_text = " ".join(["слово"] * 50)
        encoded = tokenizer.encode(long_text)
        
        # Should be truncated to max_length
        assert encoded['input_ids'].shape[1] == 10
    
    def test_subword_tokenization(self):
        """Test that subword tokenization handles unknown words."""
        tokenizer = RussianTextTokenizer()
        
        # Use a word that might not be in vocabulary
        text = "НеизвестноеСловоКоторогоНетВСловаре"
        tokens = tokenizer.tokenize(text, add_special_tokens=False)
        
        # Should still tokenize (using subwords)
        assert len(tokens) > 0
        # Subword tokens often start with ## or are split
        assert all(isinstance(t, str) for t in tokens)


class TestTokenizerFactory:
    """Tests for tokenizer factory function."""
    
    def test_create_tokenizer(self):
        """Test tokenizer creation."""
        tokenizer = create_tokenizer(
            model_name="DeepPavlov/rubert-base-cased",
            max_length=256
        )
        
        assert isinstance(tokenizer, RussianTextTokenizer)
        assert tokenizer.max_length == 256
    
    def test_create_multilingual_tokenizer(self):
        """Test creating multilingual tokenizer."""
        tokenizer = create_tokenizer(
            model_name="bert-base-multilingual-cased",
            max_length=128
        )
        
        assert tokenizer.model_name == "bert-base-multilingual-cased"
        assert tokenizer.max_length == 128


class TestTextPairTokenization:
    """Tests for title-snippet pair tokenization."""
    
    def test_tokenize_text_pair(self):
        """Test tokenizing title and snippet pair."""
        tokenizer = create_tokenizer()
        
        title = "Заголовок новости"
        snippet = "Краткое описание новости"
        
        encoded = tokenize_text_pair(
            title=title,
            snippet=snippet,
            tokenizer=tokenizer,
            max_title_len=64,
            max_snippet_len=128
        )
        
        assert 'title_input_ids' in encoded
        assert 'title_attention_mask' in encoded
        assert 'snippet_input_ids' in encoded
        assert 'snippet_attention_mask' in encoded
        
        assert encoded['title_input_ids'].shape[0] == 64
        assert encoded['snippet_input_ids'].shape[0] == 128
    
    def test_tokenize_title_only(self):
        """Test tokenizing title without snippet."""
        tokenizer = create_tokenizer()
        
        title = "Заголовок"
        
        encoded = tokenize_text_pair(
            title=title,
            snippet=None,
            tokenizer=tokenizer
        )
        
        assert 'title_input_ids' in encoded
        assert 'snippet_input_ids' not in encoded


class TestRussianTextHandling:
    """Tests for proper Russian text handling."""
    
    def test_cyrillic_characters(self):
        """Test handling of Cyrillic characters."""
        tokenizer = RussianTextTokenizer()
        
        # Test various Cyrillic characters
        texts = [
            "АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ",
            "абвгдеёжзийклмнопрстуфхцчшщъыьэюя",
            "1234567890",
            "Смешанный текст: English and русский",
        ]
        
        for text in texts:
            encoded = tokenizer.encode(text)
            decoded = tokenizer.decode(encoded['input_ids'][0])
            
            # Should handle all without errors
            assert encoded['input_ids'].shape[0] > 0
            assert isinstance(decoded, str)
    
    def test_russian_punctuation(self):
        """Test handling of Russian punctuation."""
        tokenizer = RussianTextTokenizer()
        
        text = "Текст с пунктуацией: запятые, точки. Восклицания! Вопросы?"
        encoded = tokenizer.encode(text)
        
        assert encoded['input_ids'].shape[0] > 0
        assert not torch.isnan(encoded['input_ids']).any()
    
    def test_empty_text_handling(self):
        """Test handling of empty or whitespace-only text."""
        tokenizer = RussianTextTokenizer()
        
        # Empty string
        encoded = tokenizer.encode("")
        assert encoded['input_ids'].shape[0] > 0
        
        # Whitespace only
        encoded = tokenizer.encode("   ")
        assert encoded['input_ids'].shape[0] > 0
    
    def test_very_long_text(self):
        """Test handling of very long texts (should truncate)."""
        tokenizer = RussianTextTokenizer(max_length=50, truncation=True)
        
        # Create very long text
        long_text = " ".join(["слово"] * 200)
        encoded = tokenizer.encode(long_text)
        
        # Should be truncated
        assert encoded['input_ids'].shape[1] == 50


class TestSubwordTokenization:
    """Tests for subword tokenization features."""
    
    def test_unknown_word_handling(self):
        """Test that unknown words are handled via subword tokenization."""
        tokenizer = RussianTextTokenizer()
        
        # Word that likely doesn't exist in vocabulary
        unknown_word = "НесуществующееСловоКоторогоТочноНетВСловаре12345"
        tokens = tokenizer.tokenize(unknown_word, add_special_tokens=False)
        
        # Should be split into subwords
        assert len(tokens) > 0
        # All should be valid tokens
        assert all(isinstance(t, str) for t in tokens)
    
    def test_word_piece_tokenization(self):
        """Test WordPiece subword tokenization."""
        tokenizer = RussianTextTokenizer()
        
        # Common Russian word
        text = "правительство"
        tokens = tokenizer.tokenize(text, add_special_tokens=False)
        
        # Should tokenize (may be single token or multiple subwords)
        assert len(tokens) > 0
    
    def test_vocabulary_coverage(self):
        """Test that tokenizer has good vocabulary coverage."""
        tokenizer = RussianTextTokenizer()
        
        vocab_size = tokenizer.get_vocab_size()
        
        # BERT models typically have 30K+ vocabulary
        assert vocab_size > 10000
        assert vocab_size < 1000000  # Reasonable upper bound
    
    def test_token_info(self):
        """Test getting token information."""
        tokenizer = RussianTextTokenizer()
        
        # Get a token ID
        special_tokens = tokenizer.get_special_tokens()
        pad_id = special_tokens['pad_token_id']
        
        info = tokenizer.get_token_info(pad_id)
        
        assert 'token_id' in info
        assert 'token' in info
        assert 'is_special' in info
        assert info['token_id'] == pad_id