File size: 13,285 Bytes
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cf6082
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b785d9
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cf6082
 
 
 
 
 
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
d98125d
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98125d
08fc97e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""
FastAPI service: POST /query (RAG pipeline) and GET /health.

Wires together api/hybrid.py (retrieval) and api/generate.py (generation)
behind a Pydantic-validated HTTP surface with audit logging, rate limiting,
and CORS locked down to the Streamlit demo origin only.

Security boundaries enforced here (per CLAUDE.md):
  - Rule 5: query text and chunk text never appear in logs; we log a
    16-char query hash plus latency/k/model metrics only.
  - Rule 7: query length capped at 2000 chars by Pydantic; oversize
    requests return 400 with a generic error (no schema details leaked).
  - Rule 8: /health returns only `{"status": "ok"|"degraded"}` with no
    stack traces, version strings, or schema details.

Run locally:  uvicorn api.main:app --reload --port 8000
"""

from __future__ import annotations

import json
import os
import re
import secrets
import time
from pathlib import Path
from typing import Any, Literal

import psycopg
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Form, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from .generate import generate
from .hybrid import retrieve_hybrid
from .logging_config import configure_logging, hash_query

load_dotenv()
audit = configure_logging(os.environ.get("LOG_LEVEL", "INFO"))

ALLOWED_ORIGINS = [
    o.strip() for o in os.environ.get("CORS_ORIGIN", "http://localhost:8000").split(",")
    if o.strip()
]
DATABASE_URL = os.environ["DATABASE_URL"]

_HERE = Path(__file__).resolve().parent
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(title="rag-psych", version="0.1.0", docs_url=None, redoc_url=None)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=False,
    allow_methods=["GET", "POST"],
    allow_headers=["Content-Type"],
)
app.mount("/static", StaticFiles(directory=_HERE / "static"), name="static")
templates = Jinja2Templates(directory=_HERE / "templates")
_CITATION_RE = re.compile(r"\[(\d+)\]")

_EVAL_RESULTS_DIR = _HERE.parent / "eval" / "results"
_basic_auth = HTTPBasic(auto_error=False)


def _require_eval_password(
    credentials: HTTPBasicCredentials | None = Depends(_basic_auth),
) -> None:
    """HTTP Basic auth gate for /eval routes.

    The username field is accepted but ignored; only the password is
    checked against the EVAL_PASSWORD env var. `secrets.compare_digest`
    gives us constant-time comparison so password-guessing attempts
    can't be timed. If EVAL_PASSWORD is unset the route is sealed off
    (no accidental wide-open dashboard).
    """
    expected = os.environ.get("EVAL_PASSWORD", "")
    supplied = credentials.password if credentials else ""
    ok = bool(expected) and secrets.compare_digest(
        supplied.encode("utf-8"), expected.encode("utf-8")
    )
    if not ok:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="unauthorized",
            headers={"WWW-Authenticate": 'Basic realm="rag-psych eval"'},
        )

SourceType = Literal["mtsamples", "pubmed", "icd11", "icd12"]


class QueryRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=2000)
    k: int = Field(default=5, ge=1, le=20)
    source_types: list[SourceType] | None = None


class ChunkSummary(BaseModel):
    chunk_id: int
    source_type: str
    section: str | None
    title: str | None
    chunk_text: str
    rerank_score: float


class Latencies(BaseModel):
    retrieval_ms: float
    generation_ms: float
    total_ms: float


class QueryResponse(BaseModel):
    answer: str
    cited_ids: list[int]
    invalid_cited_ids: list[int]
    refused: bool
    retrieved_chunks: list[ChunkSummary]
    model: str
    latency: Latencies


@app.exception_handler(RequestValidationError)
async def _validation_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
    """Normalize Pydantic validation failures to 400 + generic error.

    Returning the raw `exc.errors()` would leak field-level schema hints
    into the response body, which CLAUDE.md rule 8 forbids.
    """
    return JSONResponse(status_code=400, content={"error": "invalid_request"})


@app.get("/", include_in_schema=False)
def root() -> RedirectResponse:
    """Bare Space URL β†’ /ui. Also catches HF Spaces' platform healthcheck."""
    return RedirectResponse(url="/ui", status_code=307)


