gpt3-dev-api / tests /test_streaming_contracts.py
Kyryll Kochkin
AI added tests
88c0e85
"""Streaming contract tests for OpenAI-compatible SSE endpoints."""
from __future__ import annotations
import asyncio
import json
from collections import deque
import pytest
from fastapi.responses import StreamingResponse
from app.core.model_registry import ModelSpec
from app.routers import chat, completions, responses
from app.schemas.chat import ChatCompletionRequest
from app.schemas.completions import CompletionRequest
from app.schemas.responses import ResponseRequest
class DummyStream:
def __init__(
self,
*,
tokens: list[str],
prompt_tokens: int,
completion_tokens: int,
finish_reason: str = "stop",
) -> None:
self._tokens = tokens
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.finish_reason = finish_reason
def iter_tokens(self):
for token in self._tokens:
yield token
async def _read_stream_body(response: StreamingResponse) -> str:
chunks: list[str] = []
async for chunk in response.body_iterator:
if isinstance(chunk, bytes):
chunks.append(chunk.decode("utf-8"))
else:
chunks.append(chunk)
return "".join(chunks)
def _parse_sse_data_frames(raw_body: str) -> list[str]:
frames = [frame.strip() for frame in raw_body.split("\n\n") if frame.strip()]
data_frames: list[str] = []
for frame in frames:
assert frame.startswith("data: ")
data_frames.append(frame[len("data: ") :])
return data_frames
def test_completions_stream_emits_sse_chunks_usage_and_done(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None)
monkeypatch.setattr(
"app.routers.completions.engine.create_stream",
lambda *_, **__: DummyStream(
tokens=["Hel", "lo"],
prompt_tokens=3,
completion_tokens=2,
finish_reason="stop",
),
)
payload = CompletionRequest.model_validate(
{
"model": "GPT3-dev",
"prompt": "Hello",
"stream": True,
}
)
response = asyncio.run(completions.create_completion(payload))
assert isinstance(response, StreamingResponse)
body = asyncio.run(_read_stream_body(response))
data_frames = _parse_sse_data_frames(body)
assert data_frames[-1] == "[DONE]"
chunks = [json.loads(frame) for frame in data_frames[:-1]]
assert chunks[0]["object"] == "text_completion.chunk"
assert chunks[0]["choices"][0]["text"] == "Hel"
assert chunks[1]["choices"][0]["text"] == "lo"
assert chunks[2]["choices"][0]["finish_reason"] == "stop"
tail = chunks[-1]
assert tail["choices"] == []
assert tail["usage"] == {
"prompt_tokens": 3,
"completion_tokens": 2,
"total_tokens": 5,
}
def test_chat_stream_emits_initial_role_delta_and_done(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"app.routers.chat.get_model_spec",
lambda model: ModelSpec(name=model, hf_repo="dummy/instruct", is_instruct=True),
)
monkeypatch.setattr("app.routers.chat.engine.apply_chat_template", lambda *_: "formatted")
monkeypatch.setattr(
"app.routers.chat.engine.create_stream",
lambda *_, **__: DummyStream(
tokens=["Hi", " there"],
prompt_tokens=4,
completion_tokens=2,
finish_reason="stop",
),
)
payload = ChatCompletionRequest.model_validate(
{
"model": "GPT4-dev-177M-1511-Instruct",
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
}
)
response = asyncio.run(chat.create_chat_completion(payload))
assert isinstance(response, StreamingResponse)
body = asyncio.run(_read_stream_body(response))
data_frames = _parse_sse_data_frames(body)
assert data_frames[-1] == "[DONE]"
chunks = [json.loads(frame) for frame in data_frames[:-1]]
assert chunks[0]["choices"][0]["delta"]["role"] == "assistant"
assert chunks[1]["choices"][0]["delta"]["content"] == "Hi"
assert chunks[2]["choices"][0]["delta"]["content"] == " there"
assert chunks[3]["choices"][0]["finish_reason"] == "stop"
def test_responses_stream_emits_created_delta_completed_done(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(
"app.routers.responses.get_model_spec",
lambda model: ModelSpec(name=model, hf_repo="dummy/base", is_instruct=False),
)
monkeypatch.setattr(
"app.routers.responses.engine.create_stream",
lambda *_, **__: DummyStream(
tokens=["Hi", " there"],
prompt_tokens=5,
completion_tokens=2,
finish_reason="stop",
),
)
payload = ResponseRequest.model_validate(
{
"model": "GPT3-dev",
"input": "Say hi",
"stream": True,
}
)
response = asyncio.run(responses.create_response(payload))
assert isinstance(response, StreamingResponse)
body = asyncio.run(_read_stream_body(response))
data_frames = _parse_sse_data_frames(body)
assert data_frames[-1] == "[DONE]"
events = [json.loads(frame) for frame in data_frames[:-1]]
assert events[0]["type"] == "response.created"
assert events[1]["type"] == "response.output_text.delta"
assert events[1]["delta"] == "Hi"
assert events[2]["type"] == "response.output_text.delta"
assert events[2]["delta"] == " there"
assert events[3]["type"] == "response.completed"
assert events[3]["response"]["output"][0]["content"][0]["text"] == "Hi there"
assert events[3]["response"]["usage"] == {
"input_tokens": 5,
"output_tokens": 2,
"total_tokens": 7,
}
def test_completions_stream_usage_aggregates_prompt_and_completion_tokens(
monkeypatch: pytest.MonkeyPatch,
) -> None:
calls: list[str] = []
streams = deque(
[
DummyStream(tokens=["a1"], prompt_tokens=10, completion_tokens=1),
DummyStream(tokens=["a2"], prompt_tokens=999, completion_tokens=2),
DummyStream(tokens=["b1"], prompt_tokens=20, completion_tokens=3),
DummyStream(tokens=["b2"], prompt_tokens=888, completion_tokens=4),
]
)
def fake_create_stream(model: str, prompt: str, **_: object) -> DummyStream:
calls.append(prompt)
return streams.popleft()
monkeypatch.setattr("app.routers.completions.get_model_spec", lambda _: None)
monkeypatch.setattr("app.routers.completions.engine.create_stream", fake_create_stream)
payload = CompletionRequest.model_validate(
{
"model": "GPT3-dev",
"prompt": ["alpha", "beta"],
"n": 2,
"stream": True,
}
)
response = asyncio.run(completions.create_completion(payload))
body = asyncio.run(_read_stream_body(response))
data_frames = _parse_sse_data_frames(body)
assert data_frames[-1] == "[DONE]"
chunks = [json.loads(frame) for frame in data_frames[:-1]]
tail = chunks[-1]
assert calls == ["alpha", "alpha", "beta", "beta"]
assert tail["usage"] == {
"prompt_tokens": 30,
"completion_tokens": 10,
"total_tokens": 40,
}