""" Unit-tests for the refactored completions-serving handler (no pytest). Run with: python -m unittest tests.test_serving_completions_unit -v """ import json import unittest from http import HTTPStatus from typing import Optional from unittest.mock import AsyncMock, Mock from fastapi import Request from sglang.srt.entrypoints.openai.protocol import CompletionRequest from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.utils import get_or_create_event_loop from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci register_cuda_ci(est_time=10, suite="stage-b-test-small-1-gpu") register_amd_ci(est_time=10, suite="stage-b-test-small-1-gpu-amd") class _MockTemplateManager: """Minimal mock for TemplateManager.""" def __init__(self): self.chat_template_name: Optional[str] = None self.jinja_template_content_format: Optional[str] = None self.completion_template_name: Optional[str] = ( None # Set to None to avoid template processing ) class ServingCompletionTestCase(unittest.TestCase): """Bundle all prompt/echo tests in one TestCase.""" # ---------- shared test fixtures ---------- def setUp(self): # build the mock TokenizerManager once for every test tm = Mock(spec=TokenizerManager) tm.tokenizer = Mock() tm.tokenizer.encode.return_value = [1, 2, 3, 4] tm.tokenizer.decode.return_value = "decoded text" tm.tokenizer.bos_token_id = 1 tm.model_config = Mock(is_multimodal=False) tm.server_args = Mock(enable_cache_report=False) tm.generate_request = AsyncMock() tm.create_abort_task = Mock() self.template_manager = _MockTemplateManager() self.sc = OpenAIServingCompletion(tm, self.template_manager) self.fastapi_request = Mock(spec=Request) # ---------- prompt-handling ---------- def test_single_string_prompt(self): req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100) internal, _ = self.sc._convert_to_internal_request(req) self.assertEqual(internal.text, "Hello world") def test_single_token_ids_prompt(self): req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100) internal, _ = self.sc._convert_to_internal_request(req) self.assertEqual(internal.input_ids, [1, 2, 3, 4]) # ---------- echo-handling ---------- def test_echo_with_string_prompt_streaming(self): req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True) self.assertEqual(self.sc._get_echo_text(req, 0), "Hello") def test_echo_with_list_of_strings_streaming(self): req = CompletionRequest( model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1 ) self.assertEqual(self.sc._get_echo_text(req, 0), "A") self.assertEqual(self.sc._get_echo_text(req, 1), "B") def test_echo_with_token_ids_streaming(self): req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt" self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt") def test_echo_with_multiple_token_ids_streaming(self): req = CompletionRequest( model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1 ) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" self.assertEqual(self.sc._get_echo_text(req, 0), "decoded") def test_prepare_echo_prompts_non_streaming(self): # single string req = CompletionRequest(model="x", prompt="Hi", echo=True) self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"]) # list of strings req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True) self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"]) # token IDs req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"]) # ---------- response_format handling ---------- def test_response_format_json_object(self): """Test that response_format json_object is correctly processed in sampling params.""" req = CompletionRequest( model="x", prompt="Generate a JSON object:", max_tokens=100, response_format={"type": "json_object"}, ) sampling_params = self.sc._build_sampling_params(req) self.assertEqual(sampling_params["json_schema"], '{"type": "object"}') def test_response_format_json_schema(self): """Test that response_format json_schema is correctly processed in sampling params.""" schema = { "type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, } req = CompletionRequest( model="x", prompt="Generate a JSON object:", max_tokens=100, response_format={ "type": "json_schema", "json_schema": {"name": "person", "schema": schema}, }, ) sampling_params = self.sc._build_sampling_params(req) # The schema should be converted to string by convert_json_schema_to_str self.assertIn("json_schema", sampling_params) self.assertIsInstance(sampling_params["json_schema"], str) def test_response_format_structural_tag(self): """Test that response_format structural_tag is correctly processed in sampling params.""" req = CompletionRequest( model="x", prompt="Generate structured output:", max_tokens=100, response_format={ "type": "structural_tag", "structures": [{"begin": "", "end": ""}], "triggers": [""], }, ) sampling_params = self.sc._build_sampling_params(req) # The structural_tag should be processed self.assertIn("structural_tag", sampling_params) self.assertIsInstance(sampling_params["structural_tag"], str) def test_response_format_none(self): """Test that no response_format doesn't add extra constraints.""" req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100) sampling_params = self.sc._build_sampling_params(req) # Should not have json_schema or structural_tag from response_format # (but might have json_schema from the legacy json_schema field) self.assertIsNone(sampling_params.get("structural_tag")) def test_logprobs_false_non_streaming(self): """Test that logprobs=False doesn't cause KeyError in non-streaming response.""" req = CompletionRequest( model="x", prompt="Hello", max_tokens=10, logprobs=False ) mock_ret = [ { "text": " world", "meta_info": { "id": "test-id", "prompt_tokens": 1, "completion_tokens": 2, "finish_reason": {"type": "stop"}, "weight_version": "v1", }, } ] response = self.sc._build_completion_response(req, mock_ret, 1234567890) self.assertEqual(len(response.choices), 1) self.assertEqual(response.choices[0].text, " world") self.assertEqual(len(response.choices[0].logprobs.top_logprobs), 0) def test_streaming_abort_yields_error(self): """Test that an abort finish reason during streaming correctly yields an error and stops.""" err_msg = "Aborted by scheduler" err_code = HTTPStatus.INTERNAL_SERVER_ERROR async def _mock_generate_abort(*args, **kwargs): yield { "text": "Partial ", "meta_info": { "id": "cmpl-test", "prompt_tokens": 10, "completion_tokens": 2, "cached_tokens": 0, "finish_reason": { "type": "abort", "status_code": err_code, "message": err_msg, }, "output_token_logprobs": None, "output_top_logprobs": None, }, "index": 0, } self.sc.tokenizer_manager.generate_request = _mock_generate_abort req = CompletionRequest( model="x", prompt="Hello world", max_tokens=100, stream=True, ) adapted_request, _ = self.sc._convert_to_internal_request(req) async def run_stream(): chunks = [] try: async for chunk in self.sc._generate_completion_stream( adapted_request, req, self.fastapi_request ): chunks.append(chunk) except Exception as e: print(f"Error during stream iteration: {e}") return chunks loop = get_or_create_event_loop() chunks = loop.run_until_complete(run_stream()) error_chunk_data = None for c in chunks: if "error" in c: error_chunk_data = json.loads(c[len("data: ") :]) break self.assertIsNotNone(error_chunk_data, "Error chunk not found in stream") self.assertEqual(error_chunk_data["error"]["message"], err_msg) self.assertEqual(error_chunk_data["error"]["code"], err_code.value) # Ensure the stream stops after the abort error # The last chunk should be "data: [DONE]\n\n" self.assertEqual(chunks[-1], "data: [DONE]\n\n") # Check that there is an error chunk and a DONE chunk, and possibly a role chunk self.assertGreaterEqual(len(chunks), 2) self.assertIn("error", chunks[0]) if __name__ == "__main__": unittest.main(verbosity=2)