| """ |
| FastAPI service: POST /query (RAG pipeline) and GET /health. |
| |
| Wires together api/hybrid.py (retrieval) and api/generate.py (generation) |
| behind a Pydantic-validated HTTP surface with audit logging, rate limiting, |
| and CORS locked down to the Streamlit demo origin only. |
| |
| Security boundaries enforced here (per CLAUDE.md): |
| - Rule 5: query text and chunk text never appear in logs; we log a |
| 16-char query hash plus latency/k/model metrics only. |
| - Rule 7: query length capped at 2000 chars by Pydantic; oversize |
| requests return 400 with a generic error (no schema details leaked). |
| - Rule 8: /health returns only `{"status": "ok"|"degraded"}` with no |
| stack traces, version strings, or schema details. |
| |
| Run locally: uvicorn api.main:app --reload --port 8000 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import secrets |
| import time |
| from pathlib import Path |
| from typing import Any, Literal |
|
|
| import psycopg |
| from dotenv import load_dotenv |
| from fastapi import Depends, FastAPI, Form, HTTPException, Request, status |
| from fastapi.exceptions import RequestValidationError |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse |
| from fastapi.security import HTTPBasic, HTTPBasicCredentials |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from pydantic import BaseModel, Field |
| from slowapi import Limiter, _rate_limit_exceeded_handler |
| from slowapi.errors import RateLimitExceeded |
| from slowapi.util import get_remote_address |
|
|
| from .generate import generate |
| from .hybrid import retrieve_hybrid |
| from .logging_config import configure_logging, hash_query |
|
|
| load_dotenv() |
| audit = configure_logging(os.environ.get("LOG_LEVEL", "INFO")) |
|
|
| ALLOWED_ORIGINS = [ |
| o.strip() for o in os.environ.get("CORS_ORIGIN", "http://localhost:8000").split(",") |
| if o.strip() |
| ] |
| DATABASE_URL = os.environ["DATABASE_URL"] |
|
|
| _HERE = Path(__file__).resolve().parent |
| limiter = Limiter(key_func=get_remote_address) |
| app = FastAPI(title="rag-psych", version="0.1.0", docs_url=None, redoc_url=None) |
| app.state.limiter = limiter |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=ALLOWED_ORIGINS, |
| allow_credentials=False, |
| allow_methods=["GET", "POST"], |
| allow_headers=["Content-Type"], |
| ) |
| app.mount("/static", StaticFiles(directory=_HERE / "static"), name="static") |
| templates = Jinja2Templates(directory=_HERE / "templates") |
| _CITATION_RE = re.compile(r"\[(\d+)\]") |
|
|
| _EVAL_RESULTS_DIR = _HERE.parent / "eval" / "results" |
| _basic_auth = HTTPBasic(auto_error=False) |
|
|
|
|
| def _require_eval_password( |
| credentials: HTTPBasicCredentials | None = Depends(_basic_auth), |
| ) -> None: |
| """HTTP Basic auth gate for /eval routes. |
| |
| The username field is accepted but ignored; only the password is |
| checked against the EVAL_PASSWORD env var. `secrets.compare_digest` |
| gives us constant-time comparison so password-guessing attempts |
| can't be timed. If EVAL_PASSWORD is unset the route is sealed off |
| (no accidental wide-open dashboard). |
| """ |
| expected = os.environ.get("EVAL_PASSWORD", "") |
| supplied = credentials.password if credentials else "" |
| ok = bool(expected) and secrets.compare_digest( |
| supplied.encode("utf-8"), expected.encode("utf-8") |
| ) |
| if not ok: |
| raise HTTPException( |
| status_code=status.HTTP_401_UNAUTHORIZED, |
| detail="unauthorized", |
| headers={"WWW-Authenticate": 'Basic realm="rag-psych eval"'}, |
| ) |
|
|
| SourceType = Literal["mtsamples", "pubmed", "icd11", "icd12"] |
|
|
|
|
| class QueryRequest(BaseModel): |
| query: str = Field(..., min_length=1, max_length=2000) |
| k: int = Field(default=5, ge=1, le=20) |
| source_types: list[SourceType] | None = None |
|
|
|
|
| class ChunkSummary(BaseModel): |
| chunk_id: int |
| source_type: str |
| section: str | None |
| title: str | None |
| chunk_text: str |
| rerank_score: float |
|
|
|
|
| class Latencies(BaseModel): |
| retrieval_ms: float |
| generation_ms: float |
| total_ms: float |
|
|
|
|
| class QueryResponse(BaseModel): |
| answer: str |
| cited_ids: list[int] |
| invalid_cited_ids: list[int] |
| refused: bool |
| retrieved_chunks: list[ChunkSummary] |
| model: str |
| latency: Latencies |
|
|
|
|
| @app.exception_handler(RequestValidationError) |
| async def _validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse: |
| """Normalize Pydantic validation failures to 400 + generic error. |
| |
| Returning the raw `exc.errors()` would leak field-level schema hints |
| into the response body, which CLAUDE.md rule 8 forbids. |
| """ |
| return JSONResponse(status_code=400, content={"error": "invalid_request"}) |
|
|
|
|
| @app.get("/", include_in_schema=False) |
| def root() -> RedirectResponse: |
| """Bare Space URL β /ui. Also catches HF Spaces' platform healthcheck.""" |
| return RedirectResponse(url="/ui", status_code=307) |
|
|
|
|
| @app.get("/health") |
| def health() -> JSONResponse: |
| """Liveness + DB reachability. No internals leaked on failure.""" |
| try: |
| with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn: |
| with conn.cursor() as cur: |
| cur.execute("SELECT 1") |
| return JSONResponse({"status": "ok"}) |
| except Exception: |
| return JSONResponse({"status": "degraded"}, status_code=503) |
|
|
|
|
| @app.post("/query", response_model=QueryResponse) |
| @limiter.limit("5/minute;20/hour;30/day") |
| def query(request: Request, body: QueryRequest) -> QueryResponse: |
| """Run the RAG pipeline end-to-end. See module docstring for guarantees.""" |
| qhash = hash_query(body.query) |
| audit.info("query_received", extra={"audit": { |
| "query_hash": qhash, "k": body.k, |
| "source_types": body.source_types, "client": get_remote_address(request), |
| }}) |
|
|
| t0 = time.perf_counter() |
| try: |
| with psycopg.connect(DATABASE_URL) as conn: |
| t_retrieve_start = time.perf_counter() |
| hits = retrieve_hybrid(conn, body.query, k=body.k, source_types=body.source_types) |
| retrieval_ms = (time.perf_counter() - t_retrieve_start) * 1000 |
| gen = generate(body.query, hits) |
| except Exception: |
| audit.exception("query_failed", extra={"audit": {"query_hash": qhash}}) |
| raise HTTPException(status_code=500, detail="internal_error") |
|
|
| total_ms = (time.perf_counter() - t0) * 1000 |
| audit.info("query_completed", extra={"audit": { |
| "query_hash": qhash, |
| "k": body.k, |
| "retrieved_count": len(hits), |
| "cited_count": len(gen.cited_ids), |
| "invalid_cited_count": len(gen.invalid_cited_ids), |
| "refused": gen.refused, |
| "model": gen.model, |
| "retrieval_ms": round(retrieval_ms, 1), |
| "generation_ms": round(gen.latency_ms, 1), |
| "total_ms": round(total_ms, 1), |
| }}) |
|
|
| return QueryResponse( |
| answer=gen.answer, |
| cited_ids=gen.cited_ids, |
| invalid_cited_ids=gen.invalid_cited_ids, |
| refused=gen.refused, |
| retrieved_chunks=[ |
| ChunkSummary( |
| chunk_id=h.hit.chunk_id, |
| source_type=h.hit.source_type, |
| section=h.hit.section, |
| title=h.hit.title, |
| chunk_text=h.hit.chunk_text, |
| rerank_score=h.rerank_score, |
| ) |
| for h in hits |
| ], |
| model=gen.model, |
| latency=Latencies( |
| retrieval_ms=round(retrieval_ms, 1), |
| generation_ms=round(gen.latency_ms, 1), |
| total_ms=round(total_ms, 1), |
| ), |
| ) |
|
|
|
|
| |
|
|
|
|
| @app.get("/ui", response_class=HTMLResponse) |
| def ui_index(request: Request) -> HTMLResponse: |
| """Render the main page. Empty results section; HTMX swaps it in.""" |
| return templates.TemplateResponse(request, "index.html", {}) |
|
|
|
|
| @app.get("/help", response_class=HTMLResponse) |
| def ui_help(request: Request) -> HTMLResponse: |
| """Static help page: what the system offers, examples, limits, pipeline.""" |
| return templates.TemplateResponse(request, "help.html", {}) |
|
|
|
|
| |
|
|
|
|
| @app.get( |
| "/eval", |
| response_class=HTMLResponse, |
| dependencies=[Depends(_require_eval_password)], |
| ) |
| def eval_dashboard(request: Request) -> HTMLResponse: |
| """Password-protected eval visualization dashboard.""" |
| return templates.TemplateResponse(request, "eval.html", {}) |
|
|
|
|
| @app.get( |
| "/eval/data", |
| dependencies=[Depends(_require_eval_password)], |
| ) |
| def eval_data() -> JSONResponse: |
| """JSON feed for the dashboard: run history + live corpus stats.""" |
| return JSONResponse({ |
| "runs": _load_eval_runs(), |
| "corpus": _corpus_stats(), |
| }) |
|
|
|
|
| def _load_eval_runs() -> list[dict[str, Any]]: |
| """All eval/results/*.json files, oldest first. Empty list if the |
| directory hasn't been populated yet (fresh clone, Docker without |
| the volume mount).""" |
| if not _EVAL_RESULTS_DIR.is_dir(): |
| return [] |
| runs: list[dict[str, Any]] = [] |
| for path in sorted(_EVAL_RESULTS_DIR.glob("*.json")): |
| try: |
| runs.append(json.loads(path.read_text())) |
| except json.JSONDecodeError: |
| continue |
| return runs |
|
|
|
|
| def _corpus_stats() -> dict[str, Any]: |
| """Live Postgres counts by source, plus top sections per source.""" |
| try: |
| with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn: |
| with conn.cursor() as cur: |
| cur.execute( |
| "SELECT source_type, COUNT(*) FROM documents GROUP BY 1 ORDER BY 1" |
| ) |
| docs = {row[0]: row[1] for row in cur.fetchall()} |
| cur.execute( |
| "SELECT source_type, COUNT(*) FROM chunks_with_source " |
| "GROUP BY 1 ORDER BY 1" |
| ) |
| chunks = {row[0]: row[1] for row in cur.fetchall()} |
| cur.execute( |
| "SELECT source_type, section, COUNT(*) AS n " |
| "FROM chunks_with_source WHERE section IS NOT NULL " |
| "GROUP BY 1, 2 ORDER BY n DESC LIMIT 40" |
| ) |
| sections = [ |
| {"source_type": r[0], "section": r[1], "n": r[2]} |
| for r in cur.fetchall() |
| ] |
| return {"docs": docs, "chunks": chunks, "sections": sections} |
| except Exception: |
| return {"docs": {}, "chunks": {}, "sections": []} |
|
|
|
|
| @app.post("/ui/query", response_class=HTMLResponse) |
| @limiter.limit("5/minute;20/hour;30/day") |
| def ui_query( |
| request: Request, |
| query: str = Form(..., min_length=1, max_length=2000), |
| k: int = Form(5, ge=1, le=20), |
| ) -> HTMLResponse: |
| """HTMX endpoint: returns rendered _results.html fragment for swap-in.""" |
| qhash = hash_query(query) |
| audit.info("ui_query_received", extra={"audit": { |
| "query_hash": qhash, "k": k, "client": get_remote_address(request), |
| }}) |
| t0 = time.perf_counter() |
| try: |
| with psycopg.connect(DATABASE_URL) as conn: |
| t_r = time.perf_counter() |
| hits = retrieve_hybrid(conn, query, k=k) |
| retrieval_ms = (time.perf_counter() - t_r) * 1000 |
| gen = generate(query, hits) |
| except Exception: |
| audit.exception("ui_query_failed", extra={"audit": {"query_hash": qhash}}) |
| return templates.TemplateResponse( |
| request, "_error.html", {"message": "Something went wrong. Please try again."}, |
| status_code=500, |
| ) |
| total_ms = (time.perf_counter() - t0) * 1000 |
| audit.info("ui_query_completed", extra={"audit": { |
| "query_hash": qhash, "k": k, "retrieved_count": len(hits), |
| "cited_count": len(gen.cited_ids), "invalid_cited_count": len(gen.invalid_cited_ids), |
| "refused": gen.refused, "model": gen.model, |
| "retrieval_ms": round(retrieval_ms, 1), |
| "generation_ms": round(gen.latency_ms, 1), |
| "total_ms": round(total_ms, 1), |
| }}) |
|
|
| answer_html = _render_citations(gen.answer) |
| return templates.TemplateResponse(request, "_results.html", { |
| "answer_html": answer_html, |
| "cited_ids": gen.cited_ids, |
| "invalid_cited_ids": gen.invalid_cited_ids, |
| "refused": gen.refused, |
| "hits": hits, |
| "model": gen.model, |
| "retrieval_ms": round(retrieval_ms, 0), |
| "generation_ms": round(gen.latency_ms, 0), |
| "total_ms": round(total_ms, 0), |
| }) |
|
|
|
|
| def _render_citations(answer: str) -> str: |
| """Wrap each [chunk_id] in a clickable span GSAP/JS hooks into. |
| |
| Escapes the text first; chunk IDs are integers from our DB so they're |
| safe to interpolate, but the surrounding answer is LLM output and must |
| be HTML-escaped before injecting our spans. |
| """ |
| from html import escape |
| safe = escape(answer) |
|
|
| def _wrap(m: re.Match) -> str: |
| cid = m.group(1) |
| return f'<span class="citation" data-chunk="{cid}" tabindex="0">[{cid}]</span>' |
| return _CITATION_RE.sub(_wrap, safe) |
|
|