File size: 13,285 Bytes
08fc97e 0cf6082 08fc97e 4b785d9 08fc97e 0cf6082 08fc97e d98125d 08fc97e d98125d 08fc97e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | """
FastAPI service: POST /query (RAG pipeline) and GET /health.
Wires together api/hybrid.py (retrieval) and api/generate.py (generation)
behind a Pydantic-validated HTTP surface with audit logging, rate limiting,
and CORS locked down to the Streamlit demo origin only.
Security boundaries enforced here (per CLAUDE.md):
- Rule 5: query text and chunk text never appear in logs; we log a
16-char query hash plus latency/k/model metrics only.
- Rule 7: query length capped at 2000 chars by Pydantic; oversize
requests return 400 with a generic error (no schema details leaked).
- Rule 8: /health returns only `{"status": "ok"|"degraded"}` with no
stack traces, version strings, or schema details.
Run locally: uvicorn api.main:app --reload --port 8000
"""
from __future__ import annotations
import json
import os
import re
import secrets
import time
from pathlib import Path
from typing import Any, Literal
import psycopg
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Form, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from .generate import generate
from .hybrid import retrieve_hybrid
from .logging_config import configure_logging, hash_query
load_dotenv()
audit = configure_logging(os.environ.get("LOG_LEVEL", "INFO"))
ALLOWED_ORIGINS = [
o.strip() for o in os.environ.get("CORS_ORIGIN", "http://localhost:8000").split(",")
if o.strip()
]
DATABASE_URL = os.environ["DATABASE_URL"]
_HERE = Path(__file__).resolve().parent
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(title="rag-psych", version="0.1.0", docs_url=None, redoc_url=None)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=False,
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
)
app.mount("/static", StaticFiles(directory=_HERE / "static"), name="static")
templates = Jinja2Templates(directory=_HERE / "templates")
_CITATION_RE = re.compile(r"\[(\d+)\]")
_EVAL_RESULTS_DIR = _HERE.parent / "eval" / "results"
_basic_auth = HTTPBasic(auto_error=False)
def _require_eval_password(
credentials: HTTPBasicCredentials | None = Depends(_basic_auth),
) -> None:
"""HTTP Basic auth gate for /eval routes.
The username field is accepted but ignored; only the password is
checked against the EVAL_PASSWORD env var. `secrets.compare_digest`
gives us constant-time comparison so password-guessing attempts
can't be timed. If EVAL_PASSWORD is unset the route is sealed off
(no accidental wide-open dashboard).
"""
expected = os.environ.get("EVAL_PASSWORD", "")
supplied = credentials.password if credentials else ""
ok = bool(expected) and secrets.compare_digest(
supplied.encode("utf-8"), expected.encode("utf-8")
)
if not ok:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="unauthorized",
headers={"WWW-Authenticate": 'Basic realm="rag-psych eval"'},
)
SourceType = Literal["mtsamples", "pubmed", "icd11", "icd12"]
class QueryRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=2000)
k: int = Field(default=5, ge=1, le=20)
source_types: list[SourceType] | None = None
class ChunkSummary(BaseModel):
chunk_id: int
source_type: str
section: str | None
title: str | None
chunk_text: str
rerank_score: float
class Latencies(BaseModel):
retrieval_ms: float
generation_ms: float
total_ms: float
class QueryResponse(BaseModel):
answer: str
cited_ids: list[int]
invalid_cited_ids: list[int]
refused: bool
retrieved_chunks: list[ChunkSummary]
model: str
latency: Latencies
@app.exception_handler(RequestValidationError)
async def _validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Normalize Pydantic validation failures to 400 + generic error.
Returning the raw `exc.errors()` would leak field-level schema hints
into the response body, which CLAUDE.md rule 8 forbids.
"""
return JSONResponse(status_code=400, content={"error": "invalid_request"})
@app.get("/", include_in_schema=False)
def root() -> RedirectResponse:
"""Bare Space URL β /ui. Also catches HF Spaces' platform healthcheck."""
return RedirectResponse(url="/ui", status_code=307)
@app.get("/health")
def health() -> JSONResponse:
"""Liveness + DB reachability. No internals leaked on failure."""
try:
with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn:
with conn.cursor() as cur:
cur.execute("SELECT 1")
return JSONResponse({"status": "ok"})
except Exception:
return JSONResponse({"status": "degraded"}, status_code=503)
@app.post("/query", response_model=QueryResponse)
@limiter.limit("5/minute;20/hour;30/day")
def query(request: Request, body: QueryRequest) -> QueryResponse:
"""Run the RAG pipeline end-to-end. See module docstring for guarantees."""
qhash = hash_query(body.query)
audit.info("query_received", extra={"audit": {
"query_hash": qhash, "k": body.k,
"source_types": body.source_types, "client": get_remote_address(request),
}})
t0 = time.perf_counter()
try:
with psycopg.connect(DATABASE_URL) as conn:
t_retrieve_start = time.perf_counter()
hits = retrieve_hybrid(conn, body.query, k=body.k, source_types=body.source_types)
retrieval_ms = (time.perf_counter() - t_retrieve_start) * 1000
gen = generate(body.query, hits)
except Exception:
audit.exception("query_failed", extra={"audit": {"query_hash": qhash}})
raise HTTPException(status_code=500, detail="internal_error")
total_ms = (time.perf_counter() - t0) * 1000
audit.info("query_completed", extra={"audit": {
"query_hash": qhash,
"k": body.k,
"retrieved_count": len(hits),
"cited_count": len(gen.cited_ids),
"invalid_cited_count": len(gen.invalid_cited_ids),
"refused": gen.refused,
"model": gen.model,
"retrieval_ms": round(retrieval_ms, 1),
"generation_ms": round(gen.latency_ms, 1),
"total_ms": round(total_ms, 1),
}})
return QueryResponse(
answer=gen.answer,
cited_ids=gen.cited_ids,
invalid_cited_ids=gen.invalid_cited_ids,
refused=gen.refused,
retrieved_chunks=[
ChunkSummary(
chunk_id=h.hit.chunk_id,
source_type=h.hit.source_type,
section=h.hit.section,
title=h.hit.title,
chunk_text=h.hit.chunk_text,
rerank_score=h.rerank_score,
)
for h in hits
],
model=gen.model,
latency=Latencies(
retrieval_ms=round(retrieval_ms, 1),
generation_ms=round(gen.latency_ms, 1),
total_ms=round(total_ms, 1),
),
)
# βββ HTMX-served UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/ui", response_class=HTMLResponse)
def ui_index(request: Request) -> HTMLResponse:
"""Render the main page. Empty results section; HTMX swaps it in."""
return templates.TemplateResponse(request, "index.html", {})
@app.get("/help", response_class=HTMLResponse)
def ui_help(request: Request) -> HTMLResponse:
"""Static help page: what the system offers, examples, limits, pipeline."""
return templates.TemplateResponse(request, "help.html", {})
# βββ Password-gated /eval dashboard ββββββββββββββββββββββββββββββββββββββββ
@app.get(
"/eval",
response_class=HTMLResponse,
dependencies=[Depends(_require_eval_password)],
)
def eval_dashboard(request: Request) -> HTMLResponse:
"""Password-protected eval visualization dashboard."""
return templates.TemplateResponse(request, "eval.html", {})
@app.get(
"/eval/data",
dependencies=[Depends(_require_eval_password)],
)
def eval_data() -> JSONResponse:
"""JSON feed for the dashboard: run history + live corpus stats."""
return JSONResponse({
"runs": _load_eval_runs(),
"corpus": _corpus_stats(),
})
def _load_eval_runs() -> list[dict[str, Any]]:
"""All eval/results/*.json files, oldest first. Empty list if the
directory hasn't been populated yet (fresh clone, Docker without
the volume mount)."""
if not _EVAL_RESULTS_DIR.is_dir():
return []
runs: list[dict[str, Any]] = []
for path in sorted(_EVAL_RESULTS_DIR.glob("*.json")):
try:
runs.append(json.loads(path.read_text()))
except json.JSONDecodeError:
continue
return runs
def _corpus_stats() -> dict[str, Any]:
"""Live Postgres counts by source, plus top sections per source."""
try:
with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT source_type, COUNT(*) FROM documents GROUP BY 1 ORDER BY 1"
)
docs = {row[0]: row[1] for row in cur.fetchall()}
cur.execute(
"SELECT source_type, COUNT(*) FROM chunks_with_source "
"GROUP BY 1 ORDER BY 1"
)
chunks = {row[0]: row[1] for row in cur.fetchall()}
cur.execute(
"SELECT source_type, section, COUNT(*) AS n "
"FROM chunks_with_source WHERE section IS NOT NULL "
"GROUP BY 1, 2 ORDER BY n DESC LIMIT 40"
)
sections = [
{"source_type": r[0], "section": r[1], "n": r[2]}
for r in cur.fetchall()
]
return {"docs": docs, "chunks": chunks, "sections": sections}
except Exception:
return {"docs": {}, "chunks": {}, "sections": []}
@app.post("/ui/query", response_class=HTMLResponse)
@limiter.limit("5/minute;20/hour;30/day")
def ui_query(
request: Request,
query: str = Form(..., min_length=1, max_length=2000),
k: int = Form(5, ge=1, le=20),
) -> HTMLResponse:
"""HTMX endpoint: returns rendered _results.html fragment for swap-in."""
qhash = hash_query(query)
audit.info("ui_query_received", extra={"audit": {
"query_hash": qhash, "k": k, "client": get_remote_address(request),
}})
t0 = time.perf_counter()
try:
with psycopg.connect(DATABASE_URL) as conn:
t_r = time.perf_counter()
hits = retrieve_hybrid(conn, query, k=k)
retrieval_ms = (time.perf_counter() - t_r) * 1000
gen = generate(query, hits)
except Exception:
audit.exception("ui_query_failed", extra={"audit": {"query_hash": qhash}})
return templates.TemplateResponse(
request, "_error.html", {"message": "Something went wrong. Please try again."},
status_code=500,
)
total_ms = (time.perf_counter() - t0) * 1000
audit.info("ui_query_completed", extra={"audit": {
"query_hash": qhash, "k": k, "retrieved_count": len(hits),
"cited_count": len(gen.cited_ids), "invalid_cited_count": len(gen.invalid_cited_ids),
"refused": gen.refused, "model": gen.model,
"retrieval_ms": round(retrieval_ms, 1),
"generation_ms": round(gen.latency_ms, 1),
"total_ms": round(total_ms, 1),
}})
answer_html = _render_citations(gen.answer)
return templates.TemplateResponse(request, "_results.html", {
"answer_html": answer_html,
"cited_ids": gen.cited_ids,
"invalid_cited_ids": gen.invalid_cited_ids,
"refused": gen.refused,
"hits": hits,
"model": gen.model,
"retrieval_ms": round(retrieval_ms, 0),
"generation_ms": round(gen.latency_ms, 0),
"total_ms": round(total_ms, 0),
})
def _render_citations(answer: str) -> str:
"""Wrap each [chunk_id] in a clickable span GSAP/JS hooks into.
Escapes the text first; chunk IDs are integers from our DB so they're
safe to interpolate, but the surrounding answer is LLM output and must
be HTML-escaped before injecting our spans.
"""
from html import escape
safe = escape(answer)
def _wrap(m: re.Match) -> str:
cid = m.group(1)
return f'<span class="citation" data-chunk="{cid}" tabindex="0">[{cid}]</span>'
return _CITATION_RE.sub(_wrap, safe)
|