Spaces:
Running
Running
| """Streaming contract tests for OpenAI-compatible SSE endpoints.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| from collections import deque | |
| import pytest | |
| from fastapi.responses import StreamingResponse | |
| from app.core.model_registry import ModelSpec | |
| from app.routers import chat, completions, responses | |
| from app.schemas.chat import ChatCompletionRequest | |
| from app.schemas.completions import CompletionRequest | |
| from app.schemas.responses import ResponseRequest | |
| class DummyStream: | |
| def __init__( | |
| self, | |
| *, | |
| tokens: list[str], | |
| prompt_tokens: int, | |
| completion_tokens: int, | |
| finish_reason: str = "stop", | |
| ) -> None: | |
| self._tokens = tokens | |
| self.prompt_tokens = prompt_tokens | |
| self.completion_tokens = completion_tokens | |
| self.finish_reason = finish_reason | |
| def iter_tokens(self): | |
| for token in self._tokens: | |
| yield token | |
| async def _read_stream_body(response: StreamingResponse) -> str: | |
| chunks: list[str] = [] | |
| async for chunk in response.body_iterator: | |
| if isinstance(chunk, bytes): | |
| chunks.append(chunk.decode("utf-8")) | |
| else: | |
| chunks.append(chunk) | |
| return "".join(chunks) | |
| def _parse_sse_data_frames(raw_body: str) -> list[str]: | |
| frames = [frame.strip() for frame in raw_body.split("\n\n") if frame.strip()] | |
| data_frames: list[str] = [] | |
| for frame in frames: | |
| assert frame.startswith("data: ") | |
| data_frames.append(frame[len("data: ") :]) | |
| return data_frames | |
| def test_completions_stream_emits_sse_chunks_usage_and_done( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None) | |
| monkeypatch.setattr( | |
| "app.routers.completions.engine.create_stream", | |
| lambda *_, **__: DummyStream( | |
| tokens=["Hel", "lo"], | |
| prompt_tokens=3, | |
| completion_tokens=2, | |
| finish_reason="stop", | |
| ), | |
| ) | |
| payload = CompletionRequest.model_validate( | |
| { | |
| "model": "GPT3-dev", | |
| "prompt": "Hello", | |
| "stream": True, | |
| } | |
| ) | |
| response = asyncio.run(completions.create_completion(payload)) | |
| assert isinstance(response, StreamingResponse) | |
| body = asyncio.run(_read_stream_body(response)) | |
| data_frames = _parse_sse_data_frames(body) | |
| assert data_frames[-1] == "[DONE]" | |
| chunks = [json.loads(frame) for frame in data_frames[:-1]] | |
| assert chunks[0]["object"] == "text_completion.chunk" | |
| assert chunks[0]["choices"][0]["text"] == "Hel" | |
| assert chunks[1]["choices"][0]["text"] == "lo" | |
| assert chunks[2]["choices"][0]["finish_reason"] == "stop" | |
| tail = chunks[-1] | |
| assert tail["choices"] == [] | |
| assert tail["usage"] == { | |
| "prompt_tokens": 3, | |
| "completion_tokens": 2, | |
| "total_tokens": 5, | |
| } | |
| def test_chat_stream_emits_initial_role_delta_and_done( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| monkeypatch.setattr( | |
| "app.routers.chat.get_model_spec", | |
| lambda model: ModelSpec(name=model, hf_repo="dummy/instruct", is_instruct=True), | |
| ) | |
| monkeypatch.setattr("app.routers.chat.engine.apply_chat_template", lambda *_: "formatted") | |
| monkeypatch.setattr( | |
| "app.routers.chat.engine.create_stream", | |
| lambda *_, **__: DummyStream( | |
| tokens=["Hi", " there"], | |
| prompt_tokens=4, | |
| completion_tokens=2, | |
| finish_reason="stop", | |
| ), | |
| ) | |
| payload = ChatCompletionRequest.model_validate( | |
| { | |
| "model": "GPT4-dev-177M-1511-Instruct", | |
| "messages": [{"role": "user", "content": "hello"}], | |
| "stream": True, | |
| } | |
| ) | |
| response = asyncio.run(chat.create_chat_completion(payload)) | |
| assert isinstance(response, StreamingResponse) | |
| body = asyncio.run(_read_stream_body(response)) | |
| data_frames = _parse_sse_data_frames(body) | |
| assert data_frames[-1] == "[DONE]" | |
| chunks = [json.loads(frame) for frame in data_frames[:-1]] | |
| assert chunks[0]["choices"][0]["delta"]["role"] == "assistant" | |
| assert chunks[1]["choices"][0]["delta"]["content"] == "Hi" | |
| assert chunks[2]["choices"][0]["delta"]["content"] == " there" | |
| assert chunks[3]["choices"][0]["finish_reason"] == "stop" | |
| def test_responses_stream_emits_created_delta_completed_done( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| monkeypatch.setattr( | |
| "app.routers.responses.get_model_spec", | |
| lambda model: ModelSpec(name=model, hf_repo="dummy/base", is_instruct=False), | |
| ) | |
| monkeypatch.setattr( | |
| "app.routers.responses.engine.create_stream", | |
| lambda *_, **__: DummyStream( | |
| tokens=["Hi", " there"], | |
| prompt_tokens=5, | |
| completion_tokens=2, | |
| finish_reason="stop", | |
| ), | |
| ) | |
| payload = ResponseRequest.model_validate( | |
| { | |
| "model": "GPT3-dev", | |
| "input": "Say hi", | |
| "stream": True, | |
| } | |
| ) | |
| response = asyncio.run(responses.create_response(payload)) | |
| assert isinstance(response, StreamingResponse) | |
| body = asyncio.run(_read_stream_body(response)) | |
| data_frames = _parse_sse_data_frames(body) | |
| assert data_frames[-1] == "[DONE]" | |
| events = [json.loads(frame) for frame in data_frames[:-1]] | |
| assert events[0]["type"] == "response.created" | |
| assert events[1]["type"] == "response.output_text.delta" | |
| assert events[1]["delta"] == "Hi" | |
| assert events[2]["type"] == "response.output_text.delta" | |
| assert events[2]["delta"] == " there" | |
| assert events[3]["type"] == "response.completed" | |
| assert events[3]["response"]["output"][0]["content"][0]["text"] == "Hi there" | |
| assert events[3]["response"]["usage"] == { | |
| "input_tokens": 5, | |
| "output_tokens": 2, | |
| "total_tokens": 7, | |
| } | |
| def test_completions_stream_usage_aggregates_prompt_and_completion_tokens( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| calls: list[str] = [] | |
| streams = deque( | |
| [ | |
| DummyStream(tokens=["a1"], prompt_tokens=10, completion_tokens=1), | |
| DummyStream(tokens=["a2"], prompt_tokens=999, completion_tokens=2), | |
| DummyStream(tokens=["b1"], prompt_tokens=20, completion_tokens=3), | |
| DummyStream(tokens=["b2"], prompt_tokens=888, completion_tokens=4), | |
| ] | |
| ) | |
| def fake_create_stream(model: str, prompt: str, **_: object) -> DummyStream: | |
| calls.append(prompt) | |
| return streams.popleft() | |
| monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None) | |
| monkeypatch.setattr("app.routers.completions.engine.create_stream", fake_create_stream) | |
| payload = CompletionRequest.model_validate( | |
| { | |
| "model": "GPT3-dev", | |
| "prompt": ["alpha", "beta"], | |
| "n": 2, | |
| "stream": True, | |
| } | |
| ) | |
| response = asyncio.run(completions.create_completion(payload)) | |
| body = asyncio.run(_read_stream_body(response)) | |
| data_frames = _parse_sse_data_frames(body) | |
| assert data_frames[-1] == "[DONE]" | |
| chunks = [json.loads(frame) for frame in data_frames[:-1]] | |
| tail = chunks[-1] | |
| assert calls == ["alpha", "alpha", "beta", "beta"] | |
| assert tail["usage"] == { | |
| "prompt_tokens": 30, | |
| "completion_tokens": 10, | |
| "total_tokens": 40, | |
| } | |