| """ |
| Unit tests for enable_tokenizer_batch_encode feature. |
| |
| This tests the batch tokenization functionality which allows processing |
| multiple text inputs in a single batch for improved performance. |
| |
| Usage: |
| python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints |
| python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic |
| python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path |
| """ |
|
|
| import unittest |
| from unittest.mock import Mock, patch |
|
|
| from sglang.srt.managers.io_struct import GenerateReqInput |
| from sglang.srt.managers.tokenizer_manager import TokenizerManager |
| from sglang.srt.server_args import PortArgs, ServerArgs |
| from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
|
|
|
|
| class TestTokenizerBatchEncode(unittest.TestCase): |
| """Test cases for tokenizer batch encoding validation and setup.""" |
|
|
| def setUp(self): |
| """Set up test fixtures.""" |
| self.server_args = ServerArgs( |
| model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, |
| enable_tokenizer_batch_encode=True, |
| ) |
| self.port_args = PortArgs.init_new(self.server_args) |
|
|
| with patch("zmq.asyncio.Context"), patch( |
| "sglang.srt.utils.get_zmq_socket" |
| ), patch( |
| "sglang.srt.utils.hf_transformers_utils.get_tokenizer" |
| ) as mock_tokenizer: |
|
|
| mock_tokenizer.return_value = Mock(vocab_size=32000) |
| self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) |
|
|
| def test_batch_encode_enabled(self): |
| """Test that batch encoding is enabled when configured.""" |
| self.assertTrue(self.server_args.enable_tokenizer_batch_encode) |
|
|
| def test_batch_encode_disabled(self): |
| """Test that batch encoding can be disabled.""" |
| server_args_disabled = ServerArgs( |
| model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, |
| enable_tokenizer_batch_encode=False, |
| ) |
| self.assertFalse(server_args_disabled.enable_tokenizer_batch_encode) |
|
|
| def test_multimodal_input_validation(self): |
| """Test that multimodal inputs are rejected in batch mode.""" |
| req = GenerateReqInput(text="test", image_data=["dummy"]) |
| req.contains_mm_input = Mock(return_value=True) |
|
|
| batch_obj = Mock() |
| batch_obj.__getitem__ = lambda self, i: req |
|
|
| self.tokenizer_manager.is_generation = True |
|
|
| with self.assertRaises(ValueError) as cm: |
| self.tokenizer_manager._validate_batch_tokenization_constraints( |
| 1, batch_obj |
| ) |
|
|
| self.assertIn("multimodal", str(cm.exception)) |
|
|
| def test_pretokenized_input_validation(self): |
| """Test that pre-tokenized inputs are rejected in batch mode.""" |
| req = GenerateReqInput(input_ids=[1, 2, 3]) |
|
|
| batch_obj = Mock() |
| batch_obj.__getitem__ = lambda self, i: req |
|
|
| with self.assertRaises(ValueError) as cm: |
| self.tokenizer_manager._validate_batch_tokenization_constraints( |
| 1, batch_obj |
| ) |
|
|
| self.assertIn("pre-tokenized", str(cm.exception)) |
|
|
| def test_input_embeds_validation(self): |
| """Test that input embeds are rejected in batch mode.""" |
| req = GenerateReqInput(input_embeds=[0.1, 0.2]) |
|
|
| batch_obj = Mock() |
| batch_obj.__getitem__ = lambda self, i: req |
|
|
| with self.assertRaises(ValueError) as cm: |
| self.tokenizer_manager._validate_batch_tokenization_constraints( |
| 1, batch_obj |
| ) |
|
|
| self.assertIn("input_embeds", str(cm.exception)) |
|
|
| def test_valid_text_only_requests_pass_validation(self): |
| """Test that valid text-only requests pass validation.""" |
| |
| requests = [] |
| for i in range(3): |
| req = GenerateReqInput(text=f"test text {i}") |
| req.contains_mm_input = Mock(return_value=False) |
| requests.append(req) |
|
|
| batch_obj = Mock() |
| batch_obj.__getitem__ = Mock(side_effect=lambda i: requests[i]) |
|
|
| |
| try: |
| self.tokenizer_manager._validate_batch_tokenization_constraints( |
| 3, batch_obj |
| ) |
| except Exception as e: |
| self.fail(f"Validation failed for valid text-only requests: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|