hmc-rag / server.py
webmuppet
Fix language selector: wire user choice through to generator
016503a
Raw
History Blame Contribute Delete
9.73 kB
"""
FastAPI server for the health-marketing-compliance-rag application.
This is the entry point for the `refactor` branch, replacing `streamlit run app.py`.
Serves the React SPA from frontend/dist/ and exposes API endpoints for
authentication and streaming compliance queries.
"""
import asyncio
import hashlib
import json
import os
import sys
import threading
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
# Insert PageIndex/ into sys.path so that src/ modules can import from it
# (mirrors the sys.path manipulation in app.py).
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "PageIndex"))
import jwt
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
app = FastAPI(title="Health Marketing Compliance RAG API")
# ---------------------------------------------------------------------------
# CORS Middleware
# ---------------------------------------------------------------------------
#
# In production the React SPA is served from the same origin as the API
# (FastAPI serves frontend/dist/ at /), so the browser never sends a
# cross-origin request and CORS headers are technically unnecessary.
#
# CORS middleware is only added when ALLOWED_ORIGIN is explicitly set —
# this covers local dev (Vite dev server on :5173 calling the API on :7860)
# without opening a wildcard on the live deployment.
_ALLOWED_ORIGIN = os.environ.get("ALLOWED_ORIGIN")
if _ALLOWED_ORIGIN:
app.add_middleware(
CORSMiddleware,
allow_origins=[_ALLOWED_ORIGIN],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_APP_PASSWORD: str | None = os.environ.get("APP_PASSWORD")
def _jwt_secret() -> str:
"""Derive the JWT signing secret from APP_PASSWORD via HMAC-SHA256.
When APP_PASSWORD is not set we use a fixed placeholder secret — the
password gate is disabled in that case so the secret value is irrelevant.
"""
password = _APP_PASSWORD or "no-password-set"
return hashlib.sha256(password.encode()).hexdigest()
def _issue_token() -> str:
"""Issue a signed HS256 JWT valid for 24 hours."""
now = datetime.now(tz=timezone.utc)
payload = {
"sub": "authenticated",
"exp": now + timedelta(hours=24),
}
return jwt.encode(payload, _jwt_secret(), algorithm="HS256")
def _validate_token(request: Request) -> bool:
"""Validate the JWT Bearer token from the Authorization header.
Returns True if the token is valid, False otherwise.
When APP_PASSWORD is not set, always returns True (no auth required).
"""
if _APP_PASSWORD is None:
return True
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return False
token = auth_header[7:] # Strip "Bearer " prefix
if not token:
return False
try:
jwt.decode(token, _jwt_secret(), algorithms=["HS256"])
return True
except Exception:
return False
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class LoginRequest(BaseModel):
password: str
class TokenResponse(BaseModel):
token: str
class QueryRequest(BaseModel):
query: str
history: list[dict] = []
profession: Optional[str] = None
language: Optional[str] = None
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.post("/api/auth/login")
async def login(body: LoginRequest) -> JSONResponse:
"""Authenticate with the shared password and return a signed JWT.
- If APP_PASSWORD is not set: return a token unconditionally (no check).
- If APP_PASSWORD is set and the supplied password matches: return a token.
- If APP_PASSWORD is set and the password does not match: return HTTP 401.
"""
if _APP_PASSWORD is None:
# Password gate disabled — issue token unconditionally.
return JSONResponse(content={"token": _issue_token()})
if body.password == _APP_PASSWORD:
return JSONResponse(content={"token": _issue_token()})
return JSONResponse(
status_code=401,
content={"detail": "Invalid password"},
)
@app.post("/api/query/stream")
async def query_stream(request: Request, body: QueryRequest) -> StreamingResponse:
"""Stream a compliance query response as Server-Sent Events.
Validates the JWT Bearer token, calls the pipeline retrieval and streaming
generation steps, and yields SSE events:
- event: token — per text chunk from the LLM
- event: done — final metadata (citations, timing, etc.)
- event: error — on any pipeline exception
"""
# Validate JWT
if not _validate_token(request):
return JSONResponse(
status_code=401,
content={"detail": "Not authenticated"},
)
async def event_generator() -> AsyncGenerator[str, None]:
# The pipeline (run_query_retrieval, run_query_stream) is synchronous.
# Calling sync blocking code directly in an async generator blocks the
# event loop for the full generation duration — FastAPI can't flush SSE
# events to the client until the generator returns. Fix: run each sync
# step in a thread so the event loop stays free between yields.
from src.pipeline import run_query_retrieval, run_query_stream
try:
# Phase 1: retrieval — blocking, run in a thread
retrieval = await asyncio.to_thread(
run_query_retrieval,
query=body.query,
history=body.history,
profession=body.profession,
language=body.language or None,
)
except Exception as exc:
error_data = json.dumps({"message": f"Retrieval error: {exc}"})
yield f"event: error\ndata: {error_data}\n\n"
return
# Phase 2: streaming generation — bridge sync generator → async via queue.
# The producer thread calls run_query_stream and posts each chunk into the
# queue via call_soon_threadsafe; the async consumer below awaits each item
# and yields the SSE event, giving the event loop time to send it.
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def _produce() -> None:
try:
for chunk in run_query_stream(
query=body.query,
retrieval=retrieval,
history=body.history,
profession=body.profession,
):
loop.call_soon_threadsafe(queue.put_nowait, chunk)
except Exception as exc:
loop.call_soon_threadsafe(queue.put_nowait, {"_error": str(exc)})
finally:
loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel
threading.Thread(target=_produce, daemon=True).start()
while True:
chunk = await queue.get()
if chunk is None:
break
if isinstance(chunk, dict) and "_error" in chunk:
error_data = json.dumps({"message": f"Pipeline error: {chunk['_error']}"})
yield f"event: error\ndata: {error_data}\n\n"
break
if isinstance(chunk, str):
data = json.dumps({"text": chunk})
yield f"event: token\ndata: {data}\n\n"
else:
done_data = json.dumps({
"citations": chunk.get("citations", []),
"domains_searched": chunk.get("domains_searched", []),
"sections_retrieved": chunk.get("sections_retrieved", 0),
"timing": chunk.get("timing", {}),
"token_usage": chunk.get("token_usage", {}),
})
yield f"event: done\ndata: {done_data}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Static file serving — mount AFTER API routes so /api/* is not shadowed
# ---------------------------------------------------------------------------
_FRONTEND_DIR = Path(__file__).resolve().parent / "frontend" / "dist"
if _FRONTEND_DIR.is_dir():
app.mount("/", StaticFiles(directory=str(_FRONTEND_DIR), html=True), name="static")
else:
print(
f"WARNING: Frontend dist directory not found at {_FRONTEND_DIR}\n"
" All GET / requests will return 404.\n"
" Run: cd frontend && npm ci && npm run build"
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)