chat-bot / tests /test_models.py
surahj's picture
Initial commit: LLM Chat Interface for HF Spaces
c2f9396
import pytest
from pydantic import ValidationError
from app.models import ChatMessage, ChatRequest, ChatResponse, ModelInfo, ErrorResponse
class TestChatMessage:
"""Test ChatMessage model validation and behavior."""
def test_valid_chat_message(self):
"""Test creating a valid chat message."""
message = ChatMessage(role="user", content="Hello, world!")
assert message.role == "user"
assert message.content == "Hello, world!"
def test_invalid_role(self):
"""Test that invalid roles raise ValidationError."""
with pytest.raises(ValidationError):
ChatMessage(role="invalid_role", content="Hello")
def test_empty_content(self):
"""Test that empty content is allowed."""
message = ChatMessage(role="assistant", content="")
assert message.content == ""
def test_system_message(self):
"""Test system message creation."""
message = ChatMessage(role="system", content="You are a helpful assistant.")
assert message.role == "system"
def test_assistant_message(self):
"""Test assistant message creation."""
message = ChatMessage(role="assistant", content="I'm here to help!")
assert message.role == "assistant"
class TestChatRequest:
"""Test ChatRequest model validation and behavior."""
def test_valid_chat_request(self):
"""Test creating a valid chat request."""
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content="Hello!")
]
request = ChatRequest(messages=messages)
assert len(request.messages) == 2
assert request.model == "llama-2-7b-chat"
assert request.max_tokens == 2048
assert request.temperature == 0.7
assert request.stream is True
def test_custom_parameters(self):
"""Test chat request with custom parameters."""
messages = [ChatMessage(role="user", content="Hello!")]
request = ChatRequest(
messages=messages,
model="custom-model",
max_tokens=100,
temperature=0.5,
top_p=0.8,
stream=False
)
assert request.model == "custom-model"
assert request.max_tokens == 100
assert request.temperature == 0.5
assert request.top_p == 0.8
assert request.stream is False
def test_max_tokens_validation(self):
"""Test max_tokens validation."""
messages = [ChatMessage(role="user", content="Hello!")]
# Test minimum value
request = ChatRequest(messages=messages, max_tokens=1)
assert request.max_tokens == 1
# Test maximum value
request = ChatRequest(messages=messages, max_tokens=4096)
assert request.max_tokens == 4096
# Test invalid minimum
with pytest.raises(ValidationError):
ChatRequest(messages=messages, max_tokens=0)
# Test invalid maximum
with pytest.raises(ValidationError):
ChatRequest(messages=messages, max_tokens=5000)
def test_temperature_validation(self):
"""Test temperature validation."""
messages = [ChatMessage(role="user", content="Hello!")]
# Test valid range
request = ChatRequest(messages=messages, temperature=0.0)
assert request.temperature == 0.0
request = ChatRequest(messages=messages, temperature=2.0)
assert request.temperature == 2.0
# Test invalid values
with pytest.raises(ValidationError):
ChatRequest(messages=messages, temperature=-0.1)
with pytest.raises(ValidationError):
ChatRequest(messages=messages, temperature=2.1)
def test_top_p_validation(self):
"""Test top_p validation."""
messages = [ChatMessage(role="user", content="Hello!")]
# Test valid range
request = ChatRequest(messages=messages, top_p=0.0)
assert request.top_p == 0.0
request = ChatRequest(messages=messages, top_p=1.0)
assert request.top_p == 1.0
# Test invalid values
with pytest.raises(ValidationError):
ChatRequest(messages=messages, top_p=-0.1)
with pytest.raises(ValidationError):
ChatRequest(messages=messages, top_p=1.1)
def test_empty_messages(self):
"""Test that empty messages list is allowed."""
request = ChatRequest(messages=[])
assert len(request.messages) == 0
class TestChatResponse:
"""Test ChatResponse model validation and behavior."""
def test_valid_chat_response(self):
"""Test creating a valid chat response."""
response = ChatResponse(
id="test-id",
created=1234567890,
model="llama-2-7b-chat",
choices=[{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}]
)
assert response.id == "test-id"
assert response.object == "chat.completion"
assert response.created == 1234567890
assert response.model == "llama-2-7b-chat"
assert len(response.choices) == 1
def test_chat_response_with_usage(self):
"""Test chat response with usage statistics."""
response = ChatResponse(
id="test-id",
created=1234567890,
model="llama-2-7b-chat",
choices=[{
"index": 0,
"message": {"role": "assistant", "content": "Hello!"},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
)
assert response.usage is not None
assert response.usage["prompt_tokens"] == 10
class TestModelInfo:
"""Test ModelInfo model validation and behavior."""
def test_valid_model_info(self):
"""Test creating valid model info."""
model_info = ModelInfo(
id="llama-2-7b-chat",
created=1234567890
)
assert model_info.id == "llama-2-7b-chat"
assert model_info.object == "model"
assert model_info.created == 1234567890
assert model_info.owned_by == "huggingface"
class TestErrorResponse:
"""Test ErrorResponse model validation and behavior."""
def test_valid_error_response(self):
"""Test creating a valid error response."""
error_response = ErrorResponse(
error={
"message": "Invalid request",
"type": "invalid_request_error",
"code": 400
}
)
assert error_response.error["message"] == "Invalid request"
assert error_response.error["type"] == "invalid_request_error"
assert error_response.error["code"] == 400
class TestModelSerialization:
"""Test model serialization and deserialization."""
def test_chat_message_serialization(self):
"""Test ChatMessage JSON serialization."""
message = ChatMessage(role="user", content="Hello!")
data = message.model_dump()
assert data["role"] == "user"
assert data["content"] == "Hello!"
def test_chat_request_serialization(self):
"""Test ChatRequest JSON serialization."""
messages = [ChatMessage(role="user", content="Hello!")]
request = ChatRequest(messages=messages)
data = request.model_dump()
assert "messages" in data
assert len(data["messages"]) == 1
assert data["model"] == "llama-2-7b-chat"
def test_chat_request_deserialization(self):
"""Test ChatRequest JSON deserialization."""
data = {
"messages": [
{"role": "user", "content": "Hello!"}
],
"model": "custom-model",
"max_tokens": 100
}
request = ChatRequest.model_validate(data)
assert len(request.messages) == 1
assert request.model == "custom-model"
assert request.max_tokens == 100