@app.get("/health")
def health() -> JSONResponse:
    """Liveness + DB reachability. No internals leaked on failure."""
    try:
        with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT 1")
        return JSONResponse({"status": "ok"})
    except Exception:
        return JSONResponse({"status": "degraded"}, status_code=503)


@app.post("/query", response_model=QueryResponse)
@limiter.limit("5/minute;20/hour;30/day")
def query(request: Request, body: QueryRequest) -> QueryResponse:
    """Run the RAG pipeline end-to-end. See module docstring for guarantees."""
    qhash = hash_query(body.query)
    audit.info("query_received", extra={"audit": {
        "query_hash": qhash, "k": body.k,
        "source_types": body.source_types, "client": get_remote_address(request),
    }})

    t0 = time.perf_counter()
    try:
        with psycopg.connect(DATABASE_URL) as conn:
            t_retrieve_start = time.perf_counter()
            hits = retrieve_hybrid(conn, body.query, k=body.k, source_types=body.source_types)
            retrieval_ms = (time.perf_counter() - t_retrieve_start) * 1000
            gen = generate(body.query, hits)
    except Exception:
        audit.exception("query_failed", extra={"audit": {"query_hash": qhash}})
        raise HTTPException(status_code=500, detail="internal_error")

    total_ms = (time.perf_counter() - t0) * 1000
    audit.info("query_completed", extra={"audit": {
        "query_hash": qhash,
        "k": body.k,
        "retrieved_count": len(hits),
        "cited_count": len(gen.cited_ids),
        "invalid_cited_count": len(gen.invalid_cited_ids),
        "refused": gen.refused,
        "model": gen.model,
        "retrieval_ms": round(retrieval_ms, 1),
        "generation_ms": round(gen.latency_ms, 1),
        "total_ms": round(total_ms, 1),
    }})

    return QueryResponse(
        answer=gen.answer,
        cited_ids=gen.cited_ids,
        invalid_cited_ids=gen.invalid_cited_ids,
        refused=gen.refused,
        retrieved_chunks=[
            ChunkSummary(
                chunk_id=h.hit.chunk_id,
                source_type=h.hit.source_type,
                section=h.hit.section,
                title=h.hit.title,
                chunk_text=h.hit.chunk_text,
                rerank_score=h.rerank_score,
            )
            for h in hits
        ],
        model=gen.model,
        latency=Latencies(
            retrieval_ms=round(retrieval_ms, 1),
            generation_ms=round(gen.latency_ms, 1),
            total_ms=round(total_ms, 1),
        ),
    )


# ─── HTMX-served UI ────────────────────────────────────────────────────────


@app.get("/ui", response_class=HTMLResponse)
def ui_index(request: Request) -> HTMLResponse:
    """Render the main page. Empty results section; HTMX swaps it in."""
    return templates.TemplateResponse(request, "index.html", {})


@app.get("/help", response_class=HTMLResponse)
def ui_help(request: Request) -> HTMLResponse:
    """Static help page: what the system offers, examples, limits, pipeline."""
    return templates.TemplateResponse(request, "help.html", {})


# ─── Password-gated /eval dashboard ────────────────────────────────────────


@app.get(
    "/eval",
    response_class=HTMLResponse,
    dependencies=[Depends(_require_eval_password)],
)
def eval_dashboard(request: Request) -> HTMLResponse:
    """Password-protected eval visualization dashboard."""
    return templates.TemplateResponse(request, "eval.html", {})


@app.get(
    "/eval/data",
    dependencies=[Depends(_require_eval_password)],
)
def eval_data() -> JSONResponse:
    """JSON feed for the dashboard: run history + live corpus stats."""
    return JSONResponse({
        "runs": _load_eval_runs(),
        "corpus": _corpus_stats(),
    })


