Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for server.py — FastAPI endpoints. | |
| Validates Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 | |
| """ | |
| import hashlib | |
| import json | |
| from datetime import datetime, timedelta, timezone | |
| from unittest.mock import patch | |
| import jwt | |
| import pytest | |
| from httpx import ASGITransport, AsyncClient | |
| from server import app | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| TEST_PASSWORD = "test-secret-password" | |
| def _jwt_secret_for(password: str) -> str: | |
| """Mirror server._jwt_secret() logic.""" | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| def _make_token(password: str, expired: bool = False) -> str: | |
| """Issue a JWT matching the server's signing scheme.""" | |
| secret = _jwt_secret_for(password) | |
| now = datetime.now(tz=timezone.utc) | |
| if expired: | |
| exp = now - timedelta(hours=1) | |
| else: | |
| exp = now + timedelta(hours=24) | |
| payload = {"sub": "authenticated", "exp": exp} | |
| return jwt.encode(payload, secret, algorithm="HS256") | |
| # --------------------------------------------------------------------------- | |
| # Tests: Login endpoint | |
| # --------------------------------------------------------------------------- | |
| async def test_login_correct_password(): | |
| """Login with correct password returns 200 and a token.""" | |
| with patch("server._APP_PASSWORD", TEST_PASSWORD): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/auth/login", json={"password": TEST_PASSWORD} | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "token" in data | |
| # Verify the token is a valid JWT | |
| secret = _jwt_secret_for(TEST_PASSWORD) | |
| decoded = jwt.decode(data["token"], secret, algorithms=["HS256"]) | |
| assert decoded["sub"] == "authenticated" | |
| async def test_login_wrong_password(): | |
| """Login with wrong password returns 401.""" | |
| with patch("server._APP_PASSWORD", TEST_PASSWORD): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/auth/login", json={"password": "wrong-password"} | |
| ) | |
| assert resp.status_code == 401 | |
| assert resp.json()["detail"] == "Invalid password" | |
| # --------------------------------------------------------------------------- | |
| # Tests: Stream endpoint — authentication | |
| # --------------------------------------------------------------------------- | |
| async def test_stream_without_token(): | |
| """Stream endpoint without Authorization header returns 401.""" | |
| with patch("server._APP_PASSWORD", TEST_PASSWORD): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/query/stream", | |
| json={"query": "test question", "history": []}, | |
| ) | |
| assert resp.status_code == 401 | |
| assert resp.json()["detail"] == "Not authenticated" | |
| async def test_stream_with_expired_token(): | |
| """Stream endpoint with expired JWT returns 401.""" | |
| expired_token = _make_token(TEST_PASSWORD, expired=True) | |
| with patch("server._APP_PASSWORD", TEST_PASSWORD): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/query/stream", | |
| json={"query": "test question", "history": []}, | |
| headers={"Authorization": f"Bearer {expired_token}"}, | |
| ) | |
| assert resp.status_code == 401 | |
| assert resp.json()["detail"] == "Not authenticated" | |
| # --------------------------------------------------------------------------- | |
| # Tests: Stream endpoint — valid token with mocked pipeline | |
| # --------------------------------------------------------------------------- | |
| async def test_stream_with_valid_token(): | |
| """Stream endpoint with valid token returns 200 SSE stream.""" | |
| valid_token = _make_token(TEST_PASSWORD) | |
| mock_retrieval = {"sections": [], "domains": ["example.com"], "timing": {}} | |
| mock_chunks = [ | |
| "Hello ", | |
| "world", | |
| { | |
| "citations": [{"index": 1, "domain": "example.com", "title": "Test"}], | |
| "domains_searched": ["example.com"], | |
| "sections_retrieved": 1, | |
| "timing": {"total": 0.5}, | |
| "token_usage": {"total_tokens": 10}, | |
| }, | |
| ] | |
| with ( | |
| patch("server._APP_PASSWORD", TEST_PASSWORD), | |
| patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval), | |
| patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)), | |
| ): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/query/stream", | |
| json={"query": "test question", "history": []}, | |
| headers={"Authorization": f"Bearer {valid_token}"}, | |
| ) | |
| assert resp.status_code == 200 | |
| assert "text/event-stream" in resp.headers["content-type"] | |
| # Parse SSE events from the response body | |
| body = resp.text | |
| events = _parse_sse(body) | |
| # Should have 2 token events and 1 done event | |
| token_events = [e for e in events if e["event"] == "token"] | |
| done_events = [e for e in events if e["event"] == "done"] | |
| assert len(token_events) == 2 | |
| assert json.loads(token_events[0]["data"])["text"] == "Hello " | |
| assert json.loads(token_events[1]["data"])["text"] == "world" | |
| assert len(done_events) == 1 | |
| done_data = json.loads(done_events[0]["data"]) | |
| assert done_data["citations"] == [ | |
| {"index": 1, "domain": "example.com", "title": "Test"} | |
| ] | |
| assert done_data["domains_searched"] == ["example.com"] | |
| assert done_data["sections_retrieved"] == 1 | |
| # --------------------------------------------------------------------------- | |
| # Tests: APP_PASSWORD unset — all requests authenticated | |
| # --------------------------------------------------------------------------- | |
| async def test_login_without_password_set(): | |
| """When APP_PASSWORD is unset, login returns a token unconditionally.""" | |
| with patch("server._APP_PASSWORD", None): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/auth/login", json={"password": "anything"} | |
| ) | |
| assert resp.status_code == 200 | |
| assert "token" in resp.json() | |
| async def test_stream_without_password_set(): | |
| """When APP_PASSWORD is unset, stream endpoint works without a token.""" | |
| mock_retrieval = {"sections": [], "domains": [], "timing": {}} | |
| mock_chunks = [ | |
| "response", | |
| {"citations": [], "domains_searched": [], "sections_retrieved": 0, "timing": {}, "token_usage": {}}, | |
| ] | |
| with ( | |
| patch("server._APP_PASSWORD", None), | |
| patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval), | |
| patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)), | |
| ): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/query/stream", | |
| json={"query": "test question", "history": []}, | |
| # No Authorization header | |
| ) | |
| assert resp.status_code == 200 | |
| assert "text/event-stream" in resp.headers["content-type"] | |
| # --------------------------------------------------------------------------- | |
| # SSE parsing helper | |
| # --------------------------------------------------------------------------- | |
| def _parse_sse(body: str) -> list[dict]: | |
| """Parse a raw SSE response body into a list of {event, data} dicts.""" | |
| events = [] | |
| current_event = None | |
| current_data = None | |
| for line in body.split("\n"): | |
| if line.startswith("event: "): | |
| current_event = line[7:] | |
| elif line.startswith("data: "): | |
| current_data = line[6:] | |
| elif line == "" and current_event is not None: | |
| events.append({"event": current_event, "data": current_data}) | |
| current_event = None | |
| current_data = None | |
| # Handle case where last event doesn't have trailing newline | |
| if current_event is not None and current_data is not None: | |
| events.append({"event": current_event, "data": current_data}) | |
| return events | |