""" Unit tests for server.py — FastAPI endpoints. Validates Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ import hashlib import json from datetime import datetime, timedelta, timezone from unittest.mock import patch import jwt import pytest from httpx import ASGITransport, AsyncClient from server import app # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- TEST_PASSWORD = "test-secret-password" def _jwt_secret_for(password: str) -> str: """Mirror server._jwt_secret() logic.""" return hashlib.sha256(password.encode()).hexdigest() def _make_token(password: str, expired: bool = False) -> str: """Issue a JWT matching the server's signing scheme.""" secret = _jwt_secret_for(password) now = datetime.now(tz=timezone.utc) if expired: exp = now - timedelta(hours=1) else: exp = now + timedelta(hours=24) payload = {"sub": "authenticated", "exp": exp} return jwt.encode(payload, secret, algorithm="HS256") # --------------------------------------------------------------------------- # Tests: Login endpoint # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_login_correct_password(): """Login with correct password returns 200 and a token.""" with patch("server._APP_PASSWORD", TEST_PASSWORD): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/auth/login", json={"password": TEST_PASSWORD} ) assert resp.status_code == 200 data = resp.json() assert "token" in data # Verify the token is a valid JWT secret = _jwt_secret_for(TEST_PASSWORD) decoded = jwt.decode(data["token"], secret, algorithms=["HS256"]) assert decoded["sub"] == "authenticated" @pytest.mark.asyncio async def test_login_wrong_password(): """Login with wrong password returns 401.""" with patch("server._APP_PASSWORD", TEST_PASSWORD): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/auth/login", json={"password": "wrong-password"} ) assert resp.status_code == 401 assert resp.json()["detail"] == "Invalid password" # --------------------------------------------------------------------------- # Tests: Stream endpoint — authentication # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_stream_without_token(): """Stream endpoint without Authorization header returns 401.""" with patch("server._APP_PASSWORD", TEST_PASSWORD): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/query/stream", json={"query": "test question", "history": []}, ) assert resp.status_code == 401 assert resp.json()["detail"] == "Not authenticated" @pytest.mark.asyncio async def test_stream_with_expired_token(): """Stream endpoint with expired JWT returns 401.""" expired_token = _make_token(TEST_PASSWORD, expired=True) with patch("server._APP_PASSWORD", TEST_PASSWORD): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/query/stream", json={"query": "test question", "history": []}, headers={"Authorization": f"Bearer {expired_token}"}, ) assert resp.status_code == 401 assert resp.json()["detail"] == "Not authenticated" # --------------------------------------------------------------------------- # Tests: Stream endpoint — valid token with mocked pipeline # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_stream_with_valid_token(): """Stream endpoint with valid token returns 200 SSE stream.""" valid_token = _make_token(TEST_PASSWORD) mock_retrieval = {"sections": [], "domains": ["example.com"], "timing": {}} mock_chunks = [ "Hello ", "world", { "citations": [{"index": 1, "domain": "example.com", "title": "Test"}], "domains_searched": ["example.com"], "sections_retrieved": 1, "timing": {"total": 0.5}, "token_usage": {"total_tokens": 10}, }, ] with ( patch("server._APP_PASSWORD", TEST_PASSWORD), patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval), patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)), ): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/query/stream", json={"query": "test question", "history": []}, headers={"Authorization": f"Bearer {valid_token}"}, ) assert resp.status_code == 200 assert "text/event-stream" in resp.headers["content-type"] # Parse SSE events from the response body body = resp.text events = _parse_sse(body) # Should have 2 token events and 1 done event token_events = [e for e in events if e["event"] == "token"] done_events = [e for e in events if e["event"] == "done"] assert len(token_events) == 2 assert json.loads(token_events[0]["data"])["text"] == "Hello " assert json.loads(token_events[1]["data"])["text"] == "world" assert len(done_events) == 1 done_data = json.loads(done_events[0]["data"]) assert done_data["citations"] == [ {"index": 1, "domain": "example.com", "title": "Test"} ] assert done_data["domains_searched"] == ["example.com"] assert done_data["sections_retrieved"] == 1 # --------------------------------------------------------------------------- # Tests: APP_PASSWORD unset — all requests authenticated # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_login_without_password_set(): """When APP_PASSWORD is unset, login returns a token unconditionally.""" with patch("server._APP_PASSWORD", None): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/auth/login", json={"password": "anything"} ) assert resp.status_code == 200 assert "token" in resp.json() @pytest.mark.asyncio async def test_stream_without_password_set(): """When APP_PASSWORD is unset, stream endpoint works without a token.""" mock_retrieval = {"sections": [], "domains": [], "timing": {}} mock_chunks = [ "response", {"citations": [], "domains_searched": [], "sections_retrieved": 0, "timing": {}, "token_usage": {}}, ] with ( patch("server._APP_PASSWORD", None), patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval), patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)), ): async with AsyncClient( transport=ASGITransport(app=app), base_url="http://test" ) as client: resp = await client.post( "/api/query/stream", json={"query": "test question", "history": []}, # No Authorization header ) assert resp.status_code == 200 assert "text/event-stream" in resp.headers["content-type"] # --------------------------------------------------------------------------- # SSE parsing helper # --------------------------------------------------------------------------- def _parse_sse(body: str) -> list[dict]: """Parse a raw SSE response body into a list of {event, data} dicts.""" events = [] current_event = None current_data = None for line in body.split("\n"): if line.startswith("event: "): current_event = line[7:] elif line.startswith("data: "): current_data = line[6:] elif line == "" and current_event is not None: events.append({"event": current_event, "data": current_data}) current_event = None current_data = None # Handle case where last event doesn't have trailing newline if current_event is not None and current_data is not None: events.append({"event": current_event, "data": current_data}) return events