Spaces:
Running
Running
| """ | |
| kerdos_rag/server.py | |
| FastAPI REST server exposing the KerdosRAG engine over HTTP. | |
| Endpoints: | |
| GET /health — liveness probe | |
| GET /status — knowledge-base metadata | |
| POST /index — upload + index documents (multipart/form-data) | |
| POST /chat — ask a question (SSE streaming response) | |
| DELETE /reset — clear the knowledge base | |
| Authentication (optional): | |
| Set API_KEY env var to enable X-Api-Key header validation. | |
| Leave unset to run in open mode (suitable for local / trusted environments). | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import asyncio | |
| from typing import AsyncGenerator | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Header, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from kerdos_rag.core import KerdosRAG | |
| # ── App & CORS ──────────────────────────────────────────────────────────────── | |
| app = FastAPI( | |
| title="Kerdos RAG API", | |
| description="Enterprise Document Q&A engine by Kerdos Infrasoft", | |
| version="0.1.0", | |
| contact={"name": "Kerdos Infrasoft", "url": "https://kerdos.in", "email": "partnership@kerdos.in"}, | |
| license_info={"name": "MIT"}, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ── Singleton engine ─────────────────────────────────────────────────────────── | |
| _engine = KerdosRAG() | |
| # ── Auth ────────────────────────────────────────────────────────────────────── | |
| _API_KEY = os.environ.get("API_KEY", "") | |
| def _check_auth(x_api_key: str | None = Header(default=None)) -> None: | |
| """If API_KEY env var is set, validate X-Api-Key header.""" | |
| if _API_KEY and x_api_key != _API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid or missing X-Api-Key header.") | |
| # ── Request / Response models ────────────────────────────────────────────────── | |
| class ChatRequest(BaseModel): | |
| query: str | |
| history: list[dict] | None = None | |
| top_k: int | None = None | |
| class StatusResponse(BaseModel): | |
| indexed_sources: list[str] | |
| chunk_count: int | |
| model: str | |
| top_k: int | |
| min_score: float | |
| # ── Endpoints ───────────────────────────────────────────────────────────────── | |
| def health() -> dict: | |
| """Liveness probe — always returns 200 OK.""" | |
| return {"status": "ok", "version": "0.1.0"} | |
| def status(_: None = Depends(_check_auth)) -> StatusResponse: | |
| """Return current knowledge-base metadata.""" | |
| return StatusResponse( | |
| indexed_sources=list(_engine.indexed_sources), | |
| chunk_count=_engine.chunk_count, | |
| model=_engine.model, | |
| top_k=_engine.top_k, | |
| min_score=_engine.min_score, | |
| ) | |
| async def index_documents( | |
| files: list[UploadFile] = File(...), | |
| _: None = Depends(_check_auth), | |
| ) -> JSONResponse: | |
| """ | |
| Upload and index one or more documents. | |
| Accepts: PDF (.pdf), Word (.docx), plain text (.txt, .md, .csv). | |
| Duplicate filenames are automatically skipped. | |
| """ | |
| import tempfile, shutil | |
| from pathlib import Path | |
| saved_paths: list[str] = [] | |
| tmp_dir = tempfile.mkdtemp(prefix="kerdos_upload_") | |
| try: | |
| for upload in files: | |
| dest = Path(tmp_dir) / upload.filename | |
| with open(dest, "wb") as f: | |
| shutil.copyfileobj(upload.file, f) | |
| saved_paths.append(str(dest)) | |
| result = _engine.index(saved_paths) | |
| finally: | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| return JSONResponse(content=result) | |
| async def chat(req: ChatRequest, _: None = Depends(_check_auth)) -> StreamingResponse: | |
| """ | |
| Ask a question and receive a **Server-Sent Events** stream of tokens. | |
| Each SSE event has the form: | |
| data: <partial answer so far>\\n\\n | |
| The stream ends with: | |
| data: [DONE]\\n\\n | |
| Example (curl): | |
| curl -X POST http://localhost:8000/chat \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{"query": "What is the refund policy?"}' \\ | |
| --no-buffer | |
| """ | |
| if not _engine.is_ready: | |
| raise HTTPException( | |
| status_code=422, | |
| detail="Knowledge base is empty. POST files to /index first.", | |
| ) | |
| hf_token = _engine.hf_token | |
| if not hf_token: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="No Hugging Face token configured. Set HF_TOKEN env var.", | |
| ) | |
| # Temporarily override top_k if caller specified it | |
| original_top_k = _engine.top_k | |
| if req.top_k is not None: | |
| _engine.top_k = req.top_k | |
| async def event_generator() -> AsyncGenerator[str, None]: | |
| try: | |
| # answer_stream is a sync generator — run in thread pool | |
| loop = asyncio.get_event_loop() | |
| gen = _engine.chat(req.query, history=req.history) | |
| while True: | |
| try: | |
| token = await loop.run_in_executor(None, next, gen) | |
| # SSE format: escape newlines in the data value | |
| escaped = token.replace("\n", "\\n") | |
| yield f"data: {escaped}\n\n" | |
| except StopIteration: | |
| break | |
| finally: | |
| _engine.top_k = original_top_k | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| def reset(_: None = Depends(_check_auth)) -> dict: | |
| """Clear the entire knowledge base.""" | |
| _engine.reset() | |
| return {"ok": True, "message": "Knowledge base cleared."} | |