opencode-zerogpu / tests /test_openai_compat.py
serenichron's picture
Initial implementation of ZeroGPU OpenCode Provider
adcb9bd
"""Tests for OpenAI-compatible API format handling."""
import json
import pytest
import sys
import os
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from openai_compat import (
ChatCompletionRequest,
ChatMessage,
InferenceParams,
create_chat_response,
create_error_response,
create_stream_chunk,
estimate_tokens,
generate_completion_id,
messages_to_dicts,
stream_response_generator,
)
class TestChatCompletionRequest:
"""Test request parsing."""
def test_parse_basic_request(self, sample_request_data):
"""Test parsing a basic chat completion request."""
request = ChatCompletionRequest(**sample_request_data)
assert request.model == "meta-llama/Llama-3.1-8B-Instruct"
assert len(request.messages) == 2
assert request.messages[0].role == "system"
assert request.messages[1].role == "user"
assert request.temperature == 0.7
assert request.max_tokens == 512
assert request.stream is False
def test_parse_streaming_request(self, sample_streaming_request_data):
"""Test parsing a streaming request."""
request = ChatCompletionRequest(**sample_streaming_request_data)
assert request.stream is True
assert request.max_tokens == 256
def test_default_values(self):
"""Test that defaults are applied correctly."""
minimal_request = {
"model": "test-model",
"messages": [{"role": "user", "content": "Hi"}],
}
request = ChatCompletionRequest(**minimal_request)
assert request.temperature == 0.7
assert request.top_p == 0.95
assert request.max_tokens == 512
assert request.stream is False
assert request.stop is None
def test_validation_temperature_bounds(self):
"""Test temperature validation."""
with pytest.raises(ValueError):
ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "Hi"}],
temperature=-0.5,
)
with pytest.raises(ValueError):
ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "Hi"}],
temperature=2.5,
)
class TestChatCompletionResponse:
"""Test response generation."""
def test_create_basic_response(self):
"""Test creating a basic chat response."""
response = create_chat_response(
model="test-model",
content="Hello! How can I help you?",
prompt_tokens=10,
completion_tokens=8,
)
assert response.model == "test-model"
assert response.object == "chat.completion"
assert len(response.choices) == 1
assert response.choices[0].message.role == "assistant"
assert response.choices[0].message.content == "Hello! How can I help you?"
assert response.choices[0].finish_reason == "stop"
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 8
assert response.usage.total_tokens == 18
def test_response_has_unique_id(self):
"""Test that each response has a unique ID."""
response1 = create_chat_response(model="test", content="Hi")
response2 = create_chat_response(model="test", content="Hi")
assert response1.id != response2.id
assert response1.id.startswith("chatcmpl-")
def test_response_serialization(self):
"""Test that response can be serialized to JSON."""
response = create_chat_response(
model="test-model",
content="Test",
)
json_str = response.model_dump_json()
parsed = json.loads(json_str)
assert parsed["model"] == "test-model"
assert parsed["choices"][0]["message"]["content"] == "Test"
class TestStreamingResponse:
"""Test streaming response format."""
def test_create_stream_chunk_with_content(self):
"""Test creating a streaming chunk with content."""
chunk = create_stream_chunk(
completion_id="test-id",
model="test-model",
content="Hello",
)
assert chunk.id == "test-id"
assert chunk.object == "chat.completion.chunk"
assert chunk.choices[0].delta.content == "Hello"
assert chunk.choices[0].finish_reason is None
def test_create_stream_chunk_with_role(self):
"""Test creating a streaming chunk with role (first chunk)."""
chunk = create_stream_chunk(
completion_id="test-id",
model="test-model",
role="assistant",
)
assert chunk.choices[0].delta.role == "assistant"
assert chunk.choices[0].delta.content is None
def test_create_stream_chunk_with_finish_reason(self):
"""Test creating a final streaming chunk."""
chunk = create_stream_chunk(
completion_id="test-id",
model="test-model",
finish_reason="stop",
)
assert chunk.choices[0].finish_reason == "stop"
def test_stream_response_generator(self):
"""Test the full streaming response generator."""
def token_gen():
yield "Hello"
yield " World"
yield "!"
chunks = list(stream_response_generator("test-model", token_gen()))
# Should have: role chunk, 3 content chunks, finish chunk, [DONE]
assert len(chunks) == 6
# First chunk has role
first_data = json.loads(chunks[0].replace("data: ", "").strip())
assert first_data["choices"][0]["delta"]["role"] == "assistant"
# Content chunks
second_data = json.loads(chunks[1].replace("data: ", "").strip())
assert second_data["choices"][0]["delta"]["content"] == "Hello"
# Last data chunk has finish reason
last_data = json.loads(chunks[4].replace("data: ", "").strip())
assert last_data["choices"][0]["finish_reason"] == "stop"
# Very last is [DONE]
assert chunks[5] == "data: [DONE]\n\n"
class TestInferenceParams:
"""Test parameter extraction."""
def test_extract_params_from_request(self, sample_request_data):
"""Test extracting inference parameters from request."""
request = ChatCompletionRequest(**sample_request_data)
params = InferenceParams.from_request(request)
assert params.model_id == "meta-llama/Llama-3.1-8B-Instruct"
assert len(params.messages) == 2
assert params.max_new_tokens == 512
assert params.temperature == 0.7
assert params.stream is False
def test_messages_to_dicts(self):
"""Test converting ChatMessage objects to dicts."""
messages = [
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
]
dicts = messages_to_dicts(messages)
assert dicts == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
class TestErrorResponse:
"""Test error response format."""
def test_create_error_response(self):
"""Test creating an error response."""
error = create_error_response(
message="Model not found",
error_type="invalid_request_error",
param="model",
)
assert error.error.message == "Model not found"
assert error.error.type == "invalid_request_error"
assert error.error.param == "model"
def test_error_response_serialization(self):
"""Test error response JSON serialization."""
error = create_error_response(
message="Test error",
error_type="server_error",
code="internal_error",
)
parsed = json.loads(error.model_dump_json())
assert parsed["error"]["message"] == "Test error"
assert parsed["error"]["type"] == "server_error"
assert parsed["error"]["code"] == "internal_error"
class TestUtilityFunctions:
"""Test utility functions."""
def test_generate_completion_id_format(self):
"""Test completion ID format."""
id1 = generate_completion_id()
id2 = generate_completion_id()
assert id1.startswith("chatcmpl-")
assert len(id1) == len("chatcmpl-") + 24
assert id1 != id2 # Should be unique
def test_estimate_tokens(self):
"""Test rough token estimation."""
# ~4 chars per token
assert estimate_tokens("Hello World!") == 3 # 12 chars / 4 = 3
assert estimate_tokens("A") == 1 # Min 1
assert estimate_tokens("This is a longer piece of text.") == 8 # 32 / 4 = 8