| """ |
| Unit-tests for the refactored completions-serving handler (no pytest). |
| Run with: |
| python -m unittest tests.test_serving_completions_unit -v |
| """ |
|
|
| import json |
| import unittest |
| from http import HTTPStatus |
| from typing import Optional |
| from unittest.mock import AsyncMock, Mock |
|
|
| from fastapi import Request |
|
|
| from sglang.srt.entrypoints.openai.protocol import CompletionRequest |
| from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion |
| from sglang.srt.managers.tokenizer_manager import TokenizerManager |
| from sglang.srt.utils import get_or_create_event_loop |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
|
|
| register_cuda_ci(est_time=10, suite="stage-b-test-small-1-gpu") |
| register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd") |
|
|
|
|
| class _MockTemplateManager: |
| """Minimal mock for TemplateManager.""" |
|
|
| def __init__(self): |
| self.chat_template_name: Optional[str] = None |
| self.jinja_template_content_format: Optional[str] = None |
| self.completion_template_name: Optional[str] = ( |
| None |
| ) |
|
|
|
|
| class ServingCompletionTestCase(unittest.TestCase): |
| """Bundle all prompt/echo tests in one TestCase.""" |
|
|
| |
| def setUp(self): |
| |
| tm = Mock(spec=TokenizerManager) |
|
|
| tm.tokenizer = Mock() |
| tm.tokenizer.encode.return_value = [1, 2, 3, 4] |
| tm.tokenizer.decode.return_value = "decoded text" |
| tm.tokenizer.bos_token_id = 1 |
|
|
| tm.model_config = Mock(is_multimodal=False) |
| tm.server_args = Mock(enable_cache_report=False) |
|
|
| tm.generate_request = AsyncMock() |
| tm.create_abort_task = Mock() |
|
|
| self.template_manager = _MockTemplateManager() |
| self.sc = OpenAIServingCompletion(tm, self.template_manager) |
| self.fastapi_request = Mock(spec=Request) |
|
|
| |
| def test_single_string_prompt(self): |
| req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100) |
| internal, _ = self.sc._convert_to_internal_request(req) |
| self.assertEqual(internal.text, "Hello world") |
|
|
| def test_single_token_ids_prompt(self): |
| req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100) |
| internal, _ = self.sc._convert_to_internal_request(req) |
| self.assertEqual(internal.input_ids, [1, 2, 3, 4]) |
|
|
| |
| def test_echo_with_string_prompt_streaming(self): |
| req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True) |
| self.assertEqual(self.sc._get_echo_text(req, 0), "Hello") |
|
|
| def test_echo_with_list_of_strings_streaming(self): |
| req = CompletionRequest( |
| model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1 |
| ) |
| self.assertEqual(self.sc._get_echo_text(req, 0), "A") |
| self.assertEqual(self.sc._get_echo_text(req, 1), "B") |
|
|
| def test_echo_with_token_ids_streaming(self): |
| req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True) |
| self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt" |
| self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt") |
|
|
| def test_echo_with_multiple_token_ids_streaming(self): |
| req = CompletionRequest( |
| model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1 |
| ) |
| self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" |
| self.assertEqual(self.sc._get_echo_text(req, 0), "decoded") |
|
|
| def test_prepare_echo_prompts_non_streaming(self): |
| |
| req = CompletionRequest(model="x", prompt="Hi", echo=True) |
| self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"]) |
|
|
| |
| req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True) |
| self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"]) |
|
|
| |
| req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True) |
| self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" |
| self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"]) |
|
|
| |
| def test_response_format_json_object(self): |
| """Test that response_format json_object is correctly processed in sampling params.""" |
| req = CompletionRequest( |
| model="x", |
| prompt="Generate a JSON object:", |
| max_tokens=100, |
| response_format={"type": "json_object"}, |
| ) |
| sampling_params = self.sc._build_sampling_params(req) |
| self.assertEqual(sampling_params["json_schema"], '{"type": "object"}') |
|
|
| def test_response_format_json_schema(self): |
| """Test that response_format json_schema is correctly processed in sampling params.""" |
| schema = { |
| "type": "object", |
| "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, |
| } |
| req = CompletionRequest( |
| model="x", |
| prompt="Generate a JSON object:", |
| max_tokens=100, |
| response_format={ |
| "type": "json_schema", |
| "json_schema": {"name": "person", "schema": schema}, |
| }, |
| ) |
| sampling_params = self.sc._build_sampling_params(req) |
| |
| self.assertIn("json_schema", sampling_params) |
| self.assertIsInstance(sampling_params["json_schema"], str) |
|
|
| def test_response_format_structural_tag(self): |
| """Test that response_format structural_tag is correctly processed in sampling params.""" |
| req = CompletionRequest( |
| model="x", |
| prompt="Generate structured output:", |
| max_tokens=100, |
| response_format={ |
| "type": "structural_tag", |
| "structures": [{"begin": "<data>", "end": "</data>"}], |
| "triggers": ["<data>"], |
| }, |
| ) |
| sampling_params = self.sc._build_sampling_params(req) |
| |
| self.assertIn("structural_tag", sampling_params) |
| self.assertIsInstance(sampling_params["structural_tag"], str) |
|
|
| def test_response_format_none(self): |
| """Test that no response_format doesn't add extra constraints.""" |
| req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100) |
| sampling_params = self.sc._build_sampling_params(req) |
| |
| |
| self.assertIsNone(sampling_params.get("structural_tag")) |
|
|
| def test_logprobs_false_non_streaming(self): |
| """Test that logprobs=False doesn't cause KeyError in non-streaming response.""" |
| req = CompletionRequest( |
| model="x", prompt="Hello", max_tokens=10, logprobs=False |
| ) |
|
|
| mock_ret = [ |
| { |
| "text": " world", |
| "meta_info": { |
| "id": "test-id", |
| "prompt_tokens": 1, |
| "completion_tokens": 2, |
| "finish_reason": {"type": "stop"}, |
| "weight_version": "v1", |
| }, |
| } |
| ] |
|
|
| response = self.sc._build_completion_response(req, mock_ret, 1234567890) |
|
|
| self.assertEqual(len(response.choices), 1) |
| self.assertEqual(response.choices[0].text, " world") |
| self.assertEqual(len(response.choices[0].logprobs.top_logprobs), 0) |
|
|
| def test_streaming_abort_yields_error(self): |
| """Test that an abort finish reason during streaming correctly yields an error and stops.""" |
| err_msg = "Aborted by scheduler" |
| err_code = HTTPStatus.INTERNAL_SERVER_ERROR |
|
|
| async def _mock_generate_abort(*args, **kwargs): |
| yield { |
| "text": "Partial ", |
| "meta_info": { |
| "id": "cmpl-test", |
| "prompt_tokens": 10, |
| "completion_tokens": 2, |
| "cached_tokens": 0, |
| "finish_reason": { |
| "type": "abort", |
| "status_code": err_code, |
| "message": err_msg, |
| }, |
| "output_token_logprobs": None, |
| "output_top_logprobs": None, |
| }, |
| "index": 0, |
| } |
|
|
| self.sc.tokenizer_manager.generate_request = _mock_generate_abort |
|
|
| req = CompletionRequest( |
| model="x", |
| prompt="Hello world", |
| max_tokens=100, |
| stream=True, |
| ) |
|
|
| adapted_request, _ = self.sc._convert_to_internal_request(req) |
|
|
| async def run_stream(): |
| chunks = [] |
| try: |
| async for chunk in self.sc._generate_completion_stream( |
| adapted_request, req, self.fastapi_request |
| ): |
| chunks.append(chunk) |
| except Exception as e: |
| print(f"Error during stream iteration: {e}") |
| return chunks |
|
|
| loop = get_or_create_event_loop() |
| chunks = loop.run_until_complete(run_stream()) |
|
|
| error_chunk_data = None |
| for c in chunks: |
| if "error" in c: |
| error_chunk_data = json.loads(c[len("data: ") :]) |
| break |
| self.assertIsNotNone(error_chunk_data, "Error chunk not found in stream") |
| self.assertEqual(error_chunk_data["error"]["message"], err_msg) |
| self.assertEqual(error_chunk_data["error"]["code"], err_code.value) |
|
|
| |
| |
| self.assertEqual(chunks[-1], "data: [DONE]\n\n") |
|
|
| |
| self.assertGreaterEqual(len(chunks), 2) |
| self.assertIn("error", chunks[0]) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|