Spaces:
Running
Running
| """ | |
| Kerdos AI — Custom LLM Chat REST API | |
| FastAPI application exposing the full RAG pipeline as HTTP endpoints. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, File, HTTPException, Path, UploadFile, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from models import ( | |
| ChatRequest, | |
| ChatResponse, | |
| HealthResponse, | |
| IndexResponse, | |
| MessageResponse, | |
| SessionCreateResponse, | |
| SessionStatusResponse, | |
| Source, | |
| ) | |
| from rag_core import call_llm | |
| from sessions import store | |
| load_dotenv() | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-8s | %(name)s — %(message)s", | |
| ) | |
| logger = logging.getLogger("kerdos.api") | |
| _START_TIME = time.time() | |
| API_VERSION = "1.0.0" | |
| # ── Lifespan: background cleanup task ──────────────────────────────────────── | |
| async def lifespan(app: FastAPI): | |
| """Start a background task that purges expired sessions every 10 minutes.""" | |
| async def _cleanup_loop(): | |
| while True: | |
| await asyncio.sleep(600) | |
| removed = store.cleanup_expired() | |
| if removed: | |
| logger.info(f"Cleaned up {removed} expired session(s).") | |
| task = asyncio.create_task(_cleanup_loop()) | |
| logger.info("Kerdos AI RAG API started.") | |
| yield | |
| task.cancel() | |
| logger.info("Kerdos AI RAG API shutting down.") | |
| # ── App ─────────────────────────────────────────────────────────────────────── | |
| app = FastAPI( | |
| title="Kerdos AI — Custom LLM RAG API", | |
| description=( | |
| "REST API for the Kerdos AI document Q&A system.\n\n" | |
| "Upload your documents, index them, and ask questions — " | |
| "answers are strictly grounded in your uploaded content.\n\n" | |
| "**LLM**: `meta-llama/Llama-3.1-8B-Instruct` via HuggingFace Inference API \n" | |
| "**Embeddings**: `sentence-transformers/all-MiniLM-L6-v2` \n" | |
| "**Vector Store**: FAISS (in-memory, per-session) \n\n" | |
| "© 2024–2025 [Kerdos Infrasoft Private Limited](https://kerdos.in)" | |
| ), | |
| version=API_VERSION, | |
| contact={ | |
| "name": "Kerdos Infrasoft", | |
| "url": "https://kerdos.in/contact", | |
| "email": "partnership@kerdos.in", | |
| }, | |
| license_info={"name": "MIT"}, | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024 | |
| ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md", ".csv"} | |
| # ── Helpers ─────────────────────────────────────────────────────────────────── | |
| def _get_session_or_404(session_id: str): | |
| try: | |
| return store.get(session_id) | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Session '{session_id}' not found or has expired.", | |
| ) | |
| # ── Routes ──────────────────────────────────────────────────────────────────── | |
| async def root(): | |
| return { | |
| "name": "Kerdos AI RAG API", | |
| "version": API_VERSION, | |
| "docs": "/docs", | |
| "health": "/health", | |
| "website": "https://kerdos.in", | |
| } | |
| async def health(): | |
| return HealthResponse( | |
| status="ok", | |
| version=API_VERSION, | |
| uptime_seconds=round(time.time() - _START_TIME, 2), | |
| active_sessions=store.active_count, | |
| ) | |
| # ── Sessions ────────────────────────────────────────────────────────────────── | |
| async def create_session(): | |
| """ | |
| Creates a new isolated session with its own FAISS index and conversation history. | |
| Returns a `session_id` that must be passed to all subsequent requests. | |
| """ | |
| sid = store.create() | |
| logger.info(f"Session created: {sid}") | |
| return SessionCreateResponse(session_id=sid) | |
| async def get_session(session_id: str = Path(..., description="Session ID")): | |
| """Returns metadata about the session: document count, chunk count, history length, TTL.""" | |
| rag, _ = _get_session_or_404(session_id) | |
| meta = store.get_meta(session_id) | |
| return SessionStatusResponse( | |
| session_id=session_id, | |
| document_count=rag.document_count, | |
| chunk_count=rag.chunk_count, | |
| history_length=len(rag.history), | |
| created_at=meta["created_at"], | |
| expires_at=meta["expires_at"], | |
| ) | |
| async def delete_session(session_id: str = Path(...)): | |
| """Immediately removes the session and frees all in-memory resources.""" | |
| deleted = store.delete(session_id) | |
| if not deleted: | |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") | |
| logger.info(f"Session deleted: {session_id}") | |
| return MessageResponse(message=f"Session '{session_id}' deleted.") | |
| # ── Documents ───────────────────────────────────────────────────────────────── | |
| async def upload_documents( | |
| session_id: str = Path(..., description="Session ID"), | |
| files: list[UploadFile] = File(..., description="Files to index (PDF, DOCX, TXT, MD, CSV)"), | |
| ): | |
| """ | |
| Upload one or more files to the session's FAISS index. | |
| Supported formats: PDF, DOCX, TXT, MD, CSV. | |
| Can be called multiple times to add more documents to an existing index. | |
| """ | |
| rag, lock = _get_session_or_404(session_id) | |
| file_pairs: list[tuple[str, bytes]] = [] | |
| oversized: list[str] = [] | |
| for upload in files: | |
| content = await upload.read() | |
| if len(content) > MAX_UPLOAD_BYTES: | |
| oversized.append(upload.filename or "unknown") | |
| continue | |
| from pathlib import Path as P | |
| ext = P(upload.filename or "").suffix.lower() | |
| if ext not in ALLOWED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, | |
| detail=f"File '{upload.filename}' has unsupported type '{ext}'. " | |
| f"Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}", | |
| ) | |
| file_pairs.append((upload.filename or "unnamed", content)) | |
| if oversized: | |
| raise HTTPException( | |
| status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | |
| detail=f"Files exceed {os.getenv('MAX_UPLOAD_MB', '50')} MB limit: {oversized}", | |
| ) | |
| if not file_pairs: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="No valid files provided.", | |
| ) | |
| # Index in a thread so we don't block the event loop (FAISS + embeddings are CPU-bound) | |
| loop = asyncio.get_event_loop() | |
| def _index(): | |
| with lock: | |
| return rag.index_documents(file_pairs) | |
| indexed, failed = await loop.run_in_executor(None, _index) | |
| logger.info(f"[{session_id}] Indexed {len(indexed)} file(s), failed: {len(failed)}") | |
| return IndexResponse( | |
| session_id=session_id, | |
| indexed_files=indexed, | |
| failed_files=failed, | |
| chunk_count=rag.chunk_count, | |
| ) | |
| # ── Chat ────────────────────────────────────────────────────────────────────── | |
| async def chat( | |
| session_id: str = Path(..., description="Session ID"), | |
| body: ChatRequest = ..., | |
| ): | |
| """ | |
| Retrieves the most relevant document chunks and uses Llama 3.1 8B to generate | |
| an answer strictly grounded in those chunks. | |
| **Requires a HuggingFace token** with Write access and acceptance of the | |
| [Llama 3.1 license](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct). | |
| """ | |
| rag, lock = _get_session_or_404(session_id) | |
| loop = asyncio.get_event_loop() | |
| def _run_rag(): | |
| with lock: | |
| # 1. Retrieve relevant chunks | |
| try: | |
| top_chunks = rag.query(body.question, top_k=body.top_k) | |
| except RuntimeError as exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=str(exc), | |
| ) | |
| # 2. Call LLM | |
| try: | |
| answer = call_llm( | |
| context_chunks=top_chunks, | |
| question=body.question, | |
| history=rag.history, | |
| hf_token=body.hf_token, | |
| temperature=body.temperature, | |
| max_new_tokens=body.max_new_tokens, | |
| ) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)) | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) | |
| # 3. Persist to history | |
| rag.add_turn(body.question, answer) | |
| # 4. Build source citations | |
| sources = [ | |
| Source( | |
| filename=c.filename, | |
| chunk_index=c.chunk_index, | |
| excerpt=c.text[:200] + ("…" if len(c.text) > 200 else ""), | |
| ) | |
| for c in top_chunks | |
| ] | |
| return answer, sources | |
| answer, sources = await loop.run_in_executor(None, _run_rag) | |
| logger.info(f"[{session_id}] Q: {body.question[:60]}…") | |
| return ChatResponse( | |
| session_id=session_id, | |
| question=body.question, | |
| answer=answer, | |
| sources=sources, | |
| ) | |
| async def clear_history(session_id: str = Path(...)): | |
| """Clears the multi-turn conversation history for the session (keeps the FAISS index intact).""" | |
| rag, lock = _get_session_or_404(session_id) | |
| with lock: | |
| rag.clear_history() | |
| return MessageResponse(message="Conversation history cleared.") | |
| # ── Entry point ─────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "api:app", | |
| host=os.getenv("HOST", "0.0.0.0"), | |
| port=int(os.getenv("PORT", "8000")), | |
| reload=False, | |
| log_level="info", | |
| ) | |