"""Tests for OpenAI-compatible API format handling.""" import json import pytest import sys import os # Add parent directory to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from openai_compat import ( ChatCompletionRequest, ChatMessage, InferenceParams, create_chat_response, create_error_response, create_stream_chunk, estimate_tokens, generate_completion_id, messages_to_dicts, stream_response_generator, ) class TestChatCompletionRequest: """Test request parsing.""" def test_parse_basic_request(self, sample_request_data): """Test parsing a basic chat completion request.""" request = ChatCompletionRequest(**sample_request_data) assert request.model == "meta-llama/Llama-3.1-8B-Instruct" assert len(request.messages) == 2 assert request.messages[0].role == "system" assert request.messages[1].role == "user" assert request.temperature == 0.7 assert request.max_tokens == 512 assert request.stream is False def test_parse_streaming_request(self, sample_streaming_request_data): """Test parsing a streaming request.""" request = ChatCompletionRequest(**sample_streaming_request_data) assert request.stream is True assert request.max_tokens == 256 def test_default_values(self): """Test that defaults are applied correctly.""" minimal_request = { "model": "test-model", "messages": [{"role": "user", "content": "Hi"}], } request = ChatCompletionRequest(**minimal_request) assert request.temperature == 0.7 assert request.top_p == 0.95 assert request.max_tokens == 512 assert request.stream is False assert request.stop is None def test_validation_temperature_bounds(self): """Test temperature validation.""" with pytest.raises(ValueError): ChatCompletionRequest( model="test", messages=[{"role": "user", "content": "Hi"}], temperature=-0.5, ) with pytest.raises(ValueError): ChatCompletionRequest( model="test", messages=[{"role": "user", "content": "Hi"}], temperature=2.5, ) class TestChatCompletionResponse: """Test response generation.""" def test_create_basic_response(self): """Test creating a basic chat response.""" response = create_chat_response( model="test-model", content="Hello! How can I help you?", prompt_tokens=10, completion_tokens=8, ) assert response.model == "test-model" assert response.object == "chat.completion" assert len(response.choices) == 1 assert response.choices[0].message.role == "assistant" assert response.choices[0].message.content == "Hello! How can I help you?" assert response.choices[0].finish_reason == "stop" assert response.usage.prompt_tokens == 10 assert response.usage.completion_tokens == 8 assert response.usage.total_tokens == 18 def test_response_has_unique_id(self): """Test that each response has a unique ID.""" response1 = create_chat_response(model="test", content="Hi") response2 = create_chat_response(model="test", content="Hi") assert response1.id != response2.id assert response1.id.startswith("chatcmpl-") def test_response_serialization(self): """Test that response can be serialized to JSON.""" response = create_chat_response( model="test-model", content="Test", ) json_str = response.model_dump_json() parsed = json.loads(json_str) assert parsed["model"] == "test-model" assert parsed["choices"][0]["message"]["content"] == "Test" class TestStreamingResponse: """Test streaming response format.""" def test_create_stream_chunk_with_content(self): """Test creating a streaming chunk with content.""" chunk = create_stream_chunk( completion_id="test-id", model="test-model", content="Hello", ) assert chunk.id == "test-id" assert chunk.object == "chat.completion.chunk" assert chunk.choices[0].delta.content == "Hello" assert chunk.choices[0].finish_reason is None def test_create_stream_chunk_with_role(self): """Test creating a streaming chunk with role (first chunk).""" chunk = create_stream_chunk( completion_id="test-id", model="test-model", role="assistant", ) assert chunk.choices[0].delta.role == "assistant" assert chunk.choices[0].delta.content is None def test_create_stream_chunk_with_finish_reason(self): """Test creating a final streaming chunk.""" chunk = create_stream_chunk( completion_id="test-id", model="test-model", finish_reason="stop", ) assert chunk.choices[0].finish_reason == "stop" def test_stream_response_generator(self): """Test the full streaming response generator.""" def token_gen(): yield "Hello" yield " World" yield "!" chunks = list(stream_response_generator("test-model", token_gen())) # Should have: role chunk, 3 content chunks, finish chunk, [DONE] assert len(chunks) == 6 # First chunk has role first_data = json.loads(chunks[0].replace("data: ", "").strip()) assert first_data["choices"][0]["delta"]["role"] == "assistant" # Content chunks second_data = json.loads(chunks[1].replace("data: ", "").strip()) assert second_data["choices"][0]["delta"]["content"] == "Hello" # Last data chunk has finish reason last_data = json.loads(chunks[4].replace("data: ", "").strip()) assert last_data["choices"][0]["finish_reason"] == "stop" # Very last is [DONE] assert chunks[5] == "data: [DONE]\n\n" class TestInferenceParams: """Test parameter extraction.""" def test_extract_params_from_request(self, sample_request_data): """Test extracting inference parameters from request.""" request = ChatCompletionRequest(**sample_request_data) params = InferenceParams.from_request(request) assert params.model_id == "meta-llama/Llama-3.1-8B-Instruct" assert len(params.messages) == 2 assert params.max_new_tokens == 512 assert params.temperature == 0.7 assert params.stream is False def test_messages_to_dicts(self): """Test converting ChatMessage objects to dicts.""" messages = [ ChatMessage(role="user", content="Hello"), ChatMessage(role="assistant", content="Hi there!"), ] dicts = messages_to_dicts(messages) assert dicts == [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] class TestErrorResponse: """Test error response format.""" def test_create_error_response(self): """Test creating an error response.""" error = create_error_response( message="Model not found", error_type="invalid_request_error", param="model", ) assert error.error.message == "Model not found" assert error.error.type == "invalid_request_error" assert error.error.param == "model" def test_error_response_serialization(self): """Test error response JSON serialization.""" error = create_error_response( message="Test error", error_type="server_error", code="internal_error", ) parsed = json.loads(error.model_dump_json()) assert parsed["error"]["message"] == "Test error" assert parsed["error"]["type"] == "server_error" assert parsed["error"]["code"] == "internal_error" class TestUtilityFunctions: """Test utility functions.""" def test_generate_completion_id_format(self): """Test completion ID format.""" id1 = generate_completion_id() id2 = generate_completion_id() assert id1.startswith("chatcmpl-") assert len(id1) == len("chatcmpl-") + 24 assert id1 != id2 # Should be unique def test_estimate_tokens(self): """Test rough token estimation.""" # ~4 chars per token assert estimate_tokens("Hello World!") == 3 # 12 chars / 4 = 3 assert estimate_tokens("A") == 1 # Min 1 assert estimate_tokens("This is a longer piece of text.") == 8 # 32 / 4 = 8