""" Kerdos AI — Custom LLM Chat REST API FastAPI application exposing the full RAG pipeline as HTTP endpoints. """ from __future__ import annotations import asyncio import logging import os import time from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import FastAPI, File, HTTPException, Path, UploadFile, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from models import ( ChatRequest, ChatResponse, HealthResponse, IndexResponse, MessageResponse, SessionCreateResponse, SessionStatusResponse, Source, ) from rag_core import call_llm from sessions import store load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s — %(message)s", ) logger = logging.getLogger("kerdos.api") _START_TIME = time.time() API_VERSION = "1.0.0" # ── Lifespan: background cleanup task ──────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): """Start a background task that purges expired sessions every 10 minutes.""" async def _cleanup_loop(): while True: await asyncio.sleep(600) removed = store.cleanup_expired() if removed: logger.info(f"Cleaned up {removed} expired session(s).") task = asyncio.create_task(_cleanup_loop()) logger.info("Kerdos AI RAG API started.") yield task.cancel() logger.info("Kerdos AI RAG API shutting down.") # ── App ─────────────────────────────────────────────────────────────────────── app = FastAPI( title="Kerdos AI — Custom LLM RAG API", description=( "REST API for the Kerdos AI document Q&A system.\n\n" "Upload your documents, index them, and ask questions — " "answers are strictly grounded in your uploaded content.\n\n" "**LLM**: `meta-llama/Llama-3.1-8B-Instruct` via HuggingFace Inference API \n" "**Embeddings**: `sentence-transformers/all-MiniLM-L6-v2` \n" "**Vector Store**: FAISS (in-memory, per-session) \n\n" "© 2024–2025 [Kerdos Infrasoft Private Limited](https://kerdos.in)" ), version=API_VERSION, contact={ "name": "Kerdos Infrasoft", "url": "https://kerdos.in/contact", "email": "partnership@kerdos.in", }, license_info={"name": "MIT"}, lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024 ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md", ".csv"} # ── Helpers ─────────────────────────────────────────────────────────────────── def _get_session_or_404(session_id: str): try: return store.get(session_id) except KeyError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Session '{session_id}' not found or has expired.", ) # ── Routes ──────────────────────────────────────────────────────────────────── @app.get( "/", tags=["Info"], summary="API root", response_model=dict, ) async def root(): return { "name": "Kerdos AI RAG API", "version": API_VERSION, "docs": "/docs", "health": "/health", "website": "https://kerdos.in", } @app.get( "/health", tags=["Info"], summary="Health check", response_model=HealthResponse, ) async def health(): return HealthResponse( status="ok", version=API_VERSION, uptime_seconds=round(time.time() - _START_TIME, 2), active_sessions=store.active_count, ) # ── Sessions ────────────────────────────────────────────────────────────────── @app.post( "/sessions", tags=["Sessions"], summary="Create a new RAG session", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED, ) async def create_session(): """ Creates a new isolated session with its own FAISS index and conversation history. Returns a `session_id` that must be passed to all subsequent requests. """ sid = store.create() logger.info(f"Session created: {sid}") return SessionCreateResponse(session_id=sid) @app.get( "/sessions/{session_id}", tags=["Sessions"], summary="Get session status", response_model=SessionStatusResponse, ) async def get_session(session_id: str = Path(..., description="Session ID")): """Returns metadata about the session: document count, chunk count, history length, TTL.""" rag, _ = _get_session_or_404(session_id) meta = store.get_meta(session_id) return SessionStatusResponse( session_id=session_id, document_count=rag.document_count, chunk_count=rag.chunk_count, history_length=len(rag.history), created_at=meta["created_at"], expires_at=meta["expires_at"], ) @app.delete( "/sessions/{session_id}", tags=["Sessions"], summary="Delete a session", response_model=MessageResponse, ) async def delete_session(session_id: str = Path(...)): """Immediately removes the session and frees all in-memory resources.""" deleted = store.delete(session_id) if not deleted: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") logger.info(f"Session deleted: {session_id}") return MessageResponse(message=f"Session '{session_id}' deleted.") # ── Documents ───────────────────────────────────────────────────────────────── @app.post( "/sessions/{session_id}/documents", tags=["Documents"], summary="Upload and index documents", response_model=IndexResponse, ) async def upload_documents( session_id: str = Path(..., description="Session ID"), files: list[UploadFile] = File(..., description="Files to index (PDF, DOCX, TXT, MD, CSV)"), ): """ Upload one or more files to the session's FAISS index. Supported formats: PDF, DOCX, TXT, MD, CSV. Can be called multiple times to add more documents to an existing index. """ rag, lock = _get_session_or_404(session_id) file_pairs: list[tuple[str, bytes]] = [] oversized: list[str] = [] for upload in files: content = await upload.read() if len(content) > MAX_UPLOAD_BYTES: oversized.append(upload.filename or "unknown") continue from pathlib import Path as P ext = P(upload.filename or "").suffix.lower() if ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=f"File '{upload.filename}' has unsupported type '{ext}'. " f"Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}", ) file_pairs.append((upload.filename or "unnamed", content)) if oversized: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail=f"Files exceed {os.getenv('MAX_UPLOAD_MB', '50')} MB limit: {oversized}", ) if not file_pairs: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No valid files provided.", ) # Index in a thread so we don't block the event loop (FAISS + embeddings are CPU-bound) loop = asyncio.get_event_loop() def _index(): with lock: return rag.index_documents(file_pairs) indexed, failed = await loop.run_in_executor(None, _index) logger.info(f"[{session_id}] Indexed {len(indexed)} file(s), failed: {len(failed)}") return IndexResponse( session_id=session_id, indexed_files=indexed, failed_files=failed, chunk_count=rag.chunk_count, ) # ── Chat ────────────────────────────────────────────────────────────────────── @app.post( "/sessions/{session_id}/chat", tags=["Chat"], summary="Ask a question about your documents", response_model=ChatResponse, ) async def chat( session_id: str = Path(..., description="Session ID"), body: ChatRequest = ..., ): """ Retrieves the most relevant document chunks and uses Llama 3.1 8B to generate an answer strictly grounded in those chunks. **Requires a HuggingFace token** with Write access and acceptance of the [Llama 3.1 license](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). """ rag, lock = _get_session_or_404(session_id) loop = asyncio.get_event_loop() def _run_rag(): with lock: # 1. Retrieve relevant chunks try: top_chunks = rag.query(body.question, top_k=body.top_k) except RuntimeError as exc: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc), ) # 2. Call LLM try: answer = call_llm( context_chunks=top_chunks, question=body.question, history=rag.history, hf_token=body.hf_token, temperature=body.temperature, max_new_tokens=body.max_new_tokens, ) except ValueError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) except RuntimeError as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) # 3. Persist to history rag.add_turn(body.question, answer) # 4. Build source citations sources = [ Source( filename=c.filename, chunk_index=c.chunk_index, excerpt=c.text[:200] + ("…" if len(c.text) > 200 else ""), ) for c in top_chunks ] return answer, sources answer, sources = await loop.run_in_executor(None, _run_rag) logger.info(f"[{session_id}] Q: {body.question[:60]}…") return ChatResponse( session_id=session_id, question=body.question, answer=answer, sources=sources, ) @app.delete( "/sessions/{session_id}/history", tags=["Chat"], summary="Clear conversation history", response_model=MessageResponse, ) async def clear_history(session_id: str = Path(...)): """Clears the multi-turn conversation history for the session (keeps the FAISS index intact).""" rag, lock = _get_session_or_404(session_id) with lock: rag.clear_history() return MessageResponse(message="Conversation history cleared.") # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run( "api:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "8000")), reload=False, log_level="info", )