File size: 10,823 Bytes
61ba51e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | """
Unit tests for AsyncDynamicbatchTokenizer.
Tests the async dynamic batching functionality for tokenization,
including batch efficiency, timeout handling, and error cases.
"""
import asyncio
import logging
import time
from unittest.mock import Mock
import pytest
from transformers import AutoTokenizer
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
class TestAsyncDynamicbatchTokenizer:
"""Test suite for AsyncDynamicbatchTokenizer."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer that behaves like HuggingFace tokenizer."""
def mock_encode(texts, **kwargs):
is_single = isinstance(texts, str)
if is_single:
texts = [texts]
# Simulate tokenization - convert text to mock token ids
input_ids = []
token_type_ids = []
for text in texts:
# Simple mock: text length determines number of tokens
tokens = [i for i in range(len(text.split()))]
input_ids.append(tokens)
if kwargs.get("return_token_type_ids", False):
token_type_ids.append([0] * len(tokens))
result = {"input_ids": input_ids}
if kwargs.get("return_token_type_ids", False):
result["token_type_ids"] = token_type_ids
# For single inputs, return individual result (not wrapped in a list)
if is_single:
result = {"input_ids": input_ids[0]}
if kwargs.get("return_token_type_ids", False):
result["token_type_ids"] = token_type_ids[0]
# Create a proper BatchEncoding-like object that supports dict operations
class MockBatchEncoding(dict):
def __init__(self, data):
super().__init__(data)
for key, value in data.items():
setattr(self, key, value)
return MockBatchEncoding(result)
# Return the function directly - the AsyncDynamicbatchTokenizer will call it
return mock_encode
@pytest.fixture
def async_tokenizer(self, mock_tokenizer):
"""Create AsyncDynamicbatchTokenizer instance."""
return AsyncDynamicbatchTokenizer(
tokenizer=mock_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
)
@pytest.mark.asyncio
async def test_single_request(self, async_tokenizer):
"""Test tokenizing a single request."""
text = "hello world"
result = await async_tokenizer.encode(text)
assert "input_ids" in result
assert result["input_ids"] == [0, 1] # 2 words -> 2 tokens
@pytest.mark.asyncio
async def test_single_request_with_token_type_ids(self, async_tokenizer):
"""Test tokenizing with token type IDs."""
text = "hello world"
result = await async_tokenizer.encode(text, return_token_type_ids=True)
assert "input_ids" in result
assert "token_type_ids" in result
assert result["input_ids"] == [0, 1]
assert result["token_type_ids"] == [0, 0]
@pytest.mark.asyncio
async def test_concurrent_requests_same_kwargs(self, async_tokenizer):
"""Test that concurrent requests with same kwargs get batched."""
texts = ["hello world", "how are you", "fine thanks", "good morning"]
# Start all requests concurrently
tasks = [async_tokenizer.encode(text) for text in texts]
results = await asyncio.gather(*tasks)
# Verify all results
assert len(results) == 4
for i, result in enumerate(results):
assert "input_ids" in result
expected_tokens = list(range(len(texts[i].split())))
assert result["input_ids"] == expected_tokens
@pytest.mark.asyncio
async def test_concurrent_requests_different_kwargs(self, async_tokenizer):
"""Test that requests with different kwargs are processed individually."""
text1 = "hello world"
text2 = "how are you"
# One with token_type_ids, one without
task1 = async_tokenizer.encode(text1, return_token_type_ids=True)
task2 = async_tokenizer.encode(text2)
result1, result2 = await asyncio.gather(task1, task2)
# First result should have token_type_ids
assert "input_ids" in result1
assert "token_type_ids" in result1
assert result1["input_ids"] == [0, 1]
assert result1["token_type_ids"] == [0, 0]
# Second result should not have token_type_ids
assert "input_ids" in result2
assert "token_type_ids" not in result2
assert result2["input_ids"] == [0, 1, 2]
@pytest.mark.asyncio
async def test_batch_timeout(self, async_tokenizer):
"""Test that batching respects timeout."""
# Send first request
task1 = asyncio.create_task(async_tokenizer.encode("hello world"))
# Wait longer than batch timeout
await asyncio.sleep(0.02) # Longer than 0.01s timeout
# Send second request
task2 = asyncio.create_task(async_tokenizer.encode("how are you"))
results = await asyncio.gather(task1, task2)
# Both should complete successfully
assert len(results) == 2
assert results[0]["input_ids"] == [0, 1]
assert results[1]["input_ids"] == [0, 1, 2]
@pytest.mark.asyncio
async def test_max_batch_size_limit(self, async_tokenizer):
"""Test that batching respects max_batch_size."""
# Send more requests than max_batch_size (4)
texts = [f"text {i}" for i in range(6)]
tasks = [async_tokenizer.encode(text) for text in texts]
results = await asyncio.gather(*tasks)
# All should complete successfully
assert len(results) == 6
for i, result in enumerate(results):
assert "input_ids" in result
assert result["input_ids"] == [0, 1] # "text i" -> 2 tokens
@pytest.mark.asyncio
async def test_callable_interface(self, async_tokenizer):
"""Test that the tokenizer is callable."""
text = "hello world"
result = await async_tokenizer(text)
assert "input_ids" in result
assert result["input_ids"] == [0, 1]
@pytest.mark.asyncio
async def test_lazy_initialization(self, mock_tokenizer):
"""Test that initialization happens lazily."""
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
# Should not be initialized yet
assert not tokenizer._initialized
# First encode should initialize
await tokenizer.encode("hello")
# Should now be initialized
assert tokenizer._initialized
@pytest.mark.asyncio
async def test_error_handling_in_tokenizer(self, mock_tokenizer):
"""Test error handling when tokenizer fails."""
# Create a new async tokenizer with a failing tokenizer
def failing_tokenizer(*args, **kwargs):
raise ValueError("Tokenizer error")
async_tokenizer = AsyncDynamicbatchTokenizer(
tokenizer=failing_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
)
with pytest.raises(ValueError, match="Tokenizer error"):
await async_tokenizer.encode("hello world")
@pytest.mark.asyncio
async def test_batch_processing_logs(self, async_tokenizer, caplog):
"""Test that batch processing logs are generated."""
caplog.set_level(logging.DEBUG)
# Send multiple requests to trigger batching
tasks = [
async_tokenizer.encode("hello world"),
async_tokenizer.encode("how are you"),
]
await asyncio.gather(*tasks)
# Should have batch processing log
assert any(
"Processing dynamic batch of size" in record.message
for record in caplog.records
)
@pytest.mark.asyncio
async def test_empty_queue_immediate_processing(self, async_tokenizer):
"""Test that single requests are processed immediately when queue is empty."""
start_time = time.time()
result = await async_tokenizer.encode("hello world")
end_time = time.time()
# Should complete quickly (much less than batch timeout)
assert end_time - start_time < 0.005 # 5ms should be plenty
assert result["input_ids"] == [0, 1]
@pytest.mark.asyncio
async def test_real_tokenizer_integration(self):
"""Test with a real HuggingFace tokenizer."""
try:
# Use a small, fast tokenizer for testing
real_tokenizer = AutoTokenizer.from_pretrained("gpt2")
async_tokenizer = AsyncDynamicbatchTokenizer(
tokenizer=real_tokenizer, max_batch_size=2, batch_wait_timeout_s=0.01
)
text = "Hello, world!"
result = await async_tokenizer.encode(text)
# Should get actual token IDs
assert "input_ids" in result
assert isinstance(result["input_ids"], list)
assert len(result["input_ids"]) > 0
assert all(isinstance(token_id, int) for token_id in result["input_ids"])
except Exception as e:
pytest.skip(f"Real tokenizer test skipped: {e}")
@pytest.mark.asyncio
async def test_concurrent_mixed_requests(self, async_tokenizer):
"""Test mixing single and batched requests."""
# Start some requests
task1 = asyncio.create_task(async_tokenizer.encode("hello"))
task2 = asyncio.create_task(async_tokenizer.encode("world"))
# Wait a bit
await asyncio.sleep(0.005)
# Start more requests
task3 = asyncio.create_task(async_tokenizer.encode("how are"))
task4 = asyncio.create_task(async_tokenizer.encode("you doing"))
results = await asyncio.gather(task1, task2, task3, task4)
# All should complete successfully
assert len(results) == 4
for result in results:
assert "input_ids" in result
assert isinstance(result["input_ids"], list)
def test_cleanup_on_destruction(self, mock_tokenizer):
"""Test that resources are cleaned up properly."""
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
# Mock the executor and task
tokenizer._executor = Mock()
tokenizer._batcher_task = Mock()
tokenizer._batcher_task.done.return_value = False
# Call destructor
tokenizer.__del__()
# Should cancel task and shutdown executor
tokenizer._batcher_task.cancel.assert_called_once()
tokenizer._executor.shutdown.assert_called_once_with(wait=False)
if __name__ == "__main__":
pytest.main([__file__])
|