hmc-rag / test_server.py
webmuppet
Refactor: replace Streamlit with React+Vite SPA + FastAPI server
b5c0df4
Raw
History Blame Contribute Delete
8.82 kB
"""
Unit tests for server.py — FastAPI endpoints.
Validates Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
import hashlib
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
import jwt
import pytest
from httpx import ASGITransport, AsyncClient
from server import app
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TEST_PASSWORD = "test-secret-password"
def _jwt_secret_for(password: str) -> str:
"""Mirror server._jwt_secret() logic."""
return hashlib.sha256(password.encode()).hexdigest()
def _make_token(password: str, expired: bool = False) -> str:
"""Issue a JWT matching the server's signing scheme."""
secret = _jwt_secret_for(password)
now = datetime.now(tz=timezone.utc)
if expired:
exp = now - timedelta(hours=1)
else:
exp = now + timedelta(hours=24)
payload = {"sub": "authenticated", "exp": exp}
return jwt.encode(payload, secret, algorithm="HS256")
# ---------------------------------------------------------------------------
# Tests: Login endpoint
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_login_correct_password():
"""Login with correct password returns 200 and a token."""
with patch("server._APP_PASSWORD", TEST_PASSWORD):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/auth/login", json={"password": TEST_PASSWORD}
)
assert resp.status_code == 200
data = resp.json()
assert "token" in data
# Verify the token is a valid JWT
secret = _jwt_secret_for(TEST_PASSWORD)
decoded = jwt.decode(data["token"], secret, algorithms=["HS256"])
assert decoded["sub"] == "authenticated"
@pytest.mark.asyncio
async def test_login_wrong_password():
"""Login with wrong password returns 401."""
with patch("server._APP_PASSWORD", TEST_PASSWORD):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/auth/login", json={"password": "wrong-password"}
)
assert resp.status_code == 401
assert resp.json()["detail"] == "Invalid password"
# ---------------------------------------------------------------------------
# Tests: Stream endpoint — authentication
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stream_without_token():
"""Stream endpoint without Authorization header returns 401."""
with patch("server._APP_PASSWORD", TEST_PASSWORD):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/query/stream",
json={"query": "test question", "history": []},
)
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
@pytest.mark.asyncio
async def test_stream_with_expired_token():
"""Stream endpoint with expired JWT returns 401."""
expired_token = _make_token(TEST_PASSWORD, expired=True)
with patch("server._APP_PASSWORD", TEST_PASSWORD):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/query/stream",
json={"query": "test question", "history": []},
headers={"Authorization": f"Bearer {expired_token}"},
)
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
# ---------------------------------------------------------------------------
# Tests: Stream endpoint — valid token with mocked pipeline
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_stream_with_valid_token():
"""Stream endpoint with valid token returns 200 SSE stream."""
valid_token = _make_token(TEST_PASSWORD)
mock_retrieval = {"sections": [], "domains": ["example.com"], "timing": {}}
mock_chunks = [
"Hello ",
"world",
{
"citations": [{"index": 1, "domain": "example.com", "title": "Test"}],
"domains_searched": ["example.com"],
"sections_retrieved": 1,
"timing": {"total": 0.5},
"token_usage": {"total_tokens": 10},
},
]
with (
patch("server._APP_PASSWORD", TEST_PASSWORD),
patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval),
patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/query/stream",
json={"query": "test question", "history": []},
headers={"Authorization": f"Bearer {valid_token}"},
)
assert resp.status_code == 200
assert "text/event-stream" in resp.headers["content-type"]
# Parse SSE events from the response body
body = resp.text
events = _parse_sse(body)
# Should have 2 token events and 1 done event
token_events = [e for e in events if e["event"] == "token"]
done_events = [e for e in events if e["event"] == "done"]
assert len(token_events) == 2
assert json.loads(token_events[0]["data"])["text"] == "Hello "
assert json.loads(token_events[1]["data"])["text"] == "world"
assert len(done_events) == 1
done_data = json.loads(done_events[0]["data"])
assert done_data["citations"] == [
{"index": 1, "domain": "example.com", "title": "Test"}
]
assert done_data["domains_searched"] == ["example.com"]
assert done_data["sections_retrieved"] == 1
# ---------------------------------------------------------------------------
# Tests: APP_PASSWORD unset — all requests authenticated
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_login_without_password_set():
"""When APP_PASSWORD is unset, login returns a token unconditionally."""
with patch("server._APP_PASSWORD", None):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/auth/login", json={"password": "anything"}
)
assert resp.status_code == 200
assert "token" in resp.json()
@pytest.mark.asyncio
async def test_stream_without_password_set():
"""When APP_PASSWORD is unset, stream endpoint works without a token."""
mock_retrieval = {"sections": [], "domains": [], "timing": {}}
mock_chunks = [
"response",
{"citations": [], "domains_searched": [], "sections_retrieved": 0, "timing": {}, "token_usage": {}},
]
with (
patch("server._APP_PASSWORD", None),
patch("src.pipeline.run_query_retrieval", return_value=mock_retrieval),
patch("src.pipeline.run_query_stream", return_value=iter(mock_chunks)),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/query/stream",
json={"query": "test question", "history": []},
# No Authorization header
)
assert resp.status_code == 200
assert "text/event-stream" in resp.headers["content-type"]
# ---------------------------------------------------------------------------
# SSE parsing helper
# ---------------------------------------------------------------------------
def _parse_sse(body: str) -> list[dict]:
"""Parse a raw SSE response body into a list of {event, data} dicts."""
events = []
current_event = None
current_data = None
for line in body.split("\n"):
if line.startswith("event: "):
current_event = line[7:]
elif line.startswith("data: "):
current_data = line[6:]
elif line == "" and current_event is not None:
events.append({"event": current_event, "data": current_data})
current_event = None
current_data = None
# Handle case where last event doesn't have trailing newline
if current_event is not None and current_data is not None:
events.append({"event": current_event, "data": current_data})
return events