transformers / tests /cli /test_serve.py
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import time
import unittest
from threading import Thread
from unittest.mock import Mock, patch
import httpx
from huggingface_hub import ChatCompletionStreamOutput, InferenceClient, hf_hub_download
from parameterized import parameterized
from transformers import GenerationConfig
from transformers.cli.serve import Modality, Serve
from transformers.testing_utils import require_openai, slow
from transformers.utils.import_utils import is_openai_available
if is_openai_available():
from openai import APIConnectionError, OpenAI
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
@require_openai
def test_help(cli):
"""Minimal test: we can invoke the help command."""
output = cli("serve", "--help")
assert output.exit_code == 0
assert "serve" in output.output
@require_openai
def test_host_port_blocking(cli):
"""Minimal test: we can set arguments through the CLI - blocking"""
with (
patch("uvicorn.Config") as ConfigMock,
patch("uvicorn.Server") as ServerMock,
):
server_instance = Mock()
ServerMock.return_value = server_instance
# Call the serve CLI with host/port
out = cli("serve", "--host", "0.0.0.0", "--port", "9000")
_, kwargs = ConfigMock.call_args
assert out.exit_code == 0
assert kwargs["host"] == "0.0.0.0"
assert kwargs["port"] == 9000
ServerMock.assert_called_once_with(ConfigMock.return_value)
server_instance.run.assert_called_once()
@require_openai
def test_host_port_non_blocking(cli, caplog):
"""Minimal test: we can set arguments through the CLI - non-blocking"""
caplog.set_level(100000)
# ^ hack to avoid an issue happening only in CI. We don't check logs anyway so it's fine.
# Source: https://github.com/pallets/click/issues/824#issuecomment-562581313
with (
patch("uvicorn.Config") as ConfigMock,
patch("uvicorn.Server") as ServerMock,
patch.object(Serve, "start_server") as start_mock,
):
server_instance = Mock()
ServerMock.return_value = server_instance
out = cli("serve", "--host", "0.5.0.0", "--port", "9002", "--non-blocking")
assert out.exit_code == 0
# Config got the CLI args
_, kwargs = ConfigMock.call_args
assert kwargs["host"] == "0.5.0.0"
assert kwargs["port"] == 9002
# Non-blocking path uses start_server(), not server.run()
start_mock.assert_called_once()
server_instance.run.assert_not_called()
@require_openai
def test_build_chat_completion_chunk():
"""
Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implicitly
confirm that empty fields are not emitted.
"""
dummy = Serve.__new__(Serve)
# The keys for these fields must be present in every chunk
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
# Case 1: most fields are provided
chunk = dummy.build_chat_completion_chunk(
request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main"
)
chunk = dummy.chunk_to_sse_element(chunk)
for field in MANDATORY_FIELDS:
assert field in chunk
assert '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]' in chunk
# Case 2: only the role is provided -- other fields in 'choices' are omitted
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main")
chunk = dummy.chunk_to_sse_element(chunk)
for field in MANDATORY_FIELDS:
assert field in chunk
assert '"choices":[{"delta":{"role":"user"},"index":0}]' in chunk
# Case 3: only the content is provided -- other fields in 'choices' are omitted
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main")
chunk = dummy.chunk_to_sse_element(chunk)
for field in MANDATORY_FIELDS:
assert field in chunk
assert '"choices":[{"delta":{"content":"hello"},"index":0}]' in chunk
# Case 4: tool calls support a list of ChoiceDeltaToolCall objects
tool_call = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
type="function",
)
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main")
chunk = dummy.chunk_to_sse_element(chunk)
for field in MANDATORY_FIELDS:
assert field in chunk
expected_choices_content = (
'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
'\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
)
assert expected_choices_content in chunk
def test_generative_model_list():
with tempfile.TemporaryDirectory() as cache_dir:
# "download" a few models, including some non-generative models
hf_hub_download("Menlo/Jan-nano", "config.json", cache_dir=cache_dir)
hf_hub_download("Menlo/Jan-nano-128k", "config.json", cache_dir=cache_dir)
hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir)
hf_hub_download("HuggingFaceTB/SmolVLM-Instruct", "config.json", cache_dir=cache_dir)
hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir)
expected_results = {
"HuggingFaceTB/SmolVLM-Instruct": ["HuggingFaceTB", "SmolVLM-Instruct"],
"Qwen/Qwen2.5-0.5B-Instruct": ["Qwen", "Qwen2.5-0.5B-Instruct"],
"Menlo/Jan-nano": ["Menlo", "Jan-nano"],
"Menlo/Jan-nano-128k": ["Menlo", "Jan-nano-128k"],
}
# list models
result = Serve.get_gen_models(cache_dir)
assert len(expected_results) == len(result)
local_repos = {repo["id"]: repo["owned_by"] for repo in result}
for key, value in expected_results.items():
assert key in local_repos
assert local_repos[key] == value
@require_openai
def test_build_response_event():
"""
Tests that the events are correctly built for the Response API.
Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
"""
dummy = Serve.__new__(Serve)
response_created = ResponseCreatedEvent(
type="response.created",
sequence_number=0,
response=Response(
id="resp_0",
created_at=time.time(),
status="queued",
model="dummy_model@main",
instructions=None, # <--- is set to None = should NOT be in the output.
text={"format": {"type": "text"}},
object="response",
tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
output=[],
parallel_tool_calls=False,
tool_choice="auto",
metadata=None,
),
)
event = dummy.chunk_to_sse_element(response_created)
assert event.startswith("data: ") # Sanity check: event formatting
assert '"model":"dummy_model@main"' in event # Sanity check: set field
assert '"status":"queued"' in event
assert "tools" in event # empty lists should be in the output
assert "output" in event
assert "instructions" not in event # None fields should NOT be in the output
assert "metadata" not in event
assert "error" not in event # Unset optional fields should NOT be in the output
assert "top_p" not in event
def retry(fn, max_attempts=5, delay=2):
"""
Retry a function up to `max_attempts` times with a `delay` between attempts.
Useful for testing functions that may fail due to server not being ready.
"""
def wrapper(*args, **kwargs):
nb_attempts = 0
while True:
nb_attempts += 1
try:
return fn(*args, **kwargs)
except (httpx.HTTPError, APIConnectionError):
if nb_attempts >= max_attempts:
raise
time.sleep(delay)
return wrapper
class ServeCompletionsMixin:
"""
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
(`generate` and `continuous_batching`).
"""
@retry
def run_server(self, request):
with InferenceClient(f"http://localhost:{self.port}") as client:
return list(client.chat_completion(**request))
@parameterized.expand(
[
("default_request", {}),
("one_token", {"max_tokens": 1}),
("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
(
"tool_call",
{
"tools": [
{
"function": {
"name": "foo_bar",
"parameters": {"type": "object"},
"description": "Foo bar",
},
"type": "function",
}
]
},
),
]
)
def test_requests(self, test_name: str, request_flags: dict):
"""Tests that the completions app gracefully handles GOOD requests, producing the expected output payloads."""
request = {
"model": "Qwen/Qwen3-0.6B",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": True, # We don't support "stream": False yet
"max_tokens": 5, # Small generation by default
}
request.update(request_flags)
all_payloads = self.run_server(request)
# If a request is successful, the returned payload needs to follow the schema, which we test here.
# NOTE: the output of our server is wrapped by `InferenceClient`, which sends fields even when they
# are empty.
# Finish reason: the last payload should have a finish reason of "length" or "stop", all others should be empty
finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads]
self.assertTrue(finish_reasons[-1] in ["length", "stop"])
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
# Role: the first payload should have a role of "assistant", all others should be empty
roles = [payload.choices[0].delta.role for payload in all_payloads]
self.assertEqual(roles[0], "assistant")
self.assertTrue(all(role is None for role in roles[1:]))
# Content: the first and the last payload shouldn't have content (role and finish reason). It may be empty
# in some other payload positions, e.g. tool calls.
contents = [payload.choices[0].delta.content for payload in all_payloads]
self.assertTrue(contents[0] is None and contents[-1] is None)
self.assertTrue(any(content is not None for content in contents[1:-1]))
# TODO: add "usage" field to output and test it
def test_generation_config_in_request(self):
"""Tests that the generation config is correctly passed into the generation call."""
generation_config = GenerationConfig(do_sample=False, temperature=0.0)
request = {
"model": "Qwen/Qwen3-0.6B",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": True,
"max_tokens": 10,
"extra_body": {
"generation_config": generation_config.to_json_string(),
},
}
all_payloads = self.run_server(request)
contents = [payload.choices[0].delta.content for payload in all_payloads]
output_text = "".join([text for text in contents if text is not None])
# The generation config sets greedy decoding, so the output is reproducible. By default, `Qwen/Qwen3-0.6B`
# sets `do_sample=True`
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
def test_early_return_due_to_length(self):
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": True,
"max_tokens": 3,
}
all_payloads = self.run_server(request)
last_payload = all_payloads[-1]
self.assertTrue(last_payload.choices[0]["finish_reason"] == "length")
def test_continues_until_stop(self):
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"messages": [{"role": "user", "content": 'Please only answer with "Hi."'}],
"stream": True,
"max_tokens": 30,
}
all_payloads = self.run_server(request)
last_payload = all_payloads[-1]
self.assertTrue(last_payload.choices[0]["finish_reason"] == "stop")
class ServeCompletionsGenerateMockTests(unittest.TestCase):
def test_processor_inputs_from_inbound_messages_llm(self):
modality = Modality.LLM
messages = expected_outputs = [
{"role": "user", "content": "How are you doing?"},
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
{"role": "user", "content": "Can you help me write tests?"},
]
outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
self.assertListEqual(expected_outputs, outputs)
messages_with_type = [
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
],
},
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
]
outputs = Serve.get_processor_inputs_from_inbound_messages(messages_with_type, modality)
self.assertListEqual(expected_outputs, outputs)
messages_multiple_text = [
{
"role": "user",
"content": [
{"type": "text", "text": "How are you doing?"},
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"},
],
},
]
expected_outputs_multiple_text = [
{
"role": "user",
"content": "How are you doing? I'm doing great, thank you for asking! How can I assist you today?",
},
]
outputs = Serve.get_processor_inputs_from_inbound_messages(messages_multiple_text, modality)
self.assertListEqual(expected_outputs_multiple_text, outputs)
def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
modality = Modality.VLM
messages = [
{"role": "user", "content": "How are you doing?"},
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
{"role": "user", "content": "Can you help me write tests?"},
]
expected_outputs = [
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
],
},
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
]
outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
self.assertListEqual(expected_outputs, outputs)
def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self):
modality = Modality.VLM
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "How many pixels are in the image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k="
},
},
],
},
{
"role": "assistant",
"content": "The number of pixels in the image cannot be determined from the provided information.",
},
{"role": "user", "content": "Alright"},
]
expected_outputs = [
{
"role": "user",
"content": [
{"type": "text", "text": "How many pixels are in the image?"},
{"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "The number of pixels in the image cannot be determined from the provided information.",
}
],
},
{"role": "user", "content": [{"type": "text", "text": "Alright"}]},
]
outputs = Serve.get_processor_inputs_from_inbound_messages(messages, modality)
for expected_output, output in zip(expected_outputs, outputs):
expected_output_content = expected_output["content"]
output_content = output["content"]
self.assertEqual(type(expected_output_content), type(output_content))
if isinstance(expected_output_content, list):
for expected_output_content_item, output_content_item in zip(expected_output_content, output_content):
self.assertIn("type", expected_output_content_item)
self.assertIn("type", output_content_item)
self.assertTrue(expected_output_content_item["type"] == output_content_item["type"])
if expected_output_content_item["type"] == "text":
self.assertEqual(expected_output_content_item["text"], output_content_item["text"])
if expected_output_content_item["type"] == "image":
self.assertTrue(os.path.exists(output_content_item["url"]))
else:
raise ValueError("VLMs should only receive content as lists.")
@slow # server startup time is slow on our push CI
@require_openai
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `generate` version of the Completions API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
cls.port = 8001
cls.server = Serve(port=cls.port, non_blocking=True)
@classmethod
def tearDownClass(cls):
cls.server.kill_server()
@slow
def test_tool_call(self):
"""Tests that the tool call is correctly handled and that the payloads are correctly structured."""
# TODO: move to the mixin when CB also supports tool calls
request = {
# This model is a small model that's very eager to call tools
# TODO: this is a 4B model. Find a smaller model that's eager to call tools
"model": "Menlo/Jan-nano",
# The request should produce a tool call
"messages": [{"role": "user", "content": "Generate an image of a cat."}],
"stream": True,
"max_tokens": 50,
# Reproducibility
"temperature": 0.0,
# This tool is a copy from the tool in the original tiny-agents demo
"tools": [
{
"function": {
"name": "flux1_schnell_infer",
"parameters": {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"seed": {"type": "number", "description": "numeric value between 0 and 2147483647"},
"randomize_seed": {"type": "boolean", "default": True},
"width": {
"type": "number",
"description": "numeric value between 256 and 2048",
"default": 1024,
},
"height": {
"type": "number",
"description": "numeric value between 256 and 2048",
"default": 1024,
},
"num_inference_steps": {
"type": "number",
"description": "numeric value between 1 and 16",
"default": 4,
},
},
},
"description": "Generate an image using the Flux 1 Schnell Image Generator.",
},
"type": "function",
}
],
}
all_payloads = self.run_server(request)
# The first payload should contain the role
roles = [payload.choices[0].delta.role for payload in all_payloads]
self.assertEqual(roles[0], "assistant")
self.assertTrue(all(role is None for role in roles[1:]))
# All other payloads (except the last one) should be tool call related, for this specific request
contents = [payload.choices[0].delta.content for payload in all_payloads]
self.assertTrue(all(content is None for content in contents))
# The first tool call delta should contain the tool name. The other tool call deltas should contain the tool
# arguments.
tool_calls = [payload.choices[0].delta.tool_calls[0] for payload in all_payloads[1:-1]]
first_tool_call = tool_calls[0]
self.assertEqual(first_tool_call["function"]["name"], "flux1_schnell_infer")
self.assertEqual(first_tool_call["function"]["arguments"], None)
other_tool_calls = tool_calls[1:]
self.assertTrue(all(tool_call["function"]["name"] is None for tool_call in other_tool_calls))
self.assertTrue(all(tool_call["function"]["arguments"] is not None for tool_call in other_tool_calls))
# Finally, the last payload should contain a finish reason
finish_reasons = [payload.choices[0].finish_reason for payload in all_payloads]
# TODO: I think the finish reason for a tool call is different? double check this
self.assertTrue(finish_reasons[-1] in ["stop", "length"])
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
def _get_scheduler(serve_command):
# Defensive navigation in case any layer is renamed in the future
cbm = getattr(serve_command, "running_continuous_batching_manager", None)
assert cbm is not None, "ServeCommand has no running_continuous_batching_manager"
bp = getattr(cbm, "batch_processor", None)
assert bp is not None, "running_continuous_batching_manager has no batch_processor"
sched = getattr(bp, "scheduler", None)
assert sched is not None, "batch_processor has no scheduler"
return sched
def _call_healthcheck(base_url: str):
response = None
retries = 10
while retries > 0:
try:
response = httpx.get(f"{base_url}/health")
break
except httpx.NetworkError:
time.sleep(0.1)
retries -= 1
return response
def _open_stream_and_cancel(base_url: str, request_id: str):
with httpx.Client() as s:
with s.stream(
"POST",
f"{base_url}/v1/chat/completions",
headers={"X-Request-ID": request_id},
json={
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"stream": True,
"messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
},
timeout=30,
) as resp:
assert resp.status_code == 200
wait_for_n_chunks = 3
for i, _ in enumerate(resp.iter_bytes(chunk_size=None)):
if i >= wait_for_n_chunks:
resp.close()
break
@slow # server startup time is slow on our push CI
@require_openai
class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `continuous_batching` version of the Completions API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
cls.port = 8002
cls.server = Serve(
port=cls.port, continuous_batching=True, attn_implementation="sdpa", default_seed=42, non_blocking=True
)
@classmethod
def tearDownClass(cls):
cls.server.kill_server()
def test_full_request(self):
"""Tests that an inference using the Responses API and Continuous Batching works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"messages": [
{"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
{"role": "user", "content": "Tell me what you can do."},
],
"stream": True,
"max_tokens": 30,
}
all_payloads = self.run_server(request)
full_text = ""
for token in all_payloads:
if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0:
content = token.choices[0].delta.get("content", "")
full_text += content if content is not None else ""
# Verify that the system prompt went through.
self.assertTrue(
full_text.startswith(
"I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics."
)
)
def test_max_tokens_not_set_in_req(self):
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"messages": [
{"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
{"role": "user", "content": "Tell me what you can do."},
],
"stream": True,
}
all_payloads = self.run_server(request)
full_text = ""
for token in all_payloads:
if isinstance(token, ChatCompletionStreamOutput) and token.choices and len(token.choices) > 0:
content = token.choices[0].delta.get("content", "")
full_text += content if content is not None else ""
# Verify that the system prompt went through.
self.assertTrue(
full_text.startswith(
"I can assist you with a wide range of tasks, from answering questions to providing information on various sports topics."
)
)
def test_request_cancellation(self):
"""Tests that a request can be cancelled."""
base_url = f"http://127.0.0.1:{self.port}"
request_id = "test-cancel"
# Ensure the server is up before sending a request
response = _call_healthcheck(base_url)
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
self.assertEqual(response.status_code, 200)
_open_stream_and_cancel(base_url, request_id)
scheduler = _get_scheduler(self.server)
# Because cancellation is non-blocking, poll for a short, bounded time.
deadline = time.time() + 8.0 # generous but still CI-friendly
last_seen = None
while time.time() < deadline:
is_cancelled = scheduler.request_is_cancelled(request_id)
if is_cancelled:
break
last_seen = time.time()
time.sleep(0.1) # don't spin the CPU
is_cancelled = scheduler.request_is_cancelled(request_id)
self.assertTrue(
is_cancelled,
f"Request {request_id} still present in scheduler after cancellation "
f"(last seen at {last_seen}). Check cancellation propagation.",
)
@require_openai
class ServeResponsesMixin:
"""
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
(`generate` and `continuous_batching`).
"""
@retry
def run_server(self, request):
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
stream = client.responses.create(**request)
all_payloads = []
for payload in stream:
all_payloads.append(payload)
return all_payloads
def test_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a helpful assistant.",
"input": "Hello!",
"stream": True,
"max_output_tokens": 1,
}
all_payloads = self.run_server(request)
# Allow variable number of delta events depending on tokenizer/streamer behavior
self.assertGreaterEqual(len(all_payloads), 8)
# Start markers
self.assertIsInstance(all_payloads[0], ResponseCreatedEvent)
self.assertIsInstance(all_payloads[1], ResponseInProgressEvent)
self.assertIsInstance(all_payloads[2], ResponseOutputItemAddedEvent)
self.assertIsInstance(all_payloads[3], ResponseContentPartAddedEvent)
# At least one delta event during streaming
self.assertTrue(any(isinstance(p, ResponseTextDeltaEvent) for p in all_payloads[4:-4]))
# Closing markers
self.assertIsInstance(all_payloads[-4], ResponseTextDoneEvent)
self.assertIsInstance(all_payloads[-3], ResponseContentPartDoneEvent)
self.assertIsInstance(all_payloads[-2], ResponseOutputItemDoneEvent)
self.assertIsInstance(all_payloads[-1], ResponseCompletedEvent)
# TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests
@slow # server startup time is slow on our push CI
@require_openai
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
"""Tests the Responses API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
cls.port = 8003
cls.server = Serve(port=cls.port, default_seed=42, non_blocking=True)
@classmethod
def tearDownClass(cls):
cls.server.kill_server()
@slow
def test_full_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a sports assistant designed to craft sports programs.",
"input": "Tell me what you can do.",
"stream": True,
"max_output_tokens": 30,
# Disable sampling for deterministic output
"temperature": 0,
}
all_payloads = self.run_server(request)
full_text = ""
for token in all_payloads:
if isinstance(token, ResponseTextDeltaEvent):
full_text += token.delta
# Verify that the system prompt went through.
# With deterministic decoding, exact wording can still vary across versions.
# Assert non-empty output and that it references sports.
self.assertTrue(len(full_text) > 0)
self.assertIn("sports", full_text.lower())
@slow
def test_non_streaming_request(self):
"""Tests that an inference using the Responses API with stream=False returns a single Response payload."""
from openai import OpenAI
from openai.types.responses import Response as OpenAIResponse
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
resp = client.responses.create(
model="Qwen/Qwen2.5-0.5B-Instruct",
instructions="You are a helpful assistant.",
input="Hello!",
stream=False,
max_output_tokens=5,
)
# Should be a single Response object with completed status and one output item containing text
self.assertIsInstance(resp, OpenAIResponse)
self.assertEqual(resp.status, "completed")
self.assertTrue(len(resp.output) >= 1)
first_item = resp.output[0]
self.assertEqual(first_item.type, "message")
self.assertEqual(first_item.status, "completed")
self.assertTrue(len(first_item.content) >= 1)
first_part = first_item.content[0]
self.assertEqual(first_part.type, "output_text")
self.assertIsInstance(first_part.text, str)
class ServeInfrastructureTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.port = 8042
thread = Thread(target=Serve, kwargs={"port": cls.port})
thread.daemon = True
thread.start()
def test_healthcheck(self):
"""Tests that the healthcheck endpoint works."""
response = _call_healthcheck(f"http://localhost:{self.port}")
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"status": "ok"})