Bhaskar Ram
feat: Kerdos AI RAG API v1.0
b1a3dce
"""
Kerdos AI — Custom LLM Chat REST API
FastAPI application exposing the full RAG pipeline as HTTP endpoints.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, File, HTTPException, Path, UploadFile, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from models import (
ChatRequest,
ChatResponse,
HealthResponse,
IndexResponse,
MessageResponse,
SessionCreateResponse,
SessionStatusResponse,
Source,
)
from rag_core import call_llm
from sessions import store
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s — %(message)s",
)
logger = logging.getLogger("kerdos.api")
_START_TIME = time.time()
API_VERSION = "1.0.0"
# ── Lifespan: background cleanup task ────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Start a background task that purges expired sessions every 10 minutes."""
async def _cleanup_loop():
while True:
await asyncio.sleep(600)
removed = store.cleanup_expired()
if removed:
logger.info(f"Cleaned up {removed} expired session(s).")
task = asyncio.create_task(_cleanup_loop())
logger.info("Kerdos AI RAG API started.")
yield
task.cancel()
logger.info("Kerdos AI RAG API shutting down.")
# ── App ───────────────────────────────────────────────────────────────────────
app = FastAPI(
title="Kerdos AI — Custom LLM RAG API",
description=(
"REST API for the Kerdos AI document Q&A system.\n\n"
"Upload your documents, index them, and ask questions — "
"answers are strictly grounded in your uploaded content.\n\n"
"**LLM**: `meta-llama/Llama-3.1-8B-Instruct` via HuggingFace Inference API \n"
"**Embeddings**: `sentence-transformers/all-MiniLM-L6-v2` \n"
"**Vector Store**: FAISS (in-memory, per-session) \n\n"
"© 2024–2025 [Kerdos Infrasoft Private Limited](https://kerdos.in)"
),
version=API_VERSION,
contact={
"name": "Kerdos Infrasoft",
"url": "https://kerdos.in/contact",
"email": "partnership@kerdos.in",
},
license_info={"name": "MIT"},
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024
ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md", ".csv"}
# ── Helpers ───────────────────────────────────────────────────────────────────
def _get_session_or_404(session_id: str):
try:
return store.get(session_id)
except KeyError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Session '{session_id}' not found or has expired.",
)
# ── Routes ────────────────────────────────────────────────────────────────────
@app.get(
"/",
tags=["Info"],
summary="API root",
response_model=dict,
)
async def root():
return {
"name": "Kerdos AI RAG API",
"version": API_VERSION,
"docs": "/docs",
"health": "/health",
"website": "https://kerdos.in",
}
@app.get(
"/health",
tags=["Info"],
summary="Health check",
response_model=HealthResponse,
)
async def health():
return HealthResponse(
status="ok",
version=API_VERSION,
uptime_seconds=round(time.time() - _START_TIME, 2),
active_sessions=store.active_count,
)
# ── Sessions ──────────────────────────────────────────────────────────────────
@app.post(
"/sessions",
tags=["Sessions"],
summary="Create a new RAG session",
response_model=SessionCreateResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_session():
"""
Creates a new isolated session with its own FAISS index and conversation history.
Returns a `session_id` that must be passed to all subsequent requests.
"""
sid = store.create()
logger.info(f"Session created: {sid}")
return SessionCreateResponse(session_id=sid)
@app.get(
"/sessions/{session_id}",
tags=["Sessions"],
summary="Get session status",
response_model=SessionStatusResponse,
)
async def get_session(session_id: str = Path(..., description="Session ID")):
"""Returns metadata about the session: document count, chunk count, history length, TTL."""
rag, _ = _get_session_or_404(session_id)
meta = store.get_meta(session_id)
return SessionStatusResponse(
session_id=session_id,
document_count=rag.document_count,
chunk_count=rag.chunk_count,
history_length=len(rag.history),
created_at=meta["created_at"],
expires_at=meta["expires_at"],
)
@app.delete(
"/sessions/{session_id}",
tags=["Sessions"],
summary="Delete a session",
response_model=MessageResponse,
)
async def delete_session(session_id: str = Path(...)):
"""Immediately removes the session and frees all in-memory resources."""
deleted = store.delete(session_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
logger.info(f"Session deleted: {session_id}")
return MessageResponse(message=f"Session '{session_id}' deleted.")
# ── Documents ─────────────────────────────────────────────────────────────────
@app.post(
"/sessions/{session_id}/documents",
tags=["Documents"],
summary="Upload and index documents",
response_model=IndexResponse,
)
async def upload_documents(
session_id: str = Path(..., description="Session ID"),
files: list[UploadFile] = File(..., description="Files to index (PDF, DOCX, TXT, MD, CSV)"),
):
"""
Upload one or more files to the session's FAISS index.
Supported formats: PDF, DOCX, TXT, MD, CSV.
Can be called multiple times to add more documents to an existing index.
"""
rag, lock = _get_session_or_404(session_id)
file_pairs: list[tuple[str, bytes]] = []
oversized: list[str] = []
for upload in files:
content = await upload.read()
if len(content) > MAX_UPLOAD_BYTES:
oversized.append(upload.filename or "unknown")
continue
from pathlib import Path as P
ext = P(upload.filename or "").suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"File '{upload.filename}' has unsupported type '{ext}'. "
f"Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}",
)
file_pairs.append((upload.filename or "unnamed", content))
if oversized:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"Files exceed {os.getenv('MAX_UPLOAD_MB', '50')} MB limit: {oversized}",
)
if not file_pairs:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No valid files provided.",
)
# Index in a thread so we don't block the event loop (FAISS + embeddings are CPU-bound)
loop = asyncio.get_event_loop()
def _index():
with lock:
return rag.index_documents(file_pairs)
indexed, failed = await loop.run_in_executor(None, _index)
logger.info(f"[{session_id}] Indexed {len(indexed)} file(s), failed: {len(failed)}")
return IndexResponse(
session_id=session_id,
indexed_files=indexed,
failed_files=failed,
chunk_count=rag.chunk_count,
)
# ── Chat ──────────────────────────────────────────────────────────────────────
@app.post(
"/sessions/{session_id}/chat",
tags=["Chat"],
summary="Ask a question about your documents",
response_model=ChatResponse,
)
async def chat(
session_id: str = Path(..., description="Session ID"),
body: ChatRequest = ...,
):
"""
Retrieves the most relevant document chunks and uses Llama 3.1 8B to generate
an answer strictly grounded in those chunks.
**Requires a HuggingFace token** with Write access and acceptance of the
[Llama 3.1 license](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).
"""
rag, lock = _get_session_or_404(session_id)
loop = asyncio.get_event_loop()
def _run_rag():
with lock:
# 1. Retrieve relevant chunks
try:
top_chunks = rag.query(body.question, top_k=body.top_k)
except RuntimeError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(exc),
)
# 2. Call LLM
try:
answer = call_llm(
context_chunks=top_chunks,
question=body.question,
history=rag.history,
hf_token=body.hf_token,
temperature=body.temperature,
max_new_tokens=body.max_new_tokens,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc))
except RuntimeError as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc))
# 3. Persist to history
rag.add_turn(body.question, answer)
# 4. Build source citations
sources = [
Source(
filename=c.filename,
chunk_index=c.chunk_index,
excerpt=c.text[:200] + ("…" if len(c.text) > 200 else ""),
)
for c in top_chunks
]
return answer, sources
answer, sources = await loop.run_in_executor(None, _run_rag)
logger.info(f"[{session_id}] Q: {body.question[:60]}…")
return ChatResponse(
session_id=session_id,
question=body.question,
answer=answer,
sources=sources,
)
@app.delete(
"/sessions/{session_id}/history",
tags=["Chat"],
summary="Clear conversation history",
response_model=MessageResponse,
)
async def clear_history(session_id: str = Path(...)):
"""Clears the multi-turn conversation history for the session (keeps the FAISS index intact)."""
rag, lock = _get_session_or_404(session_id)
with lock:
rag.clear_history()
return MessageResponse(message="Conversation history cleared.")
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api:app",
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
reload=False,
log_level="info",
)