File size: 10,911 Bytes
c2f9396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import pytest
import json
import asyncio
from httpx import AsyncClient
from fastapi.testclient import TestClient
from unittest.mock import patch, AsyncMock

from app.main import app
from app.models import ChatMessage, ChatRequest


class TestAPIEndpoints:
    """Test all API endpoints."""

    def test_root_endpoint(self, client):
        """Test the root endpoint."""
        response = client.get("/")
        assert response.status_code == 200

        data = response.json()
        assert data["message"] == "LLM API - GPT Clone"
        assert data["version"] == "1.0.0"
        assert "endpoints" in data

    def test_health_endpoint(self, client):
        """Test the health check endpoint."""
        response = client.get("/health")
        assert response.status_code == 200

        data = response.json()
        assert data["status"] == "healthy"
        assert "model_loaded" in data
        assert "model_type" in data
        assert "timestamp" in data

    def test_models_endpoint(self, client):
        """Test the models endpoint."""
        response = client.get("/v1/models")
        assert response.status_code == 200

        data = response.json()
        assert data["object"] == "list"
        assert "data" in data
        assert len(data["data"]) > 0

        model_info = data["data"][0]
        assert model_info["id"] == "llama-2-7b-chat"
        assert model_info["object"] == "model"
        assert model_info["owned_by"] == "huggingface"

    def test_chat_completions_non_streaming(self, client):
        """Test chat completions endpoint with non-streaming response."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "stream": False,
            "max_tokens": 50,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200

        data = response.json()
        assert "id" in data
        assert data["object"] == "chat.completion"
        assert "choices" in data
        assert len(data["choices"]) > 0
        assert "message" in data["choices"][0]
        assert data["choices"][0]["finish_reason"] == "stop"

    def test_chat_completions_streaming(self, client):
        """Test chat completions endpoint with streaming response."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "stream": True,
            "max_tokens": 50,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200
        assert "text/event-stream" in response.headers["content-type"]

        # Parse SSE response
        lines = response.text.strip().split("\n")
        assert len(lines) > 0

        # Check that we have SSE events
        event_lines = [line for line in lines if line.startswith("data: ")]
        assert len(event_lines) > 0

    def test_chat_completions_empty_messages(self, client):
        """Test chat completions with empty messages."""
        request_data = {"messages": [], "stream": False}

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 400
        assert "Messages cannot be empty" in response.json()["error"]["message"]

    def test_chat_completions_invalid_message_format(self, client):
        """Test chat completions with invalid message format."""
        request_data = {
            "messages": [{"role": "invalid_role", "content": "Hello!"}],
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422  # Validation error

    def test_chat_completions_invalid_parameters(self, client):
        """Test chat completions with invalid parameters."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "max_tokens": 5000,  # Too high
            "temperature": 3.0,  # Too high
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422  # Validation error


class TestSSEStreaming:
    """Test Server-Sent Events streaming functionality."""

    @pytest.mark.skip(
        reason="SSE streaming tests have event loop conflicts in test environment"
    )
    def test_sse_response_format(self, client):
        """Test that SSE response follows correct format."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "stream": True,
            "max_tokens": 20,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200
        assert "text/event-stream" in response.headers["content-type"]

        # Basic SSE format check - just verify we get some response
        assert len(response.text) > 0

    @pytest.mark.skip(
        reason="SSE streaming tests have event loop conflicts in test environment"
    )
    def test_sse_completion_signal(self, client):
        """Test that SSE stream ends with completion signal."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "stream": True,
            "max_tokens": 10,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200
        assert "text/event-stream" in response.headers["content-type"]

        # Basic check that we get a response
        assert len(response.text) > 0

    @pytest.mark.skip(
        reason="SSE streaming tests have event loop conflicts in test environment"
    )
    def test_sse_content_streaming(self, client):
        """Test that content is actually streamed token by token."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "stream": True,
            "max_tokens": 20,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200
        assert "text/event-stream" in response.headers["content-type"]

        # Basic check that we get a response
        assert len(response.text) > 0


class TestErrorHandling:
    """Test error handling in the API."""

    def test_invalid_json_request(self, client):
        """Test handling of invalid JSON in request."""
        response = client.post(
            "/v1/chat/completions",
            data="invalid json",
            headers={"Content-Type": "application/json"},
        )
        assert response.status_code == 422

    def test_missing_required_fields(self, client):
        """Test handling of missing required fields."""
        request_data = {
            "stream": False
            # Missing messages field
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422

    def test_invalid_model_parameter(self, client):
        """Test handling of invalid model parameters."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "max_tokens": -1,  # Invalid
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422

    def test_nonexistent_endpoint(self, client):
        """Test handling of nonexistent endpoints."""
        response = client.get("/nonexistent")
        assert response.status_code == 404


class TestModelLoading:
    """Test model loading scenarios."""

    def test_health_with_model_loaded(self, client):
        """Test health endpoint when model is loaded."""
        response = client.get("/health")
        assert response.status_code == 200

        data = response.json()
        # Should work even with mock model
        assert data["status"] == "healthy"

    def test_models_endpoint_model_info(self, client):
        """Test that models endpoint returns correct model information."""
        response = client.get("/v1/models")
        assert response.status_code == 200

        data = response.json()
        model_info = data["data"][0]

        # Check required fields
        required_fields = ["id", "object", "created", "owned_by"]
        for field in required_fields:
            assert field in model_info


class TestConcurrentRequests:
    """Test handling of concurrent requests."""

    def test_multiple_concurrent_requests(self, client):
        """Test that multiple concurrent requests are handled properly."""
        import threading
        import time

        results = []
        errors = []

        def make_request():
            try:
                request_data = {
                    "messages": [{"role": "user", "content": "Hello!"}],
                    "stream": False,
                    "max_tokens": 10,
                }

                response = client.post("/v1/chat/completions", json=request_data)
                results.append(response.status_code)
            except Exception as e:
                errors.append(str(e))

        # Start multiple threads
        threads = []
        for _ in range(5):
            thread = threading.Thread(target=make_request)
            threads.append(thread)
            thread.start()

        # Wait for all threads to complete
        for thread in threads:
            thread.join()

        # Check results
        assert len(errors) == 0, f"Errors occurred: {errors}"
        assert len(results) == 5
        assert all(status == 200 for status in results)


class TestAPIValidation:
    """Test API input validation."""

    def test_message_validation(self, client):
        """Test message structure validation."""
        # Test missing content
        request_data = {
            "messages": [{"role": "user"}],  # Missing content
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422

    def test_parameter_bounds(self, client):
        """Test parameter bounds validation."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "temperature": 0.0,  # Valid minimum
            "top_p": 1.0,  # Valid maximum
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 200

    def test_parameter_bounds_invalid(self, client):
        """Test invalid parameter bounds."""
        request_data = {
            "messages": [{"role": "user", "content": "Hello!"}],
            "temperature": -0.1,  # Invalid minimum
            "stream": False,
        }

        response = client.post("/v1/chat/completions", json=request_data)
        assert response.status_code == 422