Spaces:
Sleeping
Sleeping
| """Tests for OpenAI-compatible API format handling.""" | |
| import json | |
| import pytest | |
| import sys | |
| import os | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from openai_compat import ( | |
| ChatCompletionRequest, | |
| ChatMessage, | |
| InferenceParams, | |
| create_chat_response, | |
| create_error_response, | |
| create_stream_chunk, | |
| estimate_tokens, | |
| generate_completion_id, | |
| messages_to_dicts, | |
| stream_response_generator, | |
| ) | |
| class TestChatCompletionRequest: | |
| """Test request parsing.""" | |
| def test_parse_basic_request(self, sample_request_data): | |
| """Test parsing a basic chat completion request.""" | |
| request = ChatCompletionRequest(**sample_request_data) | |
| assert request.model == "meta-llama/Llama-3.1-8B-Instruct" | |
| assert len(request.messages) == 2 | |
| assert request.messages[0].role == "system" | |
| assert request.messages[1].role == "user" | |
| assert request.temperature == 0.7 | |
| assert request.max_tokens == 512 | |
| assert request.stream is False | |
| def test_parse_streaming_request(self, sample_streaming_request_data): | |
| """Test parsing a streaming request.""" | |
| request = ChatCompletionRequest(**sample_streaming_request_data) | |
| assert request.stream is True | |
| assert request.max_tokens == 256 | |
| def test_default_values(self): | |
| """Test that defaults are applied correctly.""" | |
| minimal_request = { | |
| "model": "test-model", | |
| "messages": [{"role": "user", "content": "Hi"}], | |
| } | |
| request = ChatCompletionRequest(**minimal_request) | |
| assert request.temperature == 0.7 | |
| assert request.top_p == 0.95 | |
| assert request.max_tokens == 512 | |
| assert request.stream is False | |
| assert request.stop is None | |
| def test_validation_temperature_bounds(self): | |
| """Test temperature validation.""" | |
| with pytest.raises(ValueError): | |
| ChatCompletionRequest( | |
| model="test", | |
| messages=[{"role": "user", "content": "Hi"}], | |
| temperature=-0.5, | |
| ) | |
| with pytest.raises(ValueError): | |
| ChatCompletionRequest( | |
| model="test", | |
| messages=[{"role": "user", "content": "Hi"}], | |
| temperature=2.5, | |
| ) | |
| class TestChatCompletionResponse: | |
| """Test response generation.""" | |
| def test_create_basic_response(self): | |
| """Test creating a basic chat response.""" | |
| response = create_chat_response( | |
| model="test-model", | |
| content="Hello! How can I help you?", | |
| prompt_tokens=10, | |
| completion_tokens=8, | |
| ) | |
| assert response.model == "test-model" | |
| assert response.object == "chat.completion" | |
| assert len(response.choices) == 1 | |
| assert response.choices[0].message.role == "assistant" | |
| assert response.choices[0].message.content == "Hello! How can I help you?" | |
| assert response.choices[0].finish_reason == "stop" | |
| assert response.usage.prompt_tokens == 10 | |
| assert response.usage.completion_tokens == 8 | |
| assert response.usage.total_tokens == 18 | |
| def test_response_has_unique_id(self): | |
| """Test that each response has a unique ID.""" | |
| response1 = create_chat_response(model="test", content="Hi") | |
| response2 = create_chat_response(model="test", content="Hi") | |
| assert response1.id != response2.id | |
| assert response1.id.startswith("chatcmpl-") | |
| def test_response_serialization(self): | |
| """Test that response can be serialized to JSON.""" | |
| response = create_chat_response( | |
| model="test-model", | |
| content="Test", | |
| ) | |
| json_str = response.model_dump_json() | |
| parsed = json.loads(json_str) | |
| assert parsed["model"] == "test-model" | |
| assert parsed["choices"][0]["message"]["content"] == "Test" | |
| class TestStreamingResponse: | |
| """Test streaming response format.""" | |
| def test_create_stream_chunk_with_content(self): | |
| """Test creating a streaming chunk with content.""" | |
| chunk = create_stream_chunk( | |
| completion_id="test-id", | |
| model="test-model", | |
| content="Hello", | |
| ) | |
| assert chunk.id == "test-id" | |
| assert chunk.object == "chat.completion.chunk" | |
| assert chunk.choices[0].delta.content == "Hello" | |
| assert chunk.choices[0].finish_reason is None | |
| def test_create_stream_chunk_with_role(self): | |
| """Test creating a streaming chunk with role (first chunk).""" | |
| chunk = create_stream_chunk( | |
| completion_id="test-id", | |
| model="test-model", | |
| role="assistant", | |
| ) | |
| assert chunk.choices[0].delta.role == "assistant" | |
| assert chunk.choices[0].delta.content is None | |
| def test_create_stream_chunk_with_finish_reason(self): | |
| """Test creating a final streaming chunk.""" | |
| chunk = create_stream_chunk( | |
| completion_id="test-id", | |
| model="test-model", | |
| finish_reason="stop", | |
| ) | |
| assert chunk.choices[0].finish_reason == "stop" | |
| def test_stream_response_generator(self): | |
| """Test the full streaming response generator.""" | |
| def token_gen(): | |
| yield "Hello" | |
| yield " World" | |
| yield "!" | |
| chunks = list(stream_response_generator("test-model", token_gen())) | |
| # Should have: role chunk, 3 content chunks, finish chunk, [DONE] | |
| assert len(chunks) == 6 | |
| # First chunk has role | |
| first_data = json.loads(chunks[0].replace("data: ", "").strip()) | |
| assert first_data["choices"][0]["delta"]["role"] == "assistant" | |
| # Content chunks | |
| second_data = json.loads(chunks[1].replace("data: ", "").strip()) | |
| assert second_data["choices"][0]["delta"]["content"] == "Hello" | |
| # Last data chunk has finish reason | |
| last_data = json.loads(chunks[4].replace("data: ", "").strip()) | |
| assert last_data["choices"][0]["finish_reason"] == "stop" | |
| # Very last is [DONE] | |
| assert chunks[5] == "data: [DONE]\n\n" | |
| class TestInferenceParams: | |
| """Test parameter extraction.""" | |
| def test_extract_params_from_request(self, sample_request_data): | |
| """Test extracting inference parameters from request.""" | |
| request = ChatCompletionRequest(**sample_request_data) | |
| params = InferenceParams.from_request(request) | |
| assert params.model_id == "meta-llama/Llama-3.1-8B-Instruct" | |
| assert len(params.messages) == 2 | |
| assert params.max_new_tokens == 512 | |
| assert params.temperature == 0.7 | |
| assert params.stream is False | |
| def test_messages_to_dicts(self): | |
| """Test converting ChatMessage objects to dicts.""" | |
| messages = [ | |
| ChatMessage(role="user", content="Hello"), | |
| ChatMessage(role="assistant", content="Hi there!"), | |
| ] | |
| dicts = messages_to_dicts(messages) | |
| assert dicts == [ | |
| {"role": "user", "content": "Hello"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| class TestErrorResponse: | |
| """Test error response format.""" | |
| def test_create_error_response(self): | |
| """Test creating an error response.""" | |
| error = create_error_response( | |
| message="Model not found", | |
| error_type="invalid_request_error", | |
| param="model", | |
| ) | |
| assert error.error.message == "Model not found" | |
| assert error.error.type == "invalid_request_error" | |
| assert error.error.param == "model" | |
| def test_error_response_serialization(self): | |
| """Test error response JSON serialization.""" | |
| error = create_error_response( | |
| message="Test error", | |
| error_type="server_error", | |
| code="internal_error", | |
| ) | |
| parsed = json.loads(error.model_dump_json()) | |
| assert parsed["error"]["message"] == "Test error" | |
| assert parsed["error"]["type"] == "server_error" | |
| assert parsed["error"]["code"] == "internal_error" | |
| class TestUtilityFunctions: | |
| """Test utility functions.""" | |
| def test_generate_completion_id_format(self): | |
| """Test completion ID format.""" | |
| id1 = generate_completion_id() | |
| id2 = generate_completion_id() | |
| assert id1.startswith("chatcmpl-") | |
| assert len(id1) == len("chatcmpl-") + 24 | |
| assert id1 != id2 # Should be unique | |
| def test_estimate_tokens(self): | |
| """Test rough token estimation.""" | |
| # ~4 chars per token | |
| assert estimate_tokens("Hello World!") == 3 # 12 chars / 4 = 3 | |
| assert estimate_tokens("A") == 1 # Min 1 | |
| assert estimate_tokens("This is a longer piece of text.") == 8 # 32 / 4 = 8 | |