Spaces:
Running
Running
File size: 6,471 Bytes
634117a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
kerdos_rag/server.py
FastAPI REST server exposing the KerdosRAG engine over HTTP.
Endpoints:
GET /health — liveness probe
GET /status — knowledge-base metadata
POST /index — upload + index documents (multipart/form-data)
POST /chat — ask a question (SSE streaming response)
DELETE /reset — clear the knowledge base
Authentication (optional):
Set API_KEY env var to enable X-Api-Key header validation.
Leave unset to run in open mode (suitable for local / trusted environments).
"""
from __future__ import annotations
import os
import asyncio
from typing import AsyncGenerator
from fastapi import FastAPI, File, UploadFile, HTTPException, Depends, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from kerdos_rag.core import KerdosRAG
# ── App & CORS ────────────────────────────────────────────────────────────────
app = FastAPI(
title="Kerdos RAG API",
description="Enterprise Document Q&A engine by Kerdos Infrasoft",
version="0.1.0",
contact={"name": "Kerdos Infrasoft", "url": "https://kerdos.in", "email": "partnership@kerdos.in"},
license_info={"name": "MIT"},
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Singleton engine ───────────────────────────────────────────────────────────
_engine = KerdosRAG()
# ── Auth ──────────────────────────────────────────────────────────────────────
_API_KEY = os.environ.get("API_KEY", "")
def _check_auth(x_api_key: str | None = Header(default=None)) -> None:
"""If API_KEY env var is set, validate X-Api-Key header."""
if _API_KEY and x_api_key != _API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing X-Api-Key header.")
# ── Request / Response models ──────────────────────────────────────────────────
class ChatRequest(BaseModel):
query: str
history: list[dict] | None = None
top_k: int | None = None
class StatusResponse(BaseModel):
indexed_sources: list[str]
chunk_count: int
model: str
top_k: int
min_score: float
# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/health", tags=["Meta"])
def health() -> dict:
"""Liveness probe — always returns 200 OK."""
return {"status": "ok", "version": "0.1.0"}
@app.get("/status", response_model=StatusResponse, tags=["Meta"])
def status(_: None = Depends(_check_auth)) -> StatusResponse:
"""Return current knowledge-base metadata."""
return StatusResponse(
indexed_sources=list(_engine.indexed_sources),
chunk_count=_engine.chunk_count,
model=_engine.model,
top_k=_engine.top_k,
min_score=_engine.min_score,
)
@app.post("/index", tags=["RAG"])
async def index_documents(
files: list[UploadFile] = File(...),
_: None = Depends(_check_auth),
) -> JSONResponse:
"""
Upload and index one or more documents.
Accepts: PDF (.pdf), Word (.docx), plain text (.txt, .md, .csv).
Duplicate filenames are automatically skipped.
"""
import tempfile, shutil
from pathlib import Path
saved_paths: list[str] = []
tmp_dir = tempfile.mkdtemp(prefix="kerdos_upload_")
try:
for upload in files:
dest = Path(tmp_dir) / upload.filename
with open(dest, "wb") as f:
shutil.copyfileobj(upload.file, f)
saved_paths.append(str(dest))
result = _engine.index(saved_paths)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
return JSONResponse(content=result)
@app.post("/chat", tags=["RAG"])
async def chat(req: ChatRequest, _: None = Depends(_check_auth)) -> StreamingResponse:
"""
Ask a question and receive a **Server-Sent Events** stream of tokens.
Each SSE event has the form:
data: <partial answer so far>\\n\\n
The stream ends with:
data: [DONE]\\n\\n
Example (curl):
curl -X POST http://localhost:8000/chat \\
-H "Content-Type: application/json" \\
-d '{"query": "What is the refund policy?"}' \\
--no-buffer
"""
if not _engine.is_ready:
raise HTTPException(
status_code=422,
detail="Knowledge base is empty. POST files to /index first.",
)
hf_token = _engine.hf_token
if not hf_token:
raise HTTPException(
status_code=401,
detail="No Hugging Face token configured. Set HF_TOKEN env var.",
)
# Temporarily override top_k if caller specified it
original_top_k = _engine.top_k
if req.top_k is not None:
_engine.top_k = req.top_k
async def event_generator() -> AsyncGenerator[str, None]:
try:
# answer_stream is a sync generator — run in thread pool
loop = asyncio.get_event_loop()
gen = _engine.chat(req.query, history=req.history)
while True:
try:
token = await loop.run_in_executor(None, next, gen)
# SSE format: escape newlines in the data value
escaped = token.replace("\n", "\\n")
yield f"data: {escaped}\n\n"
except StopIteration:
break
finally:
_engine.top_k = original_top_k
yield "data: [DONE]\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.delete("/reset", tags=["RAG"])
def reset(_: None = Depends(_check_auth)) -> dict:
"""Clear the entire knowledge base."""
_engine.reset()
return {"ok": True, "message": "Knowledge base cleared."}
|