rags_api / app.py
Skier8402's picture
Update app.py
4038942 verified
"""
RAG QA API module
Provides a small FastAPI application that implements a Retrieval-Augmented
Generation (RAG) question-answering service for PDF documents.
High-level behavior
- /process (POST): Accepts PDF uploads, stores them temporarily, and schedules
a background task to extract text, build a vector store (index), and create
a conversation chain. Returns a session_id that clients can use with /ask.
- /ask (POST): Accepts a session_id and a user question. If the session is
still processing, returns 202 with current status. If ready, performs
safety checks on the prompt, runs the conversation chain (LLM) to generate
an answer, optionally retrieves supporting sources from the vectorstore,
applies a safety check to the final answer, and returns the answer,
source list, and session_id.
- /evaluate (POST): Runs evaluation on provided QA pairs in a background
thread (non-blocking) and returns evaluation results and scalar metrics
where available.
- /health (GET): Lightweight healthcheck endpoint returning {"ok": True}.
Key implementation notes
- Temporary file handling: Uploaded files are saved to a temporary directory
per session. A background worker parses PDFs from those paths and then
removes the temporary directory.
- Sessions: In-memory dictionary SESSIONS maps session_id -> session state,
including status ("processing", "ready", "failed"), vectorstore, and the
conversation chain. Sessions are ephemeral and lost on process restart.
- Background processing: _background_process performs PDF parsing (via
get_documents_from_pdfs), text chunking (get_text_chunks), vectorstore
creation (get_vectorstore) and chain creation (get_conversation_chain).
Errors during processing mark the session as failed and record an error.
- Retrieval for sources: When available, the vectorstore's similarity_search
is used to fetch up to k=3 supporting documents and return their metadata
"source" fields as source filenames/identifiers.
- Safety: A safety module (check_safety) is invoked on both the prompt and
the generated answer. A SAFETY_SCORE_THRESHOLD environment variable
controls the minimum acceptable safety score (default 40.0). Prompts or
answers that fail the safety policy are refused with HTTP 403.
- Concurrency: CPU-/IO-bound operations that may block (PDF parsing, index
building, LLM chain invocation, evaluation) are run in background threads
via asyncio.to_thread or as asyncio tasks to avoid blocking the event loop.
- CORS: CORSMiddleware is configured to allow all origins (allow_origins=["*"])
so the API can be called from browser-based clients. Restrict this in
production as needed.
- Logging: A rotating file handler writes to api.log (1 MB max, 3 backups)
and logs are also emitted to stdout. Telemetry helpers (make_snippet)
are used in logs to avoid including large payloads.
Environment variables
- SAFETY_SCORE_THRESHOLD (float, default 40.0): Minimum safety score required
for prompts and answers. When score is below this threshold or the safety
layer returns "refuse", the request will be rejected with HTTP 403.
Error handling and HTTP semantics
- 202 Accepted: Returned by /ask when the requested session is still processing.
- 400 Bad Request: Missing/invalid input (e.g., no files uploaded).
- 403 Forbidden: Prompt or generated answer refused by safety policy.
- 404 Not Found: session_id does not exist.
- 500 Internal Server Error: LLM generation, evaluation, or unexpected errors.
- Exceptions during background tasks are logged and set the session status to "failed".
Dependencies and integration points (expected to be provided in the codebase)
- generate_testset.get_documents_from_pdfs(paths): Asynchronously invoked in
a background thread to parse uploaded PDFs into document objects.
- app.get_text_chunks(texts_with_meta): Splits long documents into chunks
and returns (texts, metadatas).
- app.get_vectorstore(texts, metadatas): Builds/returns a vectorstore/index
compatible with similarity_search.
- app.get_conversation_chain(vectorstore): Returns a callable conversation chain
that accepts {"question": <str>} and returns model output (dict-like).
- safety.check_safety(text, include_meta=True): Performs prompt/answer safety
checks and returns (safe_text, passed_bool_or_refuse, score, meta).
- telemetry.make_snippet(text): Helper used to produce short snippets for logs.
- eval.run_evaluation(...) and eval.extract_scalar_metrics(...): Used by /evaluate.
Usage example (server)
- Start the server:
uvicorn api:app --host 0.0.0.0 --port 8000
- Typical client flow:
1. POST /process with one or more PDF files -> receives session_id.
2. Poll or POST /ask with session_id and question -> if ready returns answer.
3. Optionally POST /evaluate with QA pairs to run batch evaluation.
Limitations and considerations
- All session state is kept in memory. For production use, persist state
(vectorstore, chain configuration) to durable storage to survive restarts.
- CORS is permissive by default—tighten origins and credentials for production.
- The exact formats returned by the conversation chain and vectorstore are
implementation-dependent; the API contains fallbacks to extract text and
sources from common shapes but may need adaptation for different backends.
- Uploaded files are removed once background processing completes (success or failure).
- Ensure the safety module and the LLM backend are properly configured and
monitored for cost/latency in production deployments.
Security
- Inputs are safety-checked before being sent to the LLM and answers are
safety-checked before being returned.
- The API currently allows file uploads and will execute background parsing;
consider authentication, rate-limiting, and content scanning for safe
operation in shared environments.
API for RAG QA service
Provides endpoints to upload PDFs, ask questions, and evaluate QA pairs.
"""
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Security, Depends
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader, APIKeyQuery
from typing import List, Optional
import uuid
import tempfile
import shutil
import os
import logging
import sys
from logging.handlers import RotatingFileHandler
import asyncio
from datetime import datetime, timezone
from pydantic import BaseModel
# Import core helpers from the existing codebase
from original_rag import get_text_chunks, get_vectorstore, get_conversation_chain
from telemetry import make_snippet
app = FastAPI(title="RAG QA API")
# API Key authentication setup
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
# Load valid API keys from environment variable
# Format: comma-separated list of keys, e.g., "key1,key2,key3"
VALID_API_KEYS = set(
key.strip() for key in os.getenv("API_KEYS", "").split(",") if key.strip()
)
# If no API keys configured, log a warning
if not VALID_API_KEYS:
print("WARNING: No API keys configured. Set API_KEYS environment variable.")
print("Example: API_KEYS='your-secret-key-1,your-secret-key-2'")
async def verify_api_key(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
):
"""Verify API key from header or query parameter.
Checks X-API-Key header first, then falls back to ?api_key= query param.
Returns the valid API key if authenticated, otherwise raises 401/403.
"""
# If no keys are configured, skip authentication (dev mode)
if not VALID_API_KEYS:
logger.warning("API key authentication is disabled (no keys configured)")
return None
# Check header first, then query parameter
api_key = api_key_header or api_key_query
if not api_key:
logger.warning("API key missing in request")
raise HTTPException(
status_code=401,
detail="API key required. Provide via X-API-Key header or api_key query parameter.",
)
if api_key not in VALID_API_KEYS:
logger.warning("Invalid API key attempted: %s", api_key[:8] + "...")
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
# Setup logging: rotating file + stdout
logger = logging.getLogger("rag_api")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
file_handler = RotatingFileHandler(
"api.log", maxBytes=1 * 1024 * 1024, backupCount=3, encoding="utf-8"
)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
# Allow cross-origin requests so JS services can call the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# In-memory sessions: session_id -> {vectorstore, conversation}
SESSIONS = {}
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
def _build_status_payload(session_id: str, sess: dict) -> dict:
return {
"session_id": session_id,
"status": sess.get("status", "queued"),
"progress_pct": sess.get("progress_pct", 0),
"message": sess.get("message", ""),
"doc_count": sess.get("doc_count"),
"chunk_count": sess.get("chunk_count"),
"started_at": sess.get("started_at"),
"updated_at": sess.get("updated_at"),
"completed_at": sess.get("completed_at"),
"error": sess.get("error"),
}
def _update_session(session_id: str, **updates):
sess = SESSIONS.get(session_id)
if not sess:
return
now = _utc_now_iso()
updates["updated_at"] = now
if updates.get("status") in {"ready", "failed"} and not updates.get("completed_at"):
updates["completed_at"] = now
sess.update(updates)
async def _background_process(session_id: str, tmpdir: str, paths: list):
"""Background worker to parse uploaded PDFs and build vectorstore/chain.
Stores results in SESSIONS[session_id] and removes the temporary upload dir.
"""
try:
from generate_testset import get_documents_from_pdfs
_update_session(
session_id,
status="parsing",
progress_pct=10,
message="Parsing uploaded PDFs",
error=None,
)
documents = await asyncio.to_thread(get_documents_from_pdfs, paths)
if not documents:
_update_session(
session_id,
status="failed",
progress_pct=100,
message="Parsing failed: no documents parsed",
error="No documents parsed from uploads",
)
return
_update_session(
session_id,
status="chunking",
progress_pct=35,
message="Chunking parsed text",
doc_count=len(documents),
)
texts, metadatas = get_text_chunks(
[
{
"text": (
d.page_content
if hasattr(d, "page_content")
else d.get("text", "")
),
"source": (
d.metadata.get("source")
if getattr(d, "metadata", None)
else None
),
}
for d in documents
]
)
_update_session(
session_id,
status="indexing",
progress_pct=60,
message="Building vector store",
chunk_count=len(texts),
)
vectorstore = await asyncio.to_thread(get_vectorstore, texts, metadatas)
_update_session(
session_id,
status="building_chain",
progress_pct=85,
message="Building conversation chain",
)
try:
conversation = await asyncio.to_thread(get_conversation_chain, vectorstore)
except Exception as e:
_update_session(
session_id,
status="failed",
progress_pct=100,
message="Conversation chain creation failed",
error=f"Conversation creation failed: {e}",
)
return
_update_session(
session_id,
vectorstore=vectorstore,
conversation=conversation,
status="ready",
progress_pct=100,
message="Session is ready for questions",
error=None,
)
logger.info("Background processing complete for session %s", session_id)
except Exception as e:
logger.exception(
"Background processing failed for session %s: %s", session_id, e
)
_update_session(
session_id,
status="failed",
progress_pct=100,
message="Background processing failed",
error=str(e),
)
finally:
try:
shutil.rmtree(tmpdir)
except Exception:
pass
@app.get("/health")
def health():
return {"ok": True}
@app.get("/status")
def get_status(session_id: str, api_key: str = Depends(verify_api_key)):
sess = SESSIONS.get(session_id)
if not sess:
raise HTTPException(status_code=404, detail="session_id not found")
return _build_status_payload(session_id, sess)
@app.post("/process")
async def process(
files: List[UploadFile] = File(...), api_key: str = Depends(verify_api_key)
):
"""Upload PDFs and create a retrieval index + conversation chain.
Returns a `session_id` to use with `/ask`.
Requires API key authentication via X-API-Key header or api_key query parameter.
"""
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
# Save uploaded files to a temporary directory so helper code can read them
tmpdir = tempfile.mkdtemp(prefix="rag_upload_")
paths = []
# create session early and schedule background processing
session_id = str(uuid.uuid4())
now = _utc_now_iso()
SESSIONS[session_id] = {
"status": "queued",
"progress_pct": 0,
"message": "Upload received and queued",
"started_at": now,
"updated_at": now,
"completed_at": None,
"error": None,
"doc_count": None,
"chunk_count": None,
"tmpdir": tmpdir,
}
try:
_update_session(
session_id,
status="parsing",
progress_pct=5,
message="Saving uploaded files",
)
for f in files:
dest = os.path.join(tmpdir, f.filename or f"upload-{uuid.uuid4()}.pdf")
with open(dest, "wb") as out:
content = await f.read()
out.write(content)
paths.append(dest)
_update_session(
session_id,
status="queued",
progress_pct=8,
message="Files saved; waiting for background processing",
doc_count=len(paths),
)
# schedule background processing to build index & chain
asyncio.create_task(_background_process(session_id, tmpdir, paths))
logger.info("Scheduled background processing for session %s", session_id)
payload = _build_status_payload(session_id, SESSIONS[session_id])
payload["estimated_poll_interval_ms"] = 2500
return payload
except Exception:
# on unexpected error, clean up tmpdir and re-raise
_update_session(
session_id,
status="failed",
progress_pct=100,
message="Failed while saving uploads",
error="Unexpected error while handling uploaded files",
)
try:
shutil.rmtree(tmpdir)
except Exception:
pass
raise
@app.post("/ask")
async def ask(
session_id: str = Form(...),
question: str = Form(...),
api_key: str = Depends(verify_api_key),
):
"""Ask a question against a previously created session (session_id).
Returns the answer text and list of source filenames.
Requires API key authentication via X-API-Key header or api_key query parameter.
"""
sess = SESSIONS.get(session_id)
if not sess:
raise HTTPException(status_code=404, detail="session_id not found")
# If processing not complete, return 202 Accepted with full status payload
status = sess.get("status")
if status == "failed":
payload = _build_status_payload(session_id, sess)
return JSONResponse(status_code=409, content=payload)
if status != "ready":
return JSONResponse(
status_code=202, content=_build_status_payload(session_id, sess)
)
conv = sess.get("conversation")
if conv is None:
raise HTTPException(status_code=500, detail="conversation chain missing")
try:
from safety import check_safety
# Check prompt safety before passing to the LLM
safe_q, q_passed, q_score, q_meta = check_safety(question, include_meta=True)
logger.info(
"Session %s: prompt safety passed=%s score=%.1f snippet=%s",
session_id,
bool(q_passed),
float(q_score),
make_snippet(question),
)
# If safeguard metadata present, include in telemetry log for thresholds
if q_meta and isinstance(q_meta, dict):
logger.info("Session %s: prompt safety metadata=%s", session_id, q_meta)
threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0"))
if (not q_passed and safe_q == "refuse") or float(q_score) < threshold:
logger.warning(
"Refusing prompt for session %s: score=%.1f threshold=%.1f",
session_id,
float(q_score),
threshold,
)
raise HTTPException(
status_code=403, detail="Question refused by safety policy"
)
# The conversation chain may be blocking; run in thread to avoid blocking the event loop
response = await asyncio.to_thread(lambda: conv({"question": safe_q}))
except Exception as e:
logger.exception("Error generating response for session %s: %s", session_id, e)
raise HTTPException(status_code=500, detail="LLM generation error")
# Try to extract the final message similarly to the Streamlit handler
answer = ""
chat_history = response.get("chat_history") if isinstance(response, dict) else None
if chat_history:
last = chat_history[-1]
answer = getattr(last, "content", str(last))
else:
# fallback: check common keys
answer = (
response.get("answer")
if isinstance(response, dict) and "answer" in response
else str(response)
)
# Try to fetch sources if vectorstore available
sources = []
try:
vs = sess.get("vectorstore")
if vs is not None:
docs = vs.similarity_search(question, k=3)
for d in docs:
src = d.metadata.get("source") if getattr(d, "metadata", None) else None
if src and src not in sources:
sources.append(src)
except Exception:
# ignore retrieval errors for now
pass
# Safety check: ensure we do not return unsafe content
try:
from safety import check_safety
safe_answer, passed, safety_score, a_meta = check_safety(
answer, include_meta=True
)
logger.info(
"Safety check for session %s passed=%s score=%.1f snippet=%s",
session_id,
bool(passed),
float(safety_score),
make_snippet(answer),
)
if a_meta and isinstance(a_meta, dict):
logger.info("Session %s: answer safety metadata=%s", session_id, a_meta)
threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0"))
if (not passed and safe_answer == "refuse") or float(safety_score) < threshold:
# refuse to return the content
logger.warning(
"Refusing answer for session %s: score=%.1f threshold=%.1f",
session_id,
float(safety_score),
threshold,
)
raise HTTPException(
status_code=403, detail="Response refused by safety policy"
)
# prefer the possibly fixed output from the safety layer
answer = safe_answer
except Exception as e: # pylint: disable=broad-except
logger.debug("Safety check unavailable or failed: %s", e)
return {"answer": answer, "sources": sources, "session_id": session_id}
class EvalPayload(BaseModel):
questions: List[str]
answers: List[str]
contexts: List[List[str]]
ground_truths: Optional[List[str]] = None
@app.post("/evaluate")
async def evaluate(payload: EvalPayload, api_key: str = Depends(verify_api_key)):
"""Run Ragas evaluation on provided QA pairs. Runs in a thread to avoid blocking the server.
Requires API key authentication via X-API-Key header or api_key query parameter.
"""
logger.info("Starting evaluation for %d examples", len(payload.questions))
try:
from eval import run_evaluation
result = await asyncio.to_thread(
lambda: run_evaluation(
payload.questions,
payload.answers,
payload.contexts,
payload.ground_truths,
)
)
logger.info("Evaluation complete")
# Try to extract scalar metrics for convenient API consumption
try:
from eval import extract_scalar_metrics
metrics = extract_scalar_metrics(result)
except Exception:
metrics = {}
# `result` may be a complex object; include repr but return parsed metrics if available
return {"result": repr(result), "metrics": metrics}
except Exception as e: # pylint: disable=broad-except
logger.exception("Evaluation failed: %s", e)
raise HTTPException(status_code=500, detail="Evaluation failed")