| """ |
| Unit tests for the OpenAIServingEmbedding class from serving_embedding.py. |
| """ |
|
|
| import unittest |
| import uuid |
| from unittest.mock import Mock |
|
|
| from fastapi import Request |
|
|
| from sglang.srt.entrypoints.openai.protocol import ( |
| EmbeddingRequest, |
| MultimodalEmbeddingInput, |
| ) |
| from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding |
| from sglang.srt.managers.io_struct import EmbeddingReqInput |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
|
|
| register_cuda_ci(est_time=10, suite="stage-b-test-large-1-gpu") |
| register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd") |
|
|
|
|
| |
| class _MockTokenizerManager: |
| def __init__(self): |
| self.model_config = Mock() |
| self.model_config.is_multimodal = False |
| self.server_args = Mock() |
| self.server_args.enable_cache_report = False |
| self.model_path = "test-model" |
|
|
| |
| self.tokenizer = Mock() |
| self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) |
| self.tokenizer.decode = Mock(return_value="Test embedding input") |
| self.tokenizer.chat_template = None |
| self.tokenizer.bos_token_id = 1 |
|
|
| |
| async def mock_generate_embedding(): |
| yield { |
| "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, |
| "meta_info": { |
| "id": f"embd-{uuid.uuid4()}", |
| "prompt_tokens": 5, |
| }, |
| } |
|
|
| self.generate_request = Mock(return_value=mock_generate_embedding()) |
|
|
|
|
| |
| class _MockTemplateManager: |
| def __init__(self): |
| self.chat_template_name = None |
| self.jinja_template_content_format = None |
| self.completion_template_name = None |
|
|
|
|
| class ServingEmbeddingTestCase(unittest.TestCase): |
| def setUp(self): |
| """Set up test fixtures.""" |
| self.tokenizer_manager = _MockTokenizerManager() |
| self.template_manager = _MockTemplateManager() |
| self.serving_embedding = OpenAIServingEmbedding( |
| self.tokenizer_manager, self.template_manager |
| ) |
|
|
| self.request = Mock(spec=Request) |
| self.request.headers = {} |
|
|
| self.basic_req = EmbeddingRequest( |
| model="test-model", |
| input="Hello, how are you?", |
| encoding_format="float", |
| ) |
| self.list_req = EmbeddingRequest( |
| model="test-model", |
| input=["Hello, how are you?", "I am fine, thank you!"], |
| encoding_format="float", |
| ) |
| self.multimodal_req = EmbeddingRequest( |
| model="test-model", |
| input=[ |
| MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), |
| MultimodalEmbeddingInput(text="World", image=None), |
| ], |
| encoding_format="float", |
| ) |
| self.token_ids_req = EmbeddingRequest( |
| model="test-model", |
| input=[1, 2, 3, 4, 5], |
| encoding_format="float", |
| ) |
|
|
| def test_convert_single_string_request(self): |
| """Test converting single string request to internal format.""" |
| adapted_request, processed_request = ( |
| self.serving_embedding._convert_to_internal_request(self.basic_req) |
| ) |
|
|
| self.assertIsInstance(adapted_request, EmbeddingReqInput) |
| self.assertEqual(adapted_request.text, "Hello, how are you?") |
| |
| self.assertEqual(processed_request, self.basic_req) |
|
|
| def test_convert_list_string_request(self): |
| """Test converting list of strings request to internal format.""" |
| adapted_request, processed_request = ( |
| self.serving_embedding._convert_to_internal_request(self.list_req) |
| ) |
|
|
| self.assertIsInstance(adapted_request, EmbeddingReqInput) |
| self.assertEqual( |
| adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] |
| ) |
| |
| self.assertEqual(processed_request, self.list_req) |
|
|
| def test_convert_token_ids_request(self): |
| """Test converting token IDs request to internal format.""" |
| adapted_request, processed_request = ( |
| self.serving_embedding._convert_to_internal_request(self.token_ids_req) |
| ) |
|
|
| self.assertIsInstance(adapted_request, EmbeddingReqInput) |
| self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) |
| |
| self.assertEqual(processed_request, self.token_ids_req) |
|
|
| def test_convert_multimodal_request(self): |
| """Test converting multimodal request to internal format.""" |
| adapted_request, processed_request = ( |
| self.serving_embedding._convert_to_internal_request(self.multimodal_req) |
| ) |
|
|
| self.assertIsInstance(adapted_request, EmbeddingReqInput) |
| |
| self.assertEqual(len(adapted_request.text), 2) |
| self.assertIn("Hello", adapted_request.text) |
| self.assertIn("World", adapted_request.text) |
| self.assertEqual(adapted_request.image_data[0], "base64_image_data") |
| self.assertIsNone(adapted_request.image_data[1]) |
| |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|