Spaces:
Sleeping
Sleeping
| """ | |
| RAG QA API module | |
| Provides a small FastAPI application that implements a Retrieval-Augmented | |
| Generation (RAG) question-answering service for PDF documents. | |
| High-level behavior | |
| - /process (POST): Accepts PDF uploads, stores them temporarily, and schedules | |
| a background task to extract text, build a vector store (index), and create | |
| a conversation chain. Returns a session_id that clients can use with /ask. | |
| - /ask (POST): Accepts a session_id and a user question. If the session is | |
| still processing, returns 202 with current status. If ready, performs | |
| safety checks on the prompt, runs the conversation chain (LLM) to generate | |
| an answer, optionally retrieves supporting sources from the vectorstore, | |
| applies a safety check to the final answer, and returns the answer, | |
| source list, and session_id. | |
| - /evaluate (POST): Runs evaluation on provided QA pairs in a background | |
| thread (non-blocking) and returns evaluation results and scalar metrics | |
| where available. | |
| - /health (GET): Lightweight healthcheck endpoint returning {"ok": True}. | |
| Key implementation notes | |
| - Temporary file handling: Uploaded files are saved to a temporary directory | |
| per session. A background worker parses PDFs from those paths and then | |
| removes the temporary directory. | |
| - Sessions: In-memory dictionary SESSIONS maps session_id -> session state, | |
| including status ("processing", "ready", "failed"), vectorstore, and the | |
| conversation chain. Sessions are ephemeral and lost on process restart. | |
| - Background processing: _background_process performs PDF parsing (via | |
| get_documents_from_pdfs), text chunking (get_text_chunks), vectorstore | |
| creation (get_vectorstore) and chain creation (get_conversation_chain). | |
| Errors during processing mark the session as failed and record an error. | |
| - Retrieval for sources: When available, the vectorstore's similarity_search | |
| is used to fetch up to k=3 supporting documents and return their metadata | |
| "source" fields as source filenames/identifiers. | |
| - Safety: A safety module (check_safety) is invoked on both the prompt and | |
| the generated answer. A SAFETY_SCORE_THRESHOLD environment variable | |
| controls the minimum acceptable safety score (default 40.0). Prompts or | |
| answers that fail the safety policy are refused with HTTP 403. | |
| - Concurrency: CPU-/IO-bound operations that may block (PDF parsing, index | |
| building, LLM chain invocation, evaluation) are run in background threads | |
| via asyncio.to_thread or as asyncio tasks to avoid blocking the event loop. | |
| - CORS: CORSMiddleware is configured to allow all origins (allow_origins=["*"]) | |
| so the API can be called from browser-based clients. Restrict this in | |
| production as needed. | |
| - Logging: A rotating file handler writes to api.log (1 MB max, 3 backups) | |
| and logs are also emitted to stdout. Telemetry helpers (make_snippet) | |
| are used in logs to avoid including large payloads. | |
| Environment variables | |
| - SAFETY_SCORE_THRESHOLD (float, default 40.0): Minimum safety score required | |
| for prompts and answers. When score is below this threshold or the safety | |
| layer returns "refuse", the request will be rejected with HTTP 403. | |
| Error handling and HTTP semantics | |
| - 202 Accepted: Returned by /ask when the requested session is still processing. | |
| - 400 Bad Request: Missing/invalid input (e.g., no files uploaded). | |
| - 403 Forbidden: Prompt or generated answer refused by safety policy. | |
| - 404 Not Found: session_id does not exist. | |
| - 500 Internal Server Error: LLM generation, evaluation, or unexpected errors. | |
| - Exceptions during background tasks are logged and set the session status to "failed". | |
| Dependencies and integration points (expected to be provided in the codebase) | |
| - generate_testset.get_documents_from_pdfs(paths): Asynchronously invoked in | |
| a background thread to parse uploaded PDFs into document objects. | |
| - app.get_text_chunks(texts_with_meta): Splits long documents into chunks | |
| and returns (texts, metadatas). | |
| - app.get_vectorstore(texts, metadatas): Builds/returns a vectorstore/index | |
| compatible with similarity_search. | |
| - app.get_conversation_chain(vectorstore): Returns a callable conversation chain | |
| that accepts {"question": <str>} and returns model output (dict-like). | |
| - safety.check_safety(text, include_meta=True): Performs prompt/answer safety | |
| checks and returns (safe_text, passed_bool_or_refuse, score, meta). | |
| - telemetry.make_snippet(text): Helper used to produce short snippets for logs. | |
| - eval.run_evaluation(...) and eval.extract_scalar_metrics(...): Used by /evaluate. | |
| Usage example (server) | |
| - Start the server: | |
| uvicorn api:app --host 0.0.0.0 --port 8000 | |
| - Typical client flow: | |
| 1. POST /process with one or more PDF files -> receives session_id. | |
| 2. Poll or POST /ask with session_id and question -> if ready returns answer. | |
| 3. Optionally POST /evaluate with QA pairs to run batch evaluation. | |
| Limitations and considerations | |
| - All session state is kept in memory. For production use, persist state | |
| (vectorstore, chain configuration) to durable storage to survive restarts. | |
| - CORS is permissive by default—tighten origins and credentials for production. | |
| - The exact formats returned by the conversation chain and vectorstore are | |
| implementation-dependent; the API contains fallbacks to extract text and | |
| sources from common shapes but may need adaptation for different backends. | |
| - Uploaded files are removed once background processing completes (success or failure). | |
| - Ensure the safety module and the LLM backend are properly configured and | |
| monitored for cost/latency in production deployments. | |
| Security | |
| - Inputs are safety-checked before being sent to the LLM and answers are | |
| safety-checked before being returned. | |
| - The API currently allows file uploads and will execute background parsing; | |
| consider authentication, rate-limiting, and content scanning for safe | |
| operation in shared environments. | |
| API for RAG QA service | |
| Provides endpoints to upload PDFs, ask questions, and evaluate QA pairs. | |
| """ | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Security, Depends | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import APIKeyHeader, APIKeyQuery | |
| from typing import List, Optional | |
| import uuid | |
| import tempfile | |
| import shutil | |
| import os | |
| import logging | |
| import sys | |
| from logging.handlers import RotatingFileHandler | |
| import asyncio | |
| from datetime import datetime, timezone | |
| from pydantic import BaseModel | |
| # Import core helpers from the existing codebase | |
| from original_rag import get_text_chunks, get_vectorstore, get_conversation_chain | |
| from telemetry import make_snippet | |
| app = FastAPI(title="RAG QA API") | |
| # API Key authentication setup | |
| API_KEY_NAME = "X-API-Key" | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | |
| api_key_query = APIKeyQuery(name="api_key", auto_error=False) | |
| # Load valid API keys from environment variable | |
| # Format: comma-separated list of keys, e.g., "key1,key2,key3" | |
| VALID_API_KEYS = set( | |
| key.strip() for key in os.getenv("API_KEYS", "").split(",") if key.strip() | |
| ) | |
| # If no API keys configured, log a warning | |
| if not VALID_API_KEYS: | |
| print("WARNING: No API keys configured. Set API_KEYS environment variable.") | |
| print("Example: API_KEYS='your-secret-key-1,your-secret-key-2'") | |
| async def verify_api_key( | |
| api_key_header: str = Security(api_key_header), | |
| api_key_query: str = Security(api_key_query), | |
| ): | |
| """Verify API key from header or query parameter. | |
| Checks X-API-Key header first, then falls back to ?api_key= query param. | |
| Returns the valid API key if authenticated, otherwise raises 401/403. | |
| """ | |
| # If no keys are configured, skip authentication (dev mode) | |
| if not VALID_API_KEYS: | |
| logger.warning("API key authentication is disabled (no keys configured)") | |
| return None | |
| # Check header first, then query parameter | |
| api_key = api_key_header or api_key_query | |
| if not api_key: | |
| logger.warning("API key missing in request") | |
| raise HTTPException( | |
| status_code=401, | |
| detail="API key required. Provide via X-API-Key header or api_key query parameter.", | |
| ) | |
| if api_key not in VALID_API_KEYS: | |
| logger.warning("Invalid API key attempted: %s", api_key[:8] + "...") | |
| raise HTTPException(status_code=403, detail="Invalid API key") | |
| return api_key | |
| # Setup logging: rotating file + stdout | |
| logger = logging.getLogger("rag_api") | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") | |
| file_handler = RotatingFileHandler( | |
| "api.log", maxBytes=1 * 1024 * 1024, backupCount=3, encoding="utf-8" | |
| ) | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| stream_handler = logging.StreamHandler(stream=sys.stdout) | |
| stream_handler.setFormatter(formatter) | |
| logger.addHandler(stream_handler) | |
| # Allow cross-origin requests so JS services can call the API | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # In-memory sessions: session_id -> {vectorstore, conversation} | |
| SESSIONS = {} | |
| def _utc_now_iso() -> str: | |
| return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") | |
| def _build_status_payload(session_id: str, sess: dict) -> dict: | |
| return { | |
| "session_id": session_id, | |
| "status": sess.get("status", "queued"), | |
| "progress_pct": sess.get("progress_pct", 0), | |
| "message": sess.get("message", ""), | |
| "doc_count": sess.get("doc_count"), | |
| "chunk_count": sess.get("chunk_count"), | |
| "started_at": sess.get("started_at"), | |
| "updated_at": sess.get("updated_at"), | |
| "completed_at": sess.get("completed_at"), | |
| "error": sess.get("error"), | |
| } | |
| def _update_session(session_id: str, **updates): | |
| sess = SESSIONS.get(session_id) | |
| if not sess: | |
| return | |
| now = _utc_now_iso() | |
| updates["updated_at"] = now | |
| if updates.get("status") in {"ready", "failed"} and not updates.get("completed_at"): | |
| updates["completed_at"] = now | |
| sess.update(updates) | |
| async def _background_process(session_id: str, tmpdir: str, paths: list): | |
| """Background worker to parse uploaded PDFs and build vectorstore/chain. | |
| Stores results in SESSIONS[session_id] and removes the temporary upload dir. | |
| """ | |
| try: | |
| from generate_testset import get_documents_from_pdfs | |
| _update_session( | |
| session_id, | |
| status="parsing", | |
| progress_pct=10, | |
| message="Parsing uploaded PDFs", | |
| error=None, | |
| ) | |
| documents = await asyncio.to_thread(get_documents_from_pdfs, paths) | |
| if not documents: | |
| _update_session( | |
| session_id, | |
| status="failed", | |
| progress_pct=100, | |
| message="Parsing failed: no documents parsed", | |
| error="No documents parsed from uploads", | |
| ) | |
| return | |
| _update_session( | |
| session_id, | |
| status="chunking", | |
| progress_pct=35, | |
| message="Chunking parsed text", | |
| doc_count=len(documents), | |
| ) | |
| texts, metadatas = get_text_chunks( | |
| [ | |
| { | |
| "text": ( | |
| d.page_content | |
| if hasattr(d, "page_content") | |
| else d.get("text", "") | |
| ), | |
| "source": ( | |
| d.metadata.get("source") | |
| if getattr(d, "metadata", None) | |
| else None | |
| ), | |
| } | |
| for d in documents | |
| ] | |
| ) | |
| _update_session( | |
| session_id, | |
| status="indexing", | |
| progress_pct=60, | |
| message="Building vector store", | |
| chunk_count=len(texts), | |
| ) | |
| vectorstore = await asyncio.to_thread(get_vectorstore, texts, metadatas) | |
| _update_session( | |
| session_id, | |
| status="building_chain", | |
| progress_pct=85, | |
| message="Building conversation chain", | |
| ) | |
| try: | |
| conversation = await asyncio.to_thread(get_conversation_chain, vectorstore) | |
| except Exception as e: | |
| _update_session( | |
| session_id, | |
| status="failed", | |
| progress_pct=100, | |
| message="Conversation chain creation failed", | |
| error=f"Conversation creation failed: {e}", | |
| ) | |
| return | |
| _update_session( | |
| session_id, | |
| vectorstore=vectorstore, | |
| conversation=conversation, | |
| status="ready", | |
| progress_pct=100, | |
| message="Session is ready for questions", | |
| error=None, | |
| ) | |
| logger.info("Background processing complete for session %s", session_id) | |
| except Exception as e: | |
| logger.exception( | |
| "Background processing failed for session %s: %s", session_id, e | |
| ) | |
| _update_session( | |
| session_id, | |
| status="failed", | |
| progress_pct=100, | |
| message="Background processing failed", | |
| error=str(e), | |
| ) | |
| finally: | |
| try: | |
| shutil.rmtree(tmpdir) | |
| except Exception: | |
| pass | |
| def health(): | |
| return {"ok": True} | |
| def get_status(session_id: str, api_key: str = Depends(verify_api_key)): | |
| sess = SESSIONS.get(session_id) | |
| if not sess: | |
| raise HTTPException(status_code=404, detail="session_id not found") | |
| return _build_status_payload(session_id, sess) | |
| async def process( | |
| files: List[UploadFile] = File(...), api_key: str = Depends(verify_api_key) | |
| ): | |
| """Upload PDFs and create a retrieval index + conversation chain. | |
| Returns a `session_id` to use with `/ask`. | |
| Requires API key authentication via X-API-Key header or api_key query parameter. | |
| """ | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files uploaded") | |
| # Save uploaded files to a temporary directory so helper code can read them | |
| tmpdir = tempfile.mkdtemp(prefix="rag_upload_") | |
| paths = [] | |
| # create session early and schedule background processing | |
| session_id = str(uuid.uuid4()) | |
| now = _utc_now_iso() | |
| SESSIONS[session_id] = { | |
| "status": "queued", | |
| "progress_pct": 0, | |
| "message": "Upload received and queued", | |
| "started_at": now, | |
| "updated_at": now, | |
| "completed_at": None, | |
| "error": None, | |
| "doc_count": None, | |
| "chunk_count": None, | |
| "tmpdir": tmpdir, | |
| } | |
| try: | |
| _update_session( | |
| session_id, | |
| status="parsing", | |
| progress_pct=5, | |
| message="Saving uploaded files", | |
| ) | |
| for f in files: | |
| dest = os.path.join(tmpdir, f.filename or f"upload-{uuid.uuid4()}.pdf") | |
| with open(dest, "wb") as out: | |
| content = await f.read() | |
| out.write(content) | |
| paths.append(dest) | |
| _update_session( | |
| session_id, | |
| status="queued", | |
| progress_pct=8, | |
| message="Files saved; waiting for background processing", | |
| doc_count=len(paths), | |
| ) | |
| # schedule background processing to build index & chain | |
| asyncio.create_task(_background_process(session_id, tmpdir, paths)) | |
| logger.info("Scheduled background processing for session %s", session_id) | |
| payload = _build_status_payload(session_id, SESSIONS[session_id]) | |
| payload["estimated_poll_interval_ms"] = 2500 | |
| return payload | |
| except Exception: | |
| # on unexpected error, clean up tmpdir and re-raise | |
| _update_session( | |
| session_id, | |
| status="failed", | |
| progress_pct=100, | |
| message="Failed while saving uploads", | |
| error="Unexpected error while handling uploaded files", | |
| ) | |
| try: | |
| shutil.rmtree(tmpdir) | |
| except Exception: | |
| pass | |
| raise | |
| async def ask( | |
| session_id: str = Form(...), | |
| question: str = Form(...), | |
| api_key: str = Depends(verify_api_key), | |
| ): | |
| """Ask a question against a previously created session (session_id). | |
| Returns the answer text and list of source filenames. | |
| Requires API key authentication via X-API-Key header or api_key query parameter. | |
| """ | |
| sess = SESSIONS.get(session_id) | |
| if not sess: | |
| raise HTTPException(status_code=404, detail="session_id not found") | |
| # If processing not complete, return 202 Accepted with full status payload | |
| status = sess.get("status") | |
| if status == "failed": | |
| payload = _build_status_payload(session_id, sess) | |
| return JSONResponse(status_code=409, content=payload) | |
| if status != "ready": | |
| return JSONResponse( | |
| status_code=202, content=_build_status_payload(session_id, sess) | |
| ) | |
| conv = sess.get("conversation") | |
| if conv is None: | |
| raise HTTPException(status_code=500, detail="conversation chain missing") | |
| try: | |
| from safety import check_safety | |
| # Check prompt safety before passing to the LLM | |
| safe_q, q_passed, q_score, q_meta = check_safety(question, include_meta=True) | |
| logger.info( | |
| "Session %s: prompt safety passed=%s score=%.1f snippet=%s", | |
| session_id, | |
| bool(q_passed), | |
| float(q_score), | |
| make_snippet(question), | |
| ) | |
| # If safeguard metadata present, include in telemetry log for thresholds | |
| if q_meta and isinstance(q_meta, dict): | |
| logger.info("Session %s: prompt safety metadata=%s", session_id, q_meta) | |
| threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0")) | |
| if (not q_passed and safe_q == "refuse") or float(q_score) < threshold: | |
| logger.warning( | |
| "Refusing prompt for session %s: score=%.1f threshold=%.1f", | |
| session_id, | |
| float(q_score), | |
| threshold, | |
| ) | |
| raise HTTPException( | |
| status_code=403, detail="Question refused by safety policy" | |
| ) | |
| # The conversation chain may be blocking; run in thread to avoid blocking the event loop | |
| response = await asyncio.to_thread(lambda: conv({"question": safe_q})) | |
| except Exception as e: | |
| logger.exception("Error generating response for session %s: %s", session_id, e) | |
| raise HTTPException(status_code=500, detail="LLM generation error") | |
| # Try to extract the final message similarly to the Streamlit handler | |
| answer = "" | |
| chat_history = response.get("chat_history") if isinstance(response, dict) else None | |
| if chat_history: | |
| last = chat_history[-1] | |
| answer = getattr(last, "content", str(last)) | |
| else: | |
| # fallback: check common keys | |
| answer = ( | |
| response.get("answer") | |
| if isinstance(response, dict) and "answer" in response | |
| else str(response) | |
| ) | |
| # Try to fetch sources if vectorstore available | |
| sources = [] | |
| try: | |
| vs = sess.get("vectorstore") | |
| if vs is not None: | |
| docs = vs.similarity_search(question, k=3) | |
| for d in docs: | |
| src = d.metadata.get("source") if getattr(d, "metadata", None) else None | |
| if src and src not in sources: | |
| sources.append(src) | |
| except Exception: | |
| # ignore retrieval errors for now | |
| pass | |
| # Safety check: ensure we do not return unsafe content | |
| try: | |
| from safety import check_safety | |
| safe_answer, passed, safety_score, a_meta = check_safety( | |
| answer, include_meta=True | |
| ) | |
| logger.info( | |
| "Safety check for session %s passed=%s score=%.1f snippet=%s", | |
| session_id, | |
| bool(passed), | |
| float(safety_score), | |
| make_snippet(answer), | |
| ) | |
| if a_meta and isinstance(a_meta, dict): | |
| logger.info("Session %s: answer safety metadata=%s", session_id, a_meta) | |
| threshold = float(os.getenv("SAFETY_SCORE_THRESHOLD", "40.0")) | |
| if (not passed and safe_answer == "refuse") or float(safety_score) < threshold: | |
| # refuse to return the content | |
| logger.warning( | |
| "Refusing answer for session %s: score=%.1f threshold=%.1f", | |
| session_id, | |
| float(safety_score), | |
| threshold, | |
| ) | |
| raise HTTPException( | |
| status_code=403, detail="Response refused by safety policy" | |
| ) | |
| # prefer the possibly fixed output from the safety layer | |
| answer = safe_answer | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.debug("Safety check unavailable or failed: %s", e) | |
| return {"answer": answer, "sources": sources, "session_id": session_id} | |
| class EvalPayload(BaseModel): | |
| questions: List[str] | |
| answers: List[str] | |
| contexts: List[List[str]] | |
| ground_truths: Optional[List[str]] = None | |
| async def evaluate(payload: EvalPayload, api_key: str = Depends(verify_api_key)): | |
| """Run Ragas evaluation on provided QA pairs. Runs in a thread to avoid blocking the server. | |
| Requires API key authentication via X-API-Key header or api_key query parameter. | |
| """ | |
| logger.info("Starting evaluation for %d examples", len(payload.questions)) | |
| try: | |
| from eval import run_evaluation | |
| result = await asyncio.to_thread( | |
| lambda: run_evaluation( | |
| payload.questions, | |
| payload.answers, | |
| payload.contexts, | |
| payload.ground_truths, | |
| ) | |
| ) | |
| logger.info("Evaluation complete") | |
| # Try to extract scalar metrics for convenient API consumption | |
| try: | |
| from eval import extract_scalar_metrics | |
| metrics = extract_scalar_metrics(result) | |
| except Exception: | |
| metrics = {} | |
| # `result` may be a complex object; include repr but return parsed metrics if available | |
| return {"result": repr(result), "metrics": metrics} | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.exception("Evaluation failed: %s", e) | |
| raise HTTPException(status_code=500, detail="Evaluation failed") | |