Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server for the health-marketing-compliance-rag application. | |
| This is the entry point for the `refactor` branch, replacing `streamlit run app.py`. | |
| Serves the React SPA from frontend/dist/ and exposes API endpoints for | |
| authentication and streaming compliance queries. | |
| """ | |
| import asyncio | |
| import hashlib | |
| import json | |
| import os | |
| import sys | |
| import threading | |
| from collections.abc import AsyncGenerator | |
| from datetime import datetime, timedelta, timezone | |
| from pathlib import Path | |
| from typing import Optional | |
| # Insert PageIndex/ into sys.path so that src/ modules can import from it | |
| # (mirrors the sys.path manipulation in app.py). | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "PageIndex")) | |
| import jwt | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| app = FastAPI(title="Health Marketing Compliance RAG API") | |
| # --------------------------------------------------------------------------- | |
| # CORS Middleware | |
| # --------------------------------------------------------------------------- | |
| # | |
| # In production the React SPA is served from the same origin as the API | |
| # (FastAPI serves frontend/dist/ at /), so the browser never sends a | |
| # cross-origin request and CORS headers are technically unnecessary. | |
| # | |
| # CORS middleware is only added when ALLOWED_ORIGIN is explicitly set — | |
| # this covers local dev (Vite dev server on :5173 calling the API on :7860) | |
| # without opening a wildcard on the live deployment. | |
| _ALLOWED_ORIGIN = os.environ.get("ALLOWED_ORIGIN") | |
| if _ALLOWED_ORIGIN: | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[_ALLOWED_ORIGIN], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| _APP_PASSWORD: str | None = os.environ.get("APP_PASSWORD") | |
| def _jwt_secret() -> str: | |
| """Derive the JWT signing secret from APP_PASSWORD via HMAC-SHA256. | |
| When APP_PASSWORD is not set we use a fixed placeholder secret — the | |
| password gate is disabled in that case so the secret value is irrelevant. | |
| """ | |
| password = _APP_PASSWORD or "no-password-set" | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| def _issue_token() -> str: | |
| """Issue a signed HS256 JWT valid for 24 hours.""" | |
| now = datetime.now(tz=timezone.utc) | |
| payload = { | |
| "sub": "authenticated", | |
| "exp": now + timedelta(hours=24), | |
| } | |
| return jwt.encode(payload, _jwt_secret(), algorithm="HS256") | |
| def _validate_token(request: Request) -> bool: | |
| """Validate the JWT Bearer token from the Authorization header. | |
| Returns True if the token is valid, False otherwise. | |
| When APP_PASSWORD is not set, always returns True (no auth required). | |
| """ | |
| if _APP_PASSWORD is None: | |
| return True | |
| auth_header = request.headers.get("authorization", "") | |
| if not auth_header.startswith("Bearer "): | |
| return False | |
| token = auth_header[7:] # Strip "Bearer " prefix | |
| if not token: | |
| return False | |
| try: | |
| jwt.decode(token, _jwt_secret(), algorithms=["HS256"]) | |
| return True | |
| except Exception: | |
| return False | |
| # --------------------------------------------------------------------------- | |
| # Request / response models | |
| # --------------------------------------------------------------------------- | |
| class LoginRequest(BaseModel): | |
| password: str | |
| class TokenResponse(BaseModel): | |
| token: str | |
| class QueryRequest(BaseModel): | |
| query: str | |
| history: list[dict] = [] | |
| profession: Optional[str] = None | |
| language: Optional[str] = None | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def login(body: LoginRequest) -> JSONResponse: | |
| """Authenticate with the shared password and return a signed JWT. | |
| - If APP_PASSWORD is not set: return a token unconditionally (no check). | |
| - If APP_PASSWORD is set and the supplied password matches: return a token. | |
| - If APP_PASSWORD is set and the password does not match: return HTTP 401. | |
| """ | |
| if _APP_PASSWORD is None: | |
| # Password gate disabled — issue token unconditionally. | |
| return JSONResponse(content={"token": _issue_token()}) | |
| if body.password == _APP_PASSWORD: | |
| return JSONResponse(content={"token": _issue_token()}) | |
| return JSONResponse( | |
| status_code=401, | |
| content={"detail": "Invalid password"}, | |
| ) | |
| async def query_stream(request: Request, body: QueryRequest) -> StreamingResponse: | |
| """Stream a compliance query response as Server-Sent Events. | |
| Validates the JWT Bearer token, calls the pipeline retrieval and streaming | |
| generation steps, and yields SSE events: | |
| - event: token — per text chunk from the LLM | |
| - event: done — final metadata (citations, timing, etc.) | |
| - event: error — on any pipeline exception | |
| """ | |
| # Validate JWT | |
| if not _validate_token(request): | |
| return JSONResponse( | |
| status_code=401, | |
| content={"detail": "Not authenticated"}, | |
| ) | |
| async def event_generator() -> AsyncGenerator[str, None]: | |
| # The pipeline (run_query_retrieval, run_query_stream) is synchronous. | |
| # Calling sync blocking code directly in an async generator blocks the | |
| # event loop for the full generation duration — FastAPI can't flush SSE | |
| # events to the client until the generator returns. Fix: run each sync | |
| # step in a thread so the event loop stays free between yields. | |
| from src.pipeline import run_query_retrieval, run_query_stream | |
| try: | |
| # Phase 1: retrieval — blocking, run in a thread | |
| retrieval = await asyncio.to_thread( | |
| run_query_retrieval, | |
| query=body.query, | |
| history=body.history, | |
| profession=body.profession, | |
| language=body.language or None, | |
| ) | |
| except Exception as exc: | |
| error_data = json.dumps({"message": f"Retrieval error: {exc}"}) | |
| yield f"event: error\ndata: {error_data}\n\n" | |
| return | |
| # Phase 2: streaming generation — bridge sync generator → async via queue. | |
| # The producer thread calls run_query_stream and posts each chunk into the | |
| # queue via call_soon_threadsafe; the async consumer below awaits each item | |
| # and yields the SSE event, giving the event loop time to send it. | |
| queue: asyncio.Queue = asyncio.Queue() | |
| loop = asyncio.get_running_loop() | |
| def _produce() -> None: | |
| try: | |
| for chunk in run_query_stream( | |
| query=body.query, | |
| retrieval=retrieval, | |
| history=body.history, | |
| profession=body.profession, | |
| ): | |
| loop.call_soon_threadsafe(queue.put_nowait, chunk) | |
| except Exception as exc: | |
| loop.call_soon_threadsafe(queue.put_nowait, {"_error": str(exc)}) | |
| finally: | |
| loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel | |
| threading.Thread(target=_produce, daemon=True).start() | |
| while True: | |
| chunk = await queue.get() | |
| if chunk is None: | |
| break | |
| if isinstance(chunk, dict) and "_error" in chunk: | |
| error_data = json.dumps({"message": f"Pipeline error: {chunk['_error']}"}) | |
| yield f"event: error\ndata: {error_data}\n\n" | |
| break | |
| if isinstance(chunk, str): | |
| data = json.dumps({"text": chunk}) | |
| yield f"event: token\ndata: {data}\n\n" | |
| else: | |
| done_data = json.dumps({ | |
| "citations": chunk.get("citations", []), | |
| "domains_searched": chunk.get("domains_searched", []), | |
| "sections_retrieved": chunk.get("sections_retrieved", 0), | |
| "timing": chunk.get("timing", {}), | |
| "token_usage": chunk.get("token_usage", {}), | |
| }) | |
| yield f"event: done\ndata: {done_data}\n\n" | |
| return StreamingResponse( | |
| event_generator(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Static file serving — mount AFTER API routes so /api/* is not shadowed | |
| # --------------------------------------------------------------------------- | |
| _FRONTEND_DIR = Path(__file__).resolve().parent / "frontend" / "dist" | |
| if _FRONTEND_DIR.is_dir(): | |
| app.mount("/", StaticFiles(directory=str(_FRONTEND_DIR), html=True), name="static") | |
| else: | |
| print( | |
| f"WARNING: Frontend dist directory not found at {_FRONTEND_DIR}\n" | |
| " All GET / requests will return 404.\n" | |
| " Run: cd frontend && npm ci && npm run build" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |