import pytest import json import asyncio from httpx import AsyncClient from fastapi.testclient import TestClient from unittest.mock import patch, AsyncMock from app.main import app from app.models import ChatMessage, ChatRequest class TestAPIEndpoints: """Test all API endpoints.""" def test_root_endpoint(self, client): """Test the root endpoint.""" response = client.get("/") assert response.status_code == 200 data = response.json() assert data["message"] == "LLM API - GPT Clone" assert data["version"] == "1.0.0" assert "endpoints" in data def test_health_endpoint(self, client): """Test the health check endpoint.""" response = client.get("/health") assert response.status_code == 200 data = response.json() assert data["status"] == "healthy" assert "model_loaded" in data assert "model_type" in data assert "timestamp" in data def test_models_endpoint(self, client): """Test the models endpoint.""" response = client.get("/v1/models") assert response.status_code == 200 data = response.json() assert data["object"] == "list" assert "data" in data assert len(data["data"]) > 0 model_info = data["data"][0] assert model_info["id"] == "llama-2-7b-chat" assert model_info["object"] == "model" assert model_info["owned_by"] == "huggingface" def test_chat_completions_non_streaming(self, client): """Test chat completions endpoint with non-streaming response.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": False, "max_tokens": 50, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 data = response.json() assert "id" in data assert data["object"] == "chat.completion" assert "choices" in data assert len(data["choices"]) > 0 assert "message" in data["choices"][0] assert data["choices"][0]["finish_reason"] == "stop" def test_chat_completions_streaming(self, client): """Test chat completions endpoint with streaming response.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": True, "max_tokens": 50, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] # Parse SSE response lines = response.text.strip().split("\n") assert len(lines) > 0 # Check that we have SSE events event_lines = [line for line in lines if line.startswith("data: ")] assert len(event_lines) > 0 def test_chat_completions_empty_messages(self, client): """Test chat completions with empty messages.""" request_data = {"messages": [], "stream": False} response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 400 assert "Messages cannot be empty" in response.json()["error"]["message"] def test_chat_completions_invalid_message_format(self, client): """Test chat completions with invalid message format.""" request_data = { "messages": [{"role": "invalid_role", "content": "Hello!"}], "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 # Validation error def test_chat_completions_invalid_parameters(self, client): """Test chat completions with invalid parameters.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "max_tokens": 5000, # Too high "temperature": 3.0, # Too high "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 # Validation error class TestSSEStreaming: """Test Server-Sent Events streaming functionality.""" @pytest.mark.skip( reason="SSE streaming tests have event loop conflicts in test environment" ) def test_sse_response_format(self, client): """Test that SSE response follows correct format.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": True, "max_tokens": 20, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] # Basic SSE format check - just verify we get some response assert len(response.text) > 0 @pytest.mark.skip( reason="SSE streaming tests have event loop conflicts in test environment" ) def test_sse_completion_signal(self, client): """Test that SSE stream ends with completion signal.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": True, "max_tokens": 10, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] # Basic check that we get a response assert len(response.text) > 0 @pytest.mark.skip( reason="SSE streaming tests have event loop conflicts in test environment" ) def test_sse_content_streaming(self, client): """Test that content is actually streamed token by token.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": True, "max_tokens": 20, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 assert "text/event-stream" in response.headers["content-type"] # Basic check that we get a response assert len(response.text) > 0 class TestErrorHandling: """Test error handling in the API.""" def test_invalid_json_request(self, client): """Test handling of invalid JSON in request.""" response = client.post( "/v1/chat/completions", data="invalid json", headers={"Content-Type": "application/json"}, ) assert response.status_code == 422 def test_missing_required_fields(self, client): """Test handling of missing required fields.""" request_data = { "stream": False # Missing messages field } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 def test_invalid_model_parameter(self, client): """Test handling of invalid model parameters.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "max_tokens": -1, # Invalid "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 def test_nonexistent_endpoint(self, client): """Test handling of nonexistent endpoints.""" response = client.get("/nonexistent") assert response.status_code == 404 class TestModelLoading: """Test model loading scenarios.""" def test_health_with_model_loaded(self, client): """Test health endpoint when model is loaded.""" response = client.get("/health") assert response.status_code == 200 data = response.json() # Should work even with mock model assert data["status"] == "healthy" def test_models_endpoint_model_info(self, client): """Test that models endpoint returns correct model information.""" response = client.get("/v1/models") assert response.status_code == 200 data = response.json() model_info = data["data"][0] # Check required fields required_fields = ["id", "object", "created", "owned_by"] for field in required_fields: assert field in model_info class TestConcurrentRequests: """Test handling of concurrent requests.""" def test_multiple_concurrent_requests(self, client): """Test that multiple concurrent requests are handled properly.""" import threading import time results = [] errors = [] def make_request(): try: request_data = { "messages": [{"role": "user", "content": "Hello!"}], "stream": False, "max_tokens": 10, } response = client.post("/v1/chat/completions", json=request_data) results.append(response.status_code) except Exception as e: errors.append(str(e)) # Start multiple threads threads = [] for _ in range(5): thread = threading.Thread(target=make_request) threads.append(thread) thread.start() # Wait for all threads to complete for thread in threads: thread.join() # Check results assert len(errors) == 0, f"Errors occurred: {errors}" assert len(results) == 5 assert all(status == 200 for status in results) class TestAPIValidation: """Test API input validation.""" def test_message_validation(self, client): """Test message structure validation.""" # Test missing content request_data = { "messages": [{"role": "user"}], # Missing content "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422 def test_parameter_bounds(self, client): """Test parameter bounds validation.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "temperature": 0.0, # Valid minimum "top_p": 1.0, # Valid maximum "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 200 def test_parameter_bounds_invalid(self, client): """Test invalid parameter bounds.""" request_data = { "messages": [{"role": "user", "content": "Hello!"}], "temperature": -0.1, # Invalid minimum "stream": False, } response = client.post("/v1/chat/completions", json=request_data) assert response.status_code == 422