def _load_eval_runs() -> list[dict[str, Any]]:
    """All eval/results/*.json files, oldest first. Empty list if the
    directory hasn't been populated yet (fresh clone, Docker without
    the volume mount)."""
    if not _EVAL_RESULTS_DIR.is_dir():
        return []
    runs: list[dict[str, Any]] = []
    for path in sorted(_EVAL_RESULTS_DIR.glob("*.json")):
        try:
            runs.append(json.loads(path.read_text()))
        except json.JSONDecodeError:
            continue
    return runs


def _corpus_stats() -> dict[str, Any]:
    """Live Postgres counts by source, plus top sections per source."""
    try:
        with psycopg.connect(DATABASE_URL, connect_timeout=2) as conn:
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT source_type, COUNT(*) FROM documents GROUP BY 1 ORDER BY 1"
                )
                docs = {row[0]: row[1] for row in cur.fetchall()}
                cur.execute(
                    "SELECT source_type, COUNT(*) FROM chunks_with_source "
                    "GROUP BY 1 ORDER BY 1"
                )
                chunks = {row[0]: row[1] for row in cur.fetchall()}
                cur.execute(
                    "SELECT source_type, section, COUNT(*) AS n "
                    "FROM chunks_with_source WHERE section IS NOT NULL "
                    "GROUP BY 1, 2 ORDER BY n DESC LIMIT 40"
                )
                sections = [
                    {"source_type": r[0], "section": r[1], "n": r[2]}
                    for r in cur.fetchall()
                ]
        return {"docs": docs, "chunks": chunks, "sections": sections}
    except Exception:
        return {"docs": {}, "chunks": {}, "sections": []}


@app.post("/ui/query", response_class=HTMLResponse)
@limiter.limit("5/minute;20/hour;30/day")
def ui_query(
    request: Request,
    query: str = Form(..., min_length=1, max_length=2000),
    k: int = Form(5, ge=1, le=20),
) -> HTMLResponse:
    """HTMX endpoint: returns rendered _results.html fragment for swap-in."""
    qhash = hash_query(query)
    audit.info("ui_query_received", extra={"audit": {
        "query_hash": qhash, "k": k, "client": get_remote_address(request),
    }})
    t0 = time.perf_counter()
    try:
        with psycopg.connect(DATABASE_URL) as conn:
            t_r = time.perf_counter()
            hits = retrieve_hybrid(conn, query, k=k)
            retrieval_ms = (time.perf_counter() - t_r) * 1000
            gen = generate(query, hits)
    except Exception:
        audit.exception("ui_query_failed", extra={"audit": {"query_hash": qhash}})
        return templates.TemplateResponse(
            request, "_error.html", {"message": "Something went wrong. Please try again."},
            status_code=500,
        )
    total_ms = (time.perf_counter() - t0) * 1000
    audit.info("ui_query_completed", extra={"audit": {
        "query_hash": qhash, "k": k, "retrieved_count": len(hits),
        "cited_count": len(gen.cited_ids), "invalid_cited_count": len(gen.invalid_cited_ids),
        "refused": gen.refused, "model": gen.model,
        "retrieval_ms": round(retrieval_ms, 1),
        "generation_ms": round(gen.latency_ms, 1),
        "total_ms": round(total_ms, 1),
    }})

    answer_html = _render_citations(gen.answer)
    return templates.TemplateResponse(request, "_results.html", {
        "answer_html": answer_html,
        "cited_ids": gen.cited_ids,
        "invalid_cited_ids": gen.invalid_cited_ids,
        "refused": gen.refused,
        "hits": hits,
        "model": gen.model,
        "retrieval_ms": round(retrieval_ms, 0),
        "generation_ms": round(gen.latency_ms, 0),
        "total_ms": round(total_ms, 0),
    })


def _render_citations(answer: str) -> str:
    """Wrap each [chunk_id] in a clickable span GSAP/JS hooks into.

    Escapes the text first; chunk IDs are integers from our DB so they're
    safe to interpolate, but the surrounding answer is LLM output and must
    be HTML-escaped before injecting our spans.
    """
    from html import escape
    safe = escape(answer)

    def _wrap(m: re.Match) -> str:
        cid = m.group(1)
        return f'<span class="citation" data-chunk="{cid}" tabindex="0">[{cid}]</span>'
    return _CITATION_RE.sub(_wrap, safe)