""" RAG QA API module Provides a small FastAPI application that implements a Retrieval-Augmented Generation (RAG) question-answering service for PDF documents. High-level behavior - /process (POST): Accepts PDF uploads, stores them temporarily, and schedules a background task to extract text, build a vector store (index), and create a conversation chain. Returns a session_id that clients can use with /ask. - /ask (POST): Accepts a session_id and a user question. If the session is still processing, returns 202 with current status. If ready, performs safety checks on the prompt, runs the conversation chain (LLM) to generate an answer, optionally retrieves supporting sources from the vectorstore, applies a safety check to the final answer, and returns the answer, source list, and session_id. - /evaluate (POST): Runs evaluation on provided QA pairs in a background thread (non-blocking) and returns evaluation results and scalar metrics where available. - /health (GET): Lightweight healthcheck endpoint returning {"ok": True}. Key implementation notes - Temporary file handling: Uploaded files are saved to a temporary directory per session. A background worker parses PDFs from those paths and then removes the temporary directory. - Sessions: In-memory dictionary SESSIONS maps session_id -> session state, including status ("processing", "ready", "failed"), vectorstore, and the conversation chain. Sessions are ephemeral and lost on process restart. - Background processing: _background_process performs PDF parsing (via get_documents_from_pdfs), text chunking (get_text_chunks), vectorstore creation (get_vectorstore) and chain creation (get_conversation_chain). Errors during processing mark the session as failed and record an error. - Retrieval for sources: When available, the vectorstore's similarity_search is used to fetch up to k=3 supporting documents and return their metadata "source" fields as source filenames/identifiers. - Safety: A safety module (check_safety) is invoked on both the prompt and the generated answer. A SAFETY_SCORE_THRESHOLD environment variable controls the minimum acceptable safety score (default 40.0). Prompts or answers that fail the safety policy are refused with HTTP 403. - Concurrency: CPU-/IO-bound operations that may block (PDF parsing, index building, LLM chain invocation, evaluation) are run in background threads via asyncio.to_thread or as asyncio tasks to avoid blocking the event loop. - CORS: CORSMiddleware is configured to allow all origins (allow_origins=["*"]) so the API can be called from browser-based clients. Restrict this in production as needed. - Logging: A rotating file handler writes to api.log (1 MB max, 3 backups) and logs are also emitted to stdout. Telemetry helpers (make_snippet) are used in logs to avoid including large payloads. Environment variables - SAFETY_SCORE_THRESHOLD (float, default 40.0): Minimum safety score required for prompts and answers. When score is below this threshold or the safety layer returns "refuse", the request will be rejected with HTTP 403. Error handling and HTTP semantics - 202 Accepted: Returned by /ask when the requested session is still processing. - 400 Bad Request: Missing/invalid input (e.g., no files uploaded). - 403 Forbidden: Prompt or generated answer refused by safety policy. - 404 Not Found: session_id does not exist. - 500 Internal Server Error: LLM generation, evaluation, or unexpected errors. - Exceptions during background tasks are logged and set the session status to "failed". Dependencies and integration points (expected to be provided in the codebase) - generate_testset.get_documents_from_pdfs(paths): Asynchronously invoked in a background thread to parse uploaded PDFs into document objects. - app.get_text_chunks(texts_with_meta): Splits long documents into chunks and returns (texts, metadatas). - app.get_vectorstore(texts, metadatas): Builds/returns a vectorstore/index compatible with similarity_search. - app.get_conversation_chain(vectorstore): Returns a callable conversation chain that accepts {"question": } and returns model output (dict-like). - safety.check_safety(text, include_meta=True): Performs prompt/answer safety checks and returns (safe_text, passed_bool_or_refuse, score, meta). - telemetry.make_snippet(text): Helper used to produce short snippets for logs. - eval.run_evaluation(...) and eval.extract_scalar_metrics(...): Used by /evaluate. Usage example (server) - Start the server: uvicorn api:app --host 0.0.0.0 --port 8000 - Typical client flow: 1. POST /process with one or more PDF files -> receives session_id. 2. Poll or POST /ask with session_id and question -> if ready returns answer. 3. Optionally POST /evaluate with QA pairs to run batch evaluation. Limitations and considerations - All session state is kept in memory. For production use, persist state (vectorstore, chain configuration) to durable storage to survive restarts. - CORS is permissive by default—tighten origins and credentials for production. - The exact formats returned by the conversation chain and vectorstore are implementation-dependent; the API contains fallbacks to extract text and sources from common shapes but may need adaptation for different backends. - Uploaded files are removed once background processing completes (success or failure). - Ensure the safety module and the LLM backend are properly configured and monitored for cost/latency in production deployments. Security - Inputs are safety-checked before being sent to the LLM and answers are safety-checked before being returned. - The API currently allows file uploads and will execute background parsing; consider authentication, rate-limiting, and content scanning for safe operation in shared environments. API for RAG QA service Provides endpoints to upload PDFs, ask questions, and evaluate QA pairs. """ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Security, Depends from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader, APIKeyQuery from typing import List, Optional import uuid import tempfile import shutil import os import logging import sys from logging.handlers import RotatingFileHandler import asyncio from datetime import datetime, timezone from pydantic import BaseModel # Import core helpers from the existing codebase from original_rag import get_text_chunks, get_vectorstore, get_conversation_chain from telemetry import make_snippet app = FastAPI(title="RAG QA API") # API Key authentication setup API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) api_key_query = APIKeyQuery(name="api_key", auto_error=False) # Load valid API keys from environment variable # Format: comma-separated list of keys, e.g., "key1,key2,key3" VALID_API_KEYS = set( key.strip() for key in os.getenv("API_KEYS", "").split(",") if key.strip() ) # If no API keys configured, log a warning if not VALID_API_KEYS: print("WARNING: No API keys configured. Set API_KEYS environment variable.") print("Example: API_KEYS='your-secret-key-1,your-secret-key-2'") async def verify_api_key( api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query), ): """Verify API key from header or query parameter. Checks X-API-Key header first, then falls back to ?api_key= query param. Returns the valid API key if authenticated, otherwise raises 401/403. """ # If no keys are configured, skip authentication (dev mode) if not VALID_API_KEYS: logger.warning("API key authentication is disabled (no keys configured)") return None # Check header first, then query parameter api_key = api_key_header or api_key_query if not api_key: logger.warning("API key missing in request") raise HTTPException( status_code=401, detail="API key required. Provide via X-API-Key header or api_key query parameter.", ) if api_key not in VALID_API_KEYS: logger.warning("Invalid API key attempted: %s", api_key[:8] + "...") raise HTTPException(status_code=403, detail="Invalid API key") return api_key # Setup logging: rotating file + stdout logger = logging.getLogger("rag_api") logger.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") file_handler = RotatingFileHandler( "api.log", maxBytes=1 * 1024 * 1024, backupCount=3, encoding="utf-8" ) file_handler.setFormatter(formatter) logger.addHandler(file_handler) stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) # Allow cross-origin requests so JS services can call the API app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # In-memory sessions: session_id -> {vectorstore, conversation} SESSIONS = {} def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") def _build_status_payload(session_id: str, sess: dict) -> dict: return { "session_id": session_id, "status": sess.get("status", "queued"), "progress_pct": sess.get("progress_pct", 0), "message": sess.get("message", ""), "doc_count": sess.get("doc_count"), "chunk_count": sess.get("chunk_count"), "started_at": sess.get("started_at"), "updated_at": sess.get("updated_at"), "completed_at": sess.get("completed_at"), "error": sess.get("error"), } def _update_session(session_id: str, **updates): sess = SESSIONS.get(session_id) if not sess: return now = _utc_now_iso() updates["updated_at"] = now if updates.get("status") in {"ready", "failed"} and not updates.get("completed_at"): updates["completed_at"] = now sess.update(updates) async def _background_process(session_id: str, tmpdir: str, paths: list): """Background worker to parse uploaded PDFs and build vectorstore/chain. Stores results in SESSIONS[session_id] and removes the temporary upload dir. """ try: from generate_testset import get_documents_from_pdfs _update_session( session_id, status="parsing", progress_pct=10, message="Parsing uploaded PDFs", error=None, ) documents = await asyncio.to_thread(get_documents_from_pdfs, paths) if not documents: _update_session( session_id, status="failed", progress_pct=100, message="Parsing failed: no documents parsed", error="No documents parsed from uploads", ) return _update_session( session_id, status="chunking", progress_pct=35, message="Chunking parsed text", doc_count=len(documents), ) texts, metadatas = get_text_chunks( [ { "text": ( d.page_content if hasattr(d, "page_content") else d.get("text", "") ), "source": ( d.metadata.get("source") if getattr(d, "metadata", None) else None ), } for d in documents ] ) _update_session( session_id, status="indexing", progress_pct=60, message="Building vector store", chunk_count=len(texts), ) vectorstore = await asyncio.to_thread(get_vectorstore, texts, metadatas) _update_session( session_id, status="building_chain", progress_pct=85, message="Building conversation chain", ) try: conversation = await asyncio.to_thread(get_conversation_chain, vectorstore) except Exception as e: _update_session( session_id, status="failed", progress_pct=100, message="Conversation chain creation failed", error=f"Conversation creation failed: {e}", ) return _update_session( session_id, vectorstore=vectorstore, conversation=conversation, status="ready", progress_pct=100, message="Session is ready for questions", error=None, ) logger.info("Background processing complete for session %s", session_id) except Exception as e: logger.exception( "Background processing failed for session %s: %s", session_id, e ) _update_session( session_id, status="failed", progress_pct=100, message="Background processing failed", error=str(e), ) finally: try: shutil.rmtree(tmpdir) except Exception: pass @app.get("/health") def health(): return {"ok": True} @app.get("/status") def get_status(session_id: str, api_key: str = Depends(verify_api_key)): sess = SESSIONS.get(session_id) if not sess: raise HTTPException(status_code=404, detail="session_id not found") return _build_status_payload(session_id, sess) @app.post("/process") async def process( files: List[UploadFile] = File(...), api_key: str = Depends(verify_api_key) ): """Upload PDFs and create a retrieval index + conversation chain. Returns a `session_id` to use with `/ask`. Requires API key authentication via X-API-Key header or api_key query parameter. """ if not files: raise HTTPException(status_code=400, detail="No files uploaded") # Save uploaded files to a temporary directory so helper code can read them tmpdir = tempfile.mkdtemp(prefix="rag_upload_") paths = [] # create session early and schedule background processing session_id = str(uuid.uuid4()) now = _utc_now_iso() SESSIONS[session_id] = { "status": "queued", "progress_pct": 0, "message": "Upload received and queued", "started_at": now, "updated_at": now, "completed_at": None, "error": None, "doc_count": None, "chunk_count": None, "tmpdir": tmpdir, } try: _update_session( session_id, status="parsing", progress_pct=5, message="Saving uploaded files", ) for f in files: dest = os.path.join(tmpdir, f.filename or f"upload-{uuid.uuid4()}.pdf") with open(dest, "wb") as out: content = await f.read() out.write(content) paths.append(dest) _update_session( session_id, status="queued", progress_pct=8, message="Files saved; waiting for background processing", doc_count=len(paths), ) # schedule background processing to build index & chain asyncio.create_task(_background_process(session_id, tmpdir, paths)) logger.info("Scheduled background processing for session %s", session_id) payload = _build_status_payload(session_id, SESSIONS[session_id]) payload["estimated_poll_interval_ms"] = 2500 return payload except Exception: # on unexpected error, clean up tmpdir and re-raise _update_session( session_id, status="failed", progress_pct=100, message="Failed while saving uploads", error="Unexpected error while handling uploaded files", ) try: shutil.rmtree(tmpdir) except Exception: pass raise @app.post("/ask") async def ask( session_id: str = Form(...), question: str = Form(...), api_key: str = Depends(verify_api_key), ): """Ask a question against a previously created session (session_id). Returns the answer text and list of source filenames. Requires API key authentication via X-API-Key header or api_key query parameter. """ sess = SESSIONS.get(session_id) if not sess: raise HTTPException(status_code=404, detail="session_id not found") # If processing not complete, return 202 Accepted with full status payload status = sess.get("status") if status == "failed": payload = _build_status_payload(session_id, sess) return JSONResponse(status_code=409, content=payload) if status != "ready": return JSONResponse( status_code=202, content=_build_status_payload(session_id, sess) ) conv = sess.get("conversation") if conv is None: raise HTTPException(status_code=500, detail="conversation chain missing") try: from safety import check_safety # Check prompt safety before passing to the LLM safe_q, q_passed, q_score, q_meta = check_safety(question, include_meta=True) logger.info( "Session %s: prompt safety passed=%s score=%.1f snippet=%s", session_id, bool(q_passed), float(q_score), make_snippet(question), ) # If safeguard metadata present, include in telemetry log for thresholds if q_meta and isinstance(q_meta, dict): logger.info("Session %s: prompt safety metadata=%s", session_id, q_meta) threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0")) if (not q_passed and safe_q == "refuse") or float(q_score) < threshold: logger.warning( "Refusing prompt for session %s: score=%.1f threshold=%.1f", session_id, float(q_score), threshold, ) raise HTTPException( status_code=403, detail="Question refused by safety policy" ) # The conversation chain may be blocking; run in thread to avoid blocking the event loop response = await asyncio.to_thread(lambda: conv({"question": safe_q})) except Exception as e: logger.exception("Error generating response for session %s: %s", session_id, e) raise HTTPException(status_code=500, detail="LLM generation error") # Try to extract the final message similarly to the Streamlit handler answer = "" chat_history = response.get("chat_history") if isinstance(response, dict) else None if chat_history: last = chat_history[-1] answer = getattr(last, "content", str(last)) else: # fallback: check common keys answer = ( response.get("answer") if isinstance(response, dict) and "answer" in response else str(response) ) # Try to fetch sources if vectorstore available sources = [] try: vs = sess.get("vectorstore") if vs is not None: docs = vs.similarity_search(question, k=3) for d in docs: src = d.metadata.get("source") if getattr(d, "metadata", None) else None if src and src not in sources: sources.append(src) except Exception: # ignore retrieval errors for now pass # Safety check: ensure we do not return unsafe content try: from safety import check_safety safe_answer, passed, safety_score, a_meta = check_safety( answer, include_meta=True ) logger.info( "Safety check for session %s passed=%s score=%.1f snippet=%s", session_id, bool(passed), float(safety_score), make_snippet(answer), ) if a_meta and isinstance(a_meta, dict): logger.info("Session %s: answer safety metadata=%s", session_id, a_meta) threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0")) if (not passed and safe_answer == "refuse") or float(safety_score) < threshold: # refuse to return the content logger.warning( "Refusing answer for session %s: score=%.1f threshold=%.1f", session_id, float(safety_score), threshold, ) raise HTTPException( status_code=403, detail="Response refused by safety policy" ) # prefer the possibly fixed output from the safety layer answer = safe_answer except Exception as e: # pylint: disable=broad-except logger.debug("Safety check unavailable or failed: %s", e) return {"answer": answer, "sources": sources, "session_id": session_id} class EvalPayload(BaseModel): questions: List[str] answers: List[str] contexts: List[List[str]] ground_truths: Optional[List[str]] = None @app.post("/evaluate") async def evaluate(payload: EvalPayload, api_key: str = Depends(verify_api_key)): """Run Ragas evaluation on provided QA pairs. Runs in a thread to avoid blocking the server. Requires API key authentication via X-API-Key header or api_key query parameter. """ logger.info("Starting evaluation for %d examples", len(payload.questions)) try: from eval import run_evaluation result = await asyncio.to_thread( lambda: run_evaluation( payload.questions, payload.answers, payload.contexts, payload.ground_truths, ) ) logger.info("Evaluation complete") # Try to extract scalar metrics for convenient API consumption try: from eval import extract_scalar_metrics metrics = extract_scalar_metrics(result) except Exception: metrics = {} # `result` may be a complex object; include repr but return parsed metrics if available return {"result": repr(result), "metrics": metrics} except Exception as e: # pylint: disable=broad-except logger.exception("Evaluation failed: %s", e) raise HTTPException(status_code=500, detail="Evaluation failed")