|
|
import pytest |
|
|
from pydantic import ValidationError |
|
|
from app.models import ChatMessage, ChatRequest, ChatResponse, ModelInfo, ErrorResponse |
|
|
|
|
|
|
|
|
class TestChatMessage: |
|
|
"""Test ChatMessage model validation and behavior.""" |
|
|
|
|
|
def test_valid_chat_message(self): |
|
|
"""Test creating a valid chat message.""" |
|
|
message = ChatMessage(role="user", content="Hello, world!") |
|
|
assert message.role == "user" |
|
|
assert message.content == "Hello, world!" |
|
|
|
|
|
def test_invalid_role(self): |
|
|
"""Test that invalid roles raise ValidationError.""" |
|
|
with pytest.raises(ValidationError): |
|
|
ChatMessage(role="invalid_role", content="Hello") |
|
|
|
|
|
def test_empty_content(self): |
|
|
"""Test that empty content is allowed.""" |
|
|
message = ChatMessage(role="assistant", content="") |
|
|
assert message.content == "" |
|
|
|
|
|
def test_system_message(self): |
|
|
"""Test system message creation.""" |
|
|
message = ChatMessage(role="system", content="You are a helpful assistant.") |
|
|
assert message.role == "system" |
|
|
|
|
|
def test_assistant_message(self): |
|
|
"""Test assistant message creation.""" |
|
|
message = ChatMessage(role="assistant", content="I'm here to help!") |
|
|
assert message.role == "assistant" |
|
|
|
|
|
|
|
|
class TestChatRequest: |
|
|
"""Test ChatRequest model validation and behavior.""" |
|
|
|
|
|
def test_valid_chat_request(self): |
|
|
"""Test creating a valid chat request.""" |
|
|
messages = [ |
|
|
ChatMessage(role="system", content="You are helpful."), |
|
|
ChatMessage(role="user", content="Hello!") |
|
|
] |
|
|
request = ChatRequest(messages=messages) |
|
|
assert len(request.messages) == 2 |
|
|
assert request.model == "llama-2-7b-chat" |
|
|
assert request.max_tokens == 2048 |
|
|
assert request.temperature == 0.7 |
|
|
assert request.stream is True |
|
|
|
|
|
def test_custom_parameters(self): |
|
|
"""Test chat request with custom parameters.""" |
|
|
messages = [ChatMessage(role="user", content="Hello!")] |
|
|
request = ChatRequest( |
|
|
messages=messages, |
|
|
model="custom-model", |
|
|
max_tokens=100, |
|
|
temperature=0.5, |
|
|
top_p=0.8, |
|
|
stream=False |
|
|
) |
|
|
assert request.model == "custom-model" |
|
|
assert request.max_tokens == 100 |
|
|
assert request.temperature == 0.5 |
|
|
assert request.top_p == 0.8 |
|
|
assert request.stream is False |
|
|
|
|
|
def test_max_tokens_validation(self): |
|
|
"""Test max_tokens validation.""" |
|
|
messages = [ChatMessage(role="user", content="Hello!")] |
|
|
|
|
|
|
|
|
request = ChatRequest(messages=messages, max_tokens=1) |
|
|
assert request.max_tokens == 1 |
|
|
|
|
|
|
|
|
request = ChatRequest(messages=messages, max_tokens=4096) |
|
|
assert request.max_tokens == 4096 |
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, max_tokens=0) |
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, max_tokens=5000) |
|
|
|
|
|
def test_temperature_validation(self): |
|
|
"""Test temperature validation.""" |
|
|
messages = [ChatMessage(role="user", content="Hello!")] |
|
|
|
|
|
|
|
|
request = ChatRequest(messages=messages, temperature=0.0) |
|
|
assert request.temperature == 0.0 |
|
|
|
|
|
request = ChatRequest(messages=messages, temperature=2.0) |
|
|
assert request.temperature == 2.0 |
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, temperature=-0.1) |
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, temperature=2.1) |
|
|
|
|
|
def test_top_p_validation(self): |
|
|
"""Test top_p validation.""" |
|
|
messages = [ChatMessage(role="user", content="Hello!")] |
|
|
|
|
|
|
|
|
request = ChatRequest(messages=messages, top_p=0.0) |
|
|
assert request.top_p == 0.0 |
|
|
|
|
|
request = ChatRequest(messages=messages, top_p=1.0) |
|
|
assert request.top_p == 1.0 |
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, top_p=-0.1) |
|
|
|
|
|
with pytest.raises(ValidationError): |
|
|
ChatRequest(messages=messages, top_p=1.1) |
|
|
|
|
|
def test_empty_messages(self): |
|
|
"""Test that empty messages list is allowed.""" |
|
|
request = ChatRequest(messages=[]) |
|
|
assert len(request.messages) == 0 |
|
|
|
|
|
|
|
|
class TestChatResponse: |
|
|
"""Test ChatResponse model validation and behavior.""" |
|
|
|
|
|
def test_valid_chat_response(self): |
|
|
"""Test creating a valid chat response.""" |
|
|
response = ChatResponse( |
|
|
id="test-id", |
|
|
created=1234567890, |
|
|
model="llama-2-7b-chat", |
|
|
choices=[{ |
|
|
"index": 0, |
|
|
"message": {"role": "assistant", "content": "Hello!"}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
) |
|
|
assert response.id == "test-id" |
|
|
assert response.object == "chat.completion" |
|
|
assert response.created == 1234567890 |
|
|
assert response.model == "llama-2-7b-chat" |
|
|
assert len(response.choices) == 1 |
|
|
|
|
|
def test_chat_response_with_usage(self): |
|
|
"""Test chat response with usage statistics.""" |
|
|
response = ChatResponse( |
|
|
id="test-id", |
|
|
created=1234567890, |
|
|
model="llama-2-7b-chat", |
|
|
choices=[{ |
|
|
"index": 0, |
|
|
"message": {"role": "assistant", "content": "Hello!"}, |
|
|
"finish_reason": "stop" |
|
|
}], |
|
|
usage={ |
|
|
"prompt_tokens": 10, |
|
|
"completion_tokens": 5, |
|
|
"total_tokens": 15 |
|
|
} |
|
|
) |
|
|
assert response.usage is not None |
|
|
assert response.usage["prompt_tokens"] == 10 |
|
|
|
|
|
|
|
|
class TestModelInfo: |
|
|
"""Test ModelInfo model validation and behavior.""" |
|
|
|
|
|
def test_valid_model_info(self): |
|
|
"""Test creating valid model info.""" |
|
|
model_info = ModelInfo( |
|
|
id="llama-2-7b-chat", |
|
|
created=1234567890 |
|
|
) |
|
|
assert model_info.id == "llama-2-7b-chat" |
|
|
assert model_info.object == "model" |
|
|
assert model_info.created == 1234567890 |
|
|
assert model_info.owned_by == "huggingface" |
|
|
|
|
|
|
|
|
class TestErrorResponse: |
|
|
"""Test ErrorResponse model validation and behavior.""" |
|
|
|
|
|
def test_valid_error_response(self): |
|
|
"""Test creating a valid error response.""" |
|
|
error_response = ErrorResponse( |
|
|
error={ |
|
|
"message": "Invalid request", |
|
|
"type": "invalid_request_error", |
|
|
"code": 400 |
|
|
} |
|
|
) |
|
|
assert error_response.error["message"] == "Invalid request" |
|
|
assert error_response.error["type"] == "invalid_request_error" |
|
|
assert error_response.error["code"] == 400 |
|
|
|
|
|
|
|
|
class TestModelSerialization: |
|
|
"""Test model serialization and deserialization.""" |
|
|
|
|
|
def test_chat_message_serialization(self): |
|
|
"""Test ChatMessage JSON serialization.""" |
|
|
message = ChatMessage(role="user", content="Hello!") |
|
|
data = message.model_dump() |
|
|
assert data["role"] == "user" |
|
|
assert data["content"] == "Hello!" |
|
|
|
|
|
def test_chat_request_serialization(self): |
|
|
"""Test ChatRequest JSON serialization.""" |
|
|
messages = [ChatMessage(role="user", content="Hello!")] |
|
|
request = ChatRequest(messages=messages) |
|
|
data = request.model_dump() |
|
|
assert "messages" in data |
|
|
assert len(data["messages"]) == 1 |
|
|
assert data["model"] == "llama-2-7b-chat" |
|
|
|
|
|
def test_chat_request_deserialization(self): |
|
|
"""Test ChatRequest JSON deserialization.""" |
|
|
data = { |
|
|
"messages": [ |
|
|
{"role": "user", "content": "Hello!"} |
|
|
], |
|
|
"model": "custom-model", |
|
|
"max_tokens": 100 |
|
|
} |
|
|
request = ChatRequest.model_validate(data) |
|
|
assert len(request.messages) == 1 |
|
|
assert request.model == "custom-model" |
|
|
assert request.max_tokens == 100 |
|
|
|