""" FastAPI server for the health-marketing-compliance-rag application. This is the entry point for the `refactor` branch, replacing `streamlit run app.py`. Serves the React SPA from frontend/dist/ and exposes API endpoints for authentication and streaming compliance queries. """ import asyncio import hashlib import json import os import sys import threading from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional # Insert PageIndex/ into sys.path so that src/ modules can import from it # (mirrors the sys.path manipulation in app.py). sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "PageIndex")) import jwt from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel app = FastAPI(title="Health Marketing Compliance RAG API") # --------------------------------------------------------------------------- # CORS Middleware # --------------------------------------------------------------------------- # # In production the React SPA is served from the same origin as the API # (FastAPI serves frontend/dist/ at /), so the browser never sends a # cross-origin request and CORS headers are technically unnecessary. # # CORS middleware is only added when ALLOWED_ORIGIN is explicitly set — # this covers local dev (Vite dev server on :5173 calling the API on :7860) # without opening a wildcard on the live deployment. _ALLOWED_ORIGIN = os.environ.get("ALLOWED_ORIGIN") if _ALLOWED_ORIGIN: app.add_middleware( CORSMiddleware, allow_origins=[_ALLOWED_ORIGIN], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _APP_PASSWORD: str | None = os.environ.get("APP_PASSWORD") def _jwt_secret() -> str: """Derive the JWT signing secret from APP_PASSWORD via HMAC-SHA256. When APP_PASSWORD is not set we use a fixed placeholder secret — the password gate is disabled in that case so the secret value is irrelevant. """ password = _APP_PASSWORD or "no-password-set" return hashlib.sha256(password.encode()).hexdigest() def _issue_token() -> str: """Issue a signed HS256 JWT valid for 24 hours.""" now = datetime.now(tz=timezone.utc) payload = { "sub": "authenticated", "exp": now + timedelta(hours=24), } return jwt.encode(payload, _jwt_secret(), algorithm="HS256") def _validate_token(request: Request) -> bool: """Validate the JWT Bearer token from the Authorization header. Returns True if the token is valid, False otherwise. When APP_PASSWORD is not set, always returns True (no auth required). """ if _APP_PASSWORD is None: return True auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return False token = auth_header[7:] # Strip "Bearer " prefix if not token: return False try: jwt.decode(token, _jwt_secret(), algorithms=["HS256"]) return True except Exception: return False # --------------------------------------------------------------------------- # Request / response models # --------------------------------------------------------------------------- class LoginRequest(BaseModel): password: str class TokenResponse(BaseModel): token: str class QueryRequest(BaseModel): query: str history: list[dict] = [] profession: Optional[str] = None language: Optional[str] = None # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.post("/api/auth/login") async def login(body: LoginRequest) -> JSONResponse: """Authenticate with the shared password and return a signed JWT. - If APP_PASSWORD is not set: return a token unconditionally (no check). - If APP_PASSWORD is set and the supplied password matches: return a token. - If APP_PASSWORD is set and the password does not match: return HTTP 401. """ if _APP_PASSWORD is None: # Password gate disabled — issue token unconditionally. return JSONResponse(content={"token": _issue_token()}) if body.password == _APP_PASSWORD: return JSONResponse(content={"token": _issue_token()}) return JSONResponse( status_code=401, content={"detail": "Invalid password"}, ) @app.post("/api/query/stream") async def query_stream(request: Request, body: QueryRequest) -> StreamingResponse: """Stream a compliance query response as Server-Sent Events. Validates the JWT Bearer token, calls the pipeline retrieval and streaming generation steps, and yields SSE events: - event: token — per text chunk from the LLM - event: done — final metadata (citations, timing, etc.) - event: error — on any pipeline exception """ # Validate JWT if not _validate_token(request): return JSONResponse( status_code=401, content={"detail": "Not authenticated"}, ) async def event_generator() -> AsyncGenerator[str, None]: # The pipeline (run_query_retrieval, run_query_stream) is synchronous. # Calling sync blocking code directly in an async generator blocks the # event loop for the full generation duration — FastAPI can't flush SSE # events to the client until the generator returns. Fix: run each sync # step in a thread so the event loop stays free between yields. from src.pipeline import run_query_retrieval, run_query_stream try: # Phase 1: retrieval — blocking, run in a thread retrieval = await asyncio.to_thread( run_query_retrieval, query=body.query, history=body.history, profession=body.profession, language=body.language or None, ) except Exception as exc: error_data = json.dumps({"message": f"Retrieval error: {exc}"}) yield f"event: error\ndata: {error_data}\n\n" return # Phase 2: streaming generation — bridge sync generator → async via queue. # The producer thread calls run_query_stream and posts each chunk into the # queue via call_soon_threadsafe; the async consumer below awaits each item # and yields the SSE event, giving the event loop time to send it. queue: asyncio.Queue = asyncio.Queue() loop = asyncio.get_running_loop() def _produce() -> None: try: for chunk in run_query_stream( query=body.query, retrieval=retrieval, history=body.history, profession=body.profession, ): loop.call_soon_threadsafe(queue.put_nowait, chunk) except Exception as exc: loop.call_soon_threadsafe(queue.put_nowait, {"_error": str(exc)}) finally: loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel threading.Thread(target=_produce, daemon=True).start() while True: chunk = await queue.get() if chunk is None: break if isinstance(chunk, dict) and "_error" in chunk: error_data = json.dumps({"message": f"Pipeline error: {chunk['_error']}"}) yield f"event: error\ndata: {error_data}\n\n" break if isinstance(chunk, str): data = json.dumps({"text": chunk}) yield f"event: token\ndata: {data}\n\n" else: done_data = json.dumps({ "citations": chunk.get("citations", []), "domains_searched": chunk.get("domains_searched", []), "sections_retrieved": chunk.get("sections_retrieved", 0), "timing": chunk.get("timing", {}), "token_usage": chunk.get("token_usage", {}), }) yield f"event: done\ndata: {done_data}\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # --------------------------------------------------------------------------- # Static file serving — mount AFTER API routes so /api/* is not shadowed # --------------------------------------------------------------------------- _FRONTEND_DIR = Path(__file__).resolve().parent / "frontend" / "dist" if _FRONTEND_DIR.is_dir(): app.mount("/", StaticFiles(directory=str(_FRONTEND_DIR), html=True), name="static") else: print( f"WARNING: Frontend dist directory not found at {_FRONTEND_DIR}\n" " All GET / requests will return 404.\n" " Run: cd frontend && npm ci && npm run build" ) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", "7860")) uvicorn.run(app, host="0.0.0.0", port=port)