""" kerdos_rag/server.py FastAPI REST server exposing the KerdosRAG engine over HTTP. Endpoints: GET /health — liveness probe GET /status — knowledge-base metadata POST /index — upload + index documents (multipart/form-data) POST /chat — ask a question (SSE streaming response) DELETE /reset — clear the knowledge base Authentication (optional): Set API_KEY env var to enable X-Api-Key header validation. Leave unset to run in open mode (suitable for local / trusted environments). """ from __future__ import annotations import os import asyncio from typing import AsyncGenerator from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Header, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from kerdos_rag.core import KerdosRAG # ── App & CORS ──────────────────────────────────────────────────────────────── app = FastAPI( title="Kerdos RAG API", description="Enterprise Document Q&A engine by Kerdos Infrasoft", version="0.1.0", contact={"name": "Kerdos Infrasoft", "url": "https://kerdos.in", "email": "partnership@kerdos.in"}, license_info={"name": "MIT"}, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Singleton engine ─────────────────────────────────────────────────────────── _engine = KerdosRAG() # ── Auth ────────────────────────────────────────────────────────────────────── _API_KEY = os.environ.get("API_KEY", "") def _check_auth(x_api_key: str | None = Header(default=None)) -> None: """If API_KEY env var is set, validate X-Api-Key header.""" if _API_KEY and x_api_key != _API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing X-Api-Key header.") # ── Request / Response models ────────────────────────────────────────────────── class ChatRequest(BaseModel): query: str history: list[dict] | None = None top_k: int | None = None class StatusResponse(BaseModel): indexed_sources: list[str] chunk_count: int model: str top_k: int min_score: float # ── Endpoints ───────────────────────────────────────────────────────────────── @app.get("/health", tags=["Meta"]) def health() -> dict: """Liveness probe — always returns 200 OK.""" return {"status": "ok", "version": "0.1.0"} @app.get("/status", response_model=StatusResponse, tags=["Meta"]) def status(_: None = Depends(_check_auth)) -> StatusResponse: """Return current knowledge-base metadata.""" return StatusResponse( indexed_sources=list(_engine.indexed_sources), chunk_count=_engine.chunk_count, model=_engine.model, top_k=_engine.top_k, min_score=_engine.min_score, ) @app.post("/index", tags=["RAG"]) async def index_documents( files: list[UploadFile] = File(...), _: None = Depends(_check_auth), ) -> JSONResponse: """ Upload and index one or more documents. Accepts: PDF (.pdf), Word (.docx), plain text (.txt, .md, .csv). Duplicate filenames are automatically skipped. """ import tempfile, shutil from pathlib import Path saved_paths: list[str] = [] tmp_dir = tempfile.mkdtemp(prefix="kerdos_upload_") try: for upload in files: dest = Path(tmp_dir) / upload.filename with open(dest, "wb") as f: shutil.copyfileobj(upload.file, f) saved_paths.append(str(dest)) result = _engine.index(saved_paths) finally: shutil.rmtree(tmp_dir, ignore_errors=True) return JSONResponse(content=result) @app.post("/chat", tags=["RAG"]) async def chat(req: ChatRequest, _: None = Depends(_check_auth)) -> StreamingResponse: """ Ask a question and receive a **Server-Sent Events** stream of tokens. Each SSE event has the form: data: \\n\\n The stream ends with: data: [DONE]\\n\\n Example (curl): curl -X POST http://localhost:8000/chat \\ -H "Content-Type: application/json" \\ -d '{"query": "What is the refund policy?"}' \\ --no-buffer """ if not _engine.is_ready: raise HTTPException( status_code=422, detail="Knowledge base is empty. POST files to /index first.", ) hf_token = _engine.hf_token if not hf_token: raise HTTPException( status_code=401, detail="No Hugging Face token configured. Set HF_TOKEN env var.", ) # Temporarily override top_k if caller specified it original_top_k = _engine.top_k if req.top_k is not None: _engine.top_k = req.top_k async def event_generator() -> AsyncGenerator[str, None]: try: # answer_stream is a sync generator — run in thread pool loop = asyncio.get_event_loop() gen = _engine.chat(req.query, history=req.history) while True: try: token = await loop.run_in_executor(None, next, gen) # SSE format: escape newlines in the data value escaped = token.replace("\n", "\\n") yield f"data: {escaped}\n\n" except StopIteration: break finally: _engine.top_k = original_top_k yield "data: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @app.delete("/reset", tags=["RAG"]) def reset(_: None = Depends(_check_auth)) -> dict: """Clear the entire knowledge base.""" _engine.reset() return {"ok": True, "message": "Knowledge base cleared."}