Spaces:
Running
Running
Bhaskar Ram commited on
Commit ·
b1a3dce
0
Parent(s):
feat: Kerdos AI RAG API v1.0
Browse files- .env.example +19 -0
- .gitignore +41 -0
- Dockerfile +35 -0
- README.md +119 -0
- api.py +366 -0
- models.py +73 -0
- rag_core.py +313 -0
- requirements.txt +27 -0
- sessions.py +102 -0
.env.example
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ─── Kerdos AI RAG API — Environment Variables ───────────────────────────────
|
| 2 |
+
|
| 3 |
+
# Your Hugging Face API token (Write access required for Llama 3.1)
|
| 4 |
+
# Get yours at: https://huggingface.co/settings/tokens
|
| 5 |
+
# You must also accept the Llama 3.1 license:
|
| 6 |
+
# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
|
| 7 |
+
HF_TOKEN=hf_your_token_here
|
| 8 |
+
|
| 9 |
+
# Session time-to-live in minutes (default: 60)
|
| 10 |
+
SESSION_TTL_MINUTES=60
|
| 11 |
+
|
| 12 |
+
# Maximum file size for uploads in megabytes (default: 50)
|
| 13 |
+
MAX_UPLOAD_MB=50
|
| 14 |
+
|
| 15 |
+
# Server bind address
|
| 16 |
+
HOST=0.0.0.0
|
| 17 |
+
|
| 18 |
+
# Server port
|
| 19 |
+
PORT=8000
|
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg
|
| 11 |
+
|
| 12 |
+
# Env
|
| 13 |
+
.env
|
| 14 |
+
*.env.local
|
| 15 |
+
|
| 16 |
+
# Test artifacts
|
| 17 |
+
.pytest_cache/
|
| 18 |
+
.coverage
|
| 19 |
+
htmlcov/
|
| 20 |
+
|
| 21 |
+
# IDEs
|
| 22 |
+
.vscode/
|
| 23 |
+
.idea/
|
| 24 |
+
*.suo
|
| 25 |
+
*.user
|
| 26 |
+
|
| 27 |
+
# OS
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
|
| 31 |
+
# API test file (contains token)
|
| 32 |
+
api.txt
|
| 33 |
+
|
| 34 |
+
# Stray files from curl test commands
|
| 35 |
+
files-@*
|
| 36 |
+
|
| 37 |
+
# Sample doc (don't need in repo)
|
| 38 |
+
sample_doc.txt
|
| 39 |
+
|
| 40 |
+
# Uploaded files (never persisted, but just in case)
|
| 41 |
+
uploads/
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# Kerdos AI — Custom LLM RAG API
|
| 3 |
+
|
| 4 |
+
FROM python:3.11
|
| 5 |
+
|
| 6 |
+
# Create non-root user as HF Spaces recommends
|
| 7 |
+
RUN useradd -m -u 1000 user
|
| 8 |
+
USER user
|
| 9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Install OS-level dependency for faiss at runtime
|
| 14 |
+
# (must be done before switching to non-root, but faiss-cpu binary wheel
|
| 15 |
+
# includes its own libgomp so extra system libs aren't needed on py3.11-slim)
|
| 16 |
+
|
| 17 |
+
# Install Python dependencies first (Docker cache layer)
|
| 18 |
+
COPY --chown=user requirements.txt .
|
| 19 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 20 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 21 |
+
|
| 22 |
+
# Pre-download embedding model at build time (avoids cold-start delay)
|
| 23 |
+
RUN python -c "from sentence_transformers import SentenceTransformer; \
|
| 24 |
+
SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
|
| 25 |
+
|
| 26 |
+
# Copy application source
|
| 27 |
+
COPY --chown=user api.py models.py rag_core.py sessions.py ./
|
| 28 |
+
|
| 29 |
+
# HF Spaces required port
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
ENV HOST=0.0.0.0 \
|
| 33 |
+
PORT=7860
|
| 34 |
+
|
| 35 |
+
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Kerdos AI — Custom LLM RAG API
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
tags:
|
| 11 |
+
- rag
|
| 12 |
+
- document-qa
|
| 13 |
+
- fastapi
|
| 14 |
+
- llama
|
| 15 |
+
- faiss
|
| 16 |
+
- nlp
|
| 17 |
+
- question-answering
|
| 18 |
+
- kerdos
|
| 19 |
+
- private-llm
|
| 20 |
+
- api
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
# 🤖 Kerdos AI — Custom LLM RAG API
|
| 24 |
+
|
| 25 |
+
> **A REST API by [Kerdos Infrasoft Private Limited](https://kerdos.in)**
|
| 26 |
+
> Upload documents. Ask questions. Get answers — strictly grounded in your data.
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## ✨ Features
|
| 31 |
+
|
| 32 |
+
| | |
|
| 33 |
+
| -------------------- | ---------------------------------------------------------- |
|
| 34 |
+
| 📄 **Multi-format** | PDF, DOCX, TXT, MD, CSV |
|
| 35 |
+
| 🧠 **LLM** | `meta-llama/Llama-3.1-8B-Instruct` via HF Inference Router |
|
| 36 |
+
| 🔒 **Grounded** | Answers only from your uploaded documents |
|
| 37 |
+
| 💬 **Multi-turn** | Conversation history per session |
|
| 38 |
+
| ⚡ **Fast** | `all-MiniLM-L6-v2` + FAISS in-memory |
|
| 39 |
+
| 🔑 **Session-based** | Each client gets an isolated FAISS index |
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 📡 API Reference
|
| 44 |
+
|
| 45 |
+
Interactive docs → `/docs` (Swagger UI)
|
| 46 |
+
|
| 47 |
+
| Method | Path | Description |
|
| 48 |
+
| -------- | -------------------------- | ----------------------------------- |
|
| 49 |
+
| `POST` | `/sessions` | Create a session → get `session_id` |
|
| 50 |
+
| `GET` | `/sessions/{id}` | Session status |
|
| 51 |
+
| `DELETE` | `/sessions/{id}` | Delete session |
|
| 52 |
+
| `POST` | `/sessions/{id}/documents` | Upload & index files |
|
| 53 |
+
| `POST` | `/sessions/{id}/chat` | Ask a question |
|
| 54 |
+
| `DELETE` | `/sessions/{id}/history` | Clear chat history |
|
| 55 |
+
| `GET` | `/health` | Health check |
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## 🔁 Typical Workflow
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
BASE=https://kerdosdotio-kerdos-llm-rag-api.hf.space
|
| 63 |
+
|
| 64 |
+
# 1. Create session
|
| 65 |
+
curl -X POST $BASE/sessions
|
| 66 |
+
|
| 67 |
+
# 2. Upload a document
|
| 68 |
+
curl -X POST "$BASE/sessions/{session_id}/documents" \
|
| 69 |
+
-F "files=@your_doc.pdf"
|
| 70 |
+
|
| 71 |
+
# 3. Ask a question
|
| 72 |
+
curl -X POST "$BASE/sessions/{session_id}/chat" \
|
| 73 |
+
-H "Content-Type: application/json" \
|
| 74 |
+
-d '{"question": "Summarise this document", "hf_token": "hf_..."}'
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## ⚙️ Environment / Secrets
|
| 80 |
+
|
| 81 |
+
Set these in **Settings → Variables and secrets** of this Space:
|
| 82 |
+
|
| 83 |
+
| Secret | Description |
|
| 84 |
+
| --------------------- | ------------------------------------------------------------------ |
|
| 85 |
+
| `HF_TOKEN` | Your HuggingFace token (Write access + Llama 3.1 licence accepted) |
|
| 86 |
+
| `SESSION_TTL_MINUTES` | Session expiry (default: 60) |
|
| 87 |
+
| `MAX_UPLOAD_MB` | Max upload size in MB (default: 50) |
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## 🏗️ Architecture
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
FastAPI (api.py)
|
| 95 |
+
├── SessionStore — UUID sessions, TTL, per-session lock
|
| 96 |
+
└── RAGSession
|
| 97 |
+
├── parse_file() — PDF/DOCX/TXT/CSV
|
| 98 |
+
├── chunk_text() — 512-char chunks, 64 overlap
|
| 99 |
+
├── all-MiniLM-L6-v2 — embeddings
|
| 100 |
+
├── FAISS — in-memory vector search
|
| 101 |
+
└── call_llm() — HF Router → Llama 3.1 8B
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## 💼 Enterprise Edition
|
| 107 |
+
|
| 108 |
+
Interested in **private, on-premise** deployment?
|
| 109 |
+
|
| 110 |
+
- 🔒 Private LLM Hosting
|
| 111 |
+
- 🎛️ Custom Model Fine-tuning
|
| 112 |
+
- 🛡️ Data Privacy Guarantees
|
| 113 |
+
- 🏷️ White-label Deployments
|
| 114 |
+
|
| 115 |
+
📧 [partnership@kerdos.in](mailto:partnership@kerdos.in) | 🌐 [kerdos.in/contact](https://kerdos.in/contact)
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
_© 2024–2025 Kerdos Infrasoft Private Limited | Bengaluru, Karnataka, India_
|
api.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Kerdos AI — Custom LLM Chat REST API
|
| 3 |
+
FastAPI application exposing the full RAG pipeline as HTTP endpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from contextlib import asynccontextmanager
|
| 13 |
+
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
from fastapi import FastAPI, File, HTTPException, Path, UploadFile, status
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
+
from fastapi.responses import JSONResponse
|
| 18 |
+
|
| 19 |
+
from models import (
|
| 20 |
+
ChatRequest,
|
| 21 |
+
ChatResponse,
|
| 22 |
+
HealthResponse,
|
| 23 |
+
IndexResponse,
|
| 24 |
+
MessageResponse,
|
| 25 |
+
SessionCreateResponse,
|
| 26 |
+
SessionStatusResponse,
|
| 27 |
+
Source,
|
| 28 |
+
)
|
| 29 |
+
from rag_core import call_llm
|
| 30 |
+
from sessions import store
|
| 31 |
+
|
| 32 |
+
load_dotenv()
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s | %(levelname)-8s | %(name)s — %(message)s",
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger("kerdos.api")
|
| 39 |
+
|
| 40 |
+
_START_TIME = time.time()
|
| 41 |
+
API_VERSION = "1.0.0"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ── Lifespan: background cleanup task ────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
@asynccontextmanager
|
| 47 |
+
async def lifespan(app: FastAPI):
|
| 48 |
+
"""Start a background task that purges expired sessions every 10 minutes."""
|
| 49 |
+
async def _cleanup_loop():
|
| 50 |
+
while True:
|
| 51 |
+
await asyncio.sleep(600)
|
| 52 |
+
removed = store.cleanup_expired()
|
| 53 |
+
if removed:
|
| 54 |
+
logger.info(f"Cleaned up {removed} expired session(s).")
|
| 55 |
+
|
| 56 |
+
task = asyncio.create_task(_cleanup_loop())
|
| 57 |
+
logger.info("Kerdos AI RAG API started.")
|
| 58 |
+
yield
|
| 59 |
+
task.cancel()
|
| 60 |
+
logger.info("Kerdos AI RAG API shutting down.")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── App ───────────────────────────────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
app = FastAPI(
|
| 66 |
+
title="Kerdos AI — Custom LLM RAG API",
|
| 67 |
+
description=(
|
| 68 |
+
"REST API for the Kerdos AI document Q&A system.\n\n"
|
| 69 |
+
"Upload your documents, index them, and ask questions — "
|
| 70 |
+
"answers are strictly grounded in your uploaded content.\n\n"
|
| 71 |
+
"**LLM**: `meta-llama/Llama-3.1-8B-Instruct` via HuggingFace Inference API \n"
|
| 72 |
+
"**Embeddings**: `sentence-transformers/all-MiniLM-L6-v2` \n"
|
| 73 |
+
"**Vector Store**: FAISS (in-memory, per-session) \n\n"
|
| 74 |
+
"© 2024–2025 [Kerdos Infrasoft Private Limited](https://kerdos.in)"
|
| 75 |
+
),
|
| 76 |
+
version=API_VERSION,
|
| 77 |
+
contact={
|
| 78 |
+
"name": "Kerdos Infrasoft",
|
| 79 |
+
"url": "https://kerdos.in/contact",
|
| 80 |
+
"email": "partnership@kerdos.in",
|
| 81 |
+
},
|
| 82 |
+
license_info={"name": "MIT"},
|
| 83 |
+
lifespan=lifespan,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
app.add_middleware(
|
| 87 |
+
CORSMiddleware,
|
| 88 |
+
allow_origins=["*"],
|
| 89 |
+
allow_credentials=True,
|
| 90 |
+
allow_methods=["*"],
|
| 91 |
+
allow_headers=["*"],
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_MB", "50")) * 1024 * 1024
|
| 95 |
+
ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md", ".csv"}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 99 |
+
|
| 100 |
+
def _get_session_or_404(session_id: str):
|
| 101 |
+
try:
|
| 102 |
+
return store.get(session_id)
|
| 103 |
+
except KeyError:
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 106 |
+
detail=f"Session '{session_id}' not found or has expired.",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ── Routes ────────────────────────────────────────────────────────────────────
|
| 111 |
+
|
| 112 |
+
@app.get(
|
| 113 |
+
"/",
|
| 114 |
+
tags=["Info"],
|
| 115 |
+
summary="API root",
|
| 116 |
+
response_model=dict,
|
| 117 |
+
)
|
| 118 |
+
async def root():
|
| 119 |
+
return {
|
| 120 |
+
"name": "Kerdos AI RAG API",
|
| 121 |
+
"version": API_VERSION,
|
| 122 |
+
"docs": "/docs",
|
| 123 |
+
"health": "/health",
|
| 124 |
+
"website": "https://kerdos.in",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.get(
|
| 129 |
+
"/health",
|
| 130 |
+
tags=["Info"],
|
| 131 |
+
summary="Health check",
|
| 132 |
+
response_model=HealthResponse,
|
| 133 |
+
)
|
| 134 |
+
async def health():
|
| 135 |
+
return HealthResponse(
|
| 136 |
+
status="ok",
|
| 137 |
+
version=API_VERSION,
|
| 138 |
+
uptime_seconds=round(time.time() - _START_TIME, 2),
|
| 139 |
+
active_sessions=store.active_count,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ── Sessions ──────────────────────────────────────────────────────────────────
|
| 144 |
+
|
| 145 |
+
@app.post(
|
| 146 |
+
"/sessions",
|
| 147 |
+
tags=["Sessions"],
|
| 148 |
+
summary="Create a new RAG session",
|
| 149 |
+
response_model=SessionCreateResponse,
|
| 150 |
+
status_code=status.HTTP_201_CREATED,
|
| 151 |
+
)
|
| 152 |
+
async def create_session():
|
| 153 |
+
"""
|
| 154 |
+
Creates a new isolated session with its own FAISS index and conversation history.
|
| 155 |
+
Returns a `session_id` that must be passed to all subsequent requests.
|
| 156 |
+
"""
|
| 157 |
+
sid = store.create()
|
| 158 |
+
logger.info(f"Session created: {sid}")
|
| 159 |
+
return SessionCreateResponse(session_id=sid)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@app.get(
|
| 163 |
+
"/sessions/{session_id}",
|
| 164 |
+
tags=["Sessions"],
|
| 165 |
+
summary="Get session status",
|
| 166 |
+
response_model=SessionStatusResponse,
|
| 167 |
+
)
|
| 168 |
+
async def get_session(session_id: str = Path(..., description="Session ID")):
|
| 169 |
+
"""Returns metadata about the session: document count, chunk count, history length, TTL."""
|
| 170 |
+
rag, _ = _get_session_or_404(session_id)
|
| 171 |
+
meta = store.get_meta(session_id)
|
| 172 |
+
return SessionStatusResponse(
|
| 173 |
+
session_id=session_id,
|
| 174 |
+
document_count=rag.document_count,
|
| 175 |
+
chunk_count=rag.chunk_count,
|
| 176 |
+
history_length=len(rag.history),
|
| 177 |
+
created_at=meta["created_at"],
|
| 178 |
+
expires_at=meta["expires_at"],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@app.delete(
|
| 183 |
+
"/sessions/{session_id}",
|
| 184 |
+
tags=["Sessions"],
|
| 185 |
+
summary="Delete a session",
|
| 186 |
+
response_model=MessageResponse,
|
| 187 |
+
)
|
| 188 |
+
async def delete_session(session_id: str = Path(...)):
|
| 189 |
+
"""Immediately removes the session and frees all in-memory resources."""
|
| 190 |
+
deleted = store.delete(session_id)
|
| 191 |
+
if not deleted:
|
| 192 |
+
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
|
| 193 |
+
logger.info(f"Session deleted: {session_id}")
|
| 194 |
+
return MessageResponse(message=f"Session '{session_id}' deleted.")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# ── Documents ─────────────────────────────────────────────────────────────────
|
| 198 |
+
|
| 199 |
+
@app.post(
|
| 200 |
+
"/sessions/{session_id}/documents",
|
| 201 |
+
tags=["Documents"],
|
| 202 |
+
summary="Upload and index documents",
|
| 203 |
+
response_model=IndexResponse,
|
| 204 |
+
)
|
| 205 |
+
async def upload_documents(
|
| 206 |
+
session_id: str = Path(..., description="Session ID"),
|
| 207 |
+
files: list[UploadFile] = File(..., description="Files to index (PDF, DOCX, TXT, MD, CSV)"),
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Upload one or more files to the session's FAISS index.
|
| 211 |
+
Supported formats: PDF, DOCX, TXT, MD, CSV.
|
| 212 |
+
Can be called multiple times to add more documents to an existing index.
|
| 213 |
+
"""
|
| 214 |
+
rag, lock = _get_session_or_404(session_id)
|
| 215 |
+
|
| 216 |
+
file_pairs: list[tuple[str, bytes]] = []
|
| 217 |
+
oversized: list[str] = []
|
| 218 |
+
|
| 219 |
+
for upload in files:
|
| 220 |
+
content = await upload.read()
|
| 221 |
+
if len(content) > MAX_UPLOAD_BYTES:
|
| 222 |
+
oversized.append(upload.filename or "unknown")
|
| 223 |
+
continue
|
| 224 |
+
from pathlib import Path as P
|
| 225 |
+
ext = P(upload.filename or "").suffix.lower()
|
| 226 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 227 |
+
raise HTTPException(
|
| 228 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
| 229 |
+
detail=f"File '{upload.filename}' has unsupported type '{ext}'. "
|
| 230 |
+
f"Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}",
|
| 231 |
+
)
|
| 232 |
+
file_pairs.append((upload.filename or "unnamed", content))
|
| 233 |
+
|
| 234 |
+
if oversized:
|
| 235 |
+
raise HTTPException(
|
| 236 |
+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
| 237 |
+
detail=f"Files exceed {os.getenv('MAX_UPLOAD_MB', '50')} MB limit: {oversized}",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if not file_pairs:
|
| 241 |
+
raise HTTPException(
|
| 242 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 243 |
+
detail="No valid files provided.",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Index in a thread so we don't block the event loop (FAISS + embeddings are CPU-bound)
|
| 247 |
+
loop = asyncio.get_event_loop()
|
| 248 |
+
|
| 249 |
+
def _index():
|
| 250 |
+
with lock:
|
| 251 |
+
return rag.index_documents(file_pairs)
|
| 252 |
+
|
| 253 |
+
indexed, failed = await loop.run_in_executor(None, _index)
|
| 254 |
+
|
| 255 |
+
logger.info(f"[{session_id}] Indexed {len(indexed)} file(s), failed: {len(failed)}")
|
| 256 |
+
|
| 257 |
+
return IndexResponse(
|
| 258 |
+
session_id=session_id,
|
| 259 |
+
indexed_files=indexed,
|
| 260 |
+
failed_files=failed,
|
| 261 |
+
chunk_count=rag.chunk_count,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ── Chat ──────────────────────────────────────────────────────────────────────
|
| 266 |
+
|
| 267 |
+
@app.post(
|
| 268 |
+
"/sessions/{session_id}/chat",
|
| 269 |
+
tags=["Chat"],
|
| 270 |
+
summary="Ask a question about your documents",
|
| 271 |
+
response_model=ChatResponse,
|
| 272 |
+
)
|
| 273 |
+
async def chat(
|
| 274 |
+
session_id: str = Path(..., description="Session ID"),
|
| 275 |
+
body: ChatRequest = ...,
|
| 276 |
+
):
|
| 277 |
+
"""
|
| 278 |
+
Retrieves the most relevant document chunks and uses Llama 3.1 8B to generate
|
| 279 |
+
an answer strictly grounded in those chunks.
|
| 280 |
+
|
| 281 |
+
**Requires a HuggingFace token** with Write access and acceptance of the
|
| 282 |
+
[Llama 3.1 license](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).
|
| 283 |
+
"""
|
| 284 |
+
rag, lock = _get_session_or_404(session_id)
|
| 285 |
+
|
| 286 |
+
loop = asyncio.get_event_loop()
|
| 287 |
+
|
| 288 |
+
def _run_rag():
|
| 289 |
+
with lock:
|
| 290 |
+
# 1. Retrieve relevant chunks
|
| 291 |
+
try:
|
| 292 |
+
top_chunks = rag.query(body.question, top_k=body.top_k)
|
| 293 |
+
except RuntimeError as exc:
|
| 294 |
+
raise HTTPException(
|
| 295 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 296 |
+
detail=str(exc),
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# 2. Call LLM
|
| 300 |
+
try:
|
| 301 |
+
answer = call_llm(
|
| 302 |
+
context_chunks=top_chunks,
|
| 303 |
+
question=body.question,
|
| 304 |
+
history=rag.history,
|
| 305 |
+
hf_token=body.hf_token,
|
| 306 |
+
temperature=body.temperature,
|
| 307 |
+
max_new_tokens=body.max_new_tokens,
|
| 308 |
+
)
|
| 309 |
+
except ValueError as exc:
|
| 310 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc))
|
| 311 |
+
except RuntimeError as exc:
|
| 312 |
+
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc))
|
| 313 |
+
|
| 314 |
+
# 3. Persist to history
|
| 315 |
+
rag.add_turn(body.question, answer)
|
| 316 |
+
|
| 317 |
+
# 4. Build source citations
|
| 318 |
+
sources = [
|
| 319 |
+
Source(
|
| 320 |
+
filename=c.filename,
|
| 321 |
+
chunk_index=c.chunk_index,
|
| 322 |
+
excerpt=c.text[:200] + ("…" if len(c.text) > 200 else ""),
|
| 323 |
+
)
|
| 324 |
+
for c in top_chunks
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
return answer, sources
|
| 328 |
+
|
| 329 |
+
answer, sources = await loop.run_in_executor(None, _run_rag)
|
| 330 |
+
|
| 331 |
+
logger.info(f"[{session_id}] Q: {body.question[:60]}…")
|
| 332 |
+
|
| 333 |
+
return ChatResponse(
|
| 334 |
+
session_id=session_id,
|
| 335 |
+
question=body.question,
|
| 336 |
+
answer=answer,
|
| 337 |
+
sources=sources,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@app.delete(
|
| 342 |
+
"/sessions/{session_id}/history",
|
| 343 |
+
tags=["Chat"],
|
| 344 |
+
summary="Clear conversation history",
|
| 345 |
+
response_model=MessageResponse,
|
| 346 |
+
)
|
| 347 |
+
async def clear_history(session_id: str = Path(...)):
|
| 348 |
+
"""Clears the multi-turn conversation history for the session (keeps the FAISS index intact)."""
|
| 349 |
+
rag, lock = _get_session_or_404(session_id)
|
| 350 |
+
with lock:
|
| 351 |
+
rag.clear_history()
|
| 352 |
+
return MessageResponse(message="Conversation history cleared.")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
import uvicorn
|
| 359 |
+
|
| 360 |
+
uvicorn.run(
|
| 361 |
+
"api:app",
|
| 362 |
+
host=os.getenv("HOST", "0.0.0.0"),
|
| 363 |
+
port=int(os.getenv("PORT", "8000")),
|
| 364 |
+
reload=False,
|
| 365 |
+
log_level="info",
|
| 366 |
+
)
|
models.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic request/response models for the Kerdos AI RAG API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ─── Session ────────────────────────────────────────────────────────────────
|
| 12 |
+
|
| 13 |
+
class SessionCreateResponse(BaseModel):
|
| 14 |
+
session_id: str = Field(..., description="Unique session identifier")
|
| 15 |
+
message: str = Field(default="Session created successfully")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SessionStatusResponse(BaseModel):
|
| 19 |
+
session_id: str
|
| 20 |
+
document_count: int = Field(..., description="Number of uploaded documents")
|
| 21 |
+
chunk_count: int = Field(..., description="Number of indexed text chunks")
|
| 22 |
+
history_length: int = Field(..., description="Number of turns in conversation history")
|
| 23 |
+
created_at: str
|
| 24 |
+
expires_at: str
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ─── Documents ──────────────────────────────────────────────────────────────
|
| 28 |
+
|
| 29 |
+
class IndexResponse(BaseModel):
|
| 30 |
+
session_id: str
|
| 31 |
+
indexed_files: List[str] = Field(..., description="Names of successfully indexed files")
|
| 32 |
+
failed_files: List[str] = Field(default_factory=list, description="Files that failed to parse")
|
| 33 |
+
chunk_count: int = Field(..., description="Total chunks in FAISS index")
|
| 34 |
+
message: str = Field(default="Documents indexed successfully")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ─── Chat ────────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
class Source(BaseModel):
|
| 40 |
+
filename: str
|
| 41 |
+
chunk_index: int
|
| 42 |
+
excerpt: str = Field(..., description="Short preview of the retrieved chunk")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ChatRequest(BaseModel):
|
| 46 |
+
question: str = Field(..., min_length=1, description="The question to ask about your documents")
|
| 47 |
+
hf_token: str = Field(..., description="Hugging Face API token (Write access required for Llama 3)")
|
| 48 |
+
top_k: int = Field(default=5, ge=1, le=20, description="Number of chunks to retrieve")
|
| 49 |
+
temperature: float = Field(default=0.3, ge=0.0, le=1.0)
|
| 50 |
+
max_new_tokens: int = Field(default=512, ge=64, le=2048)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ChatResponse(BaseModel):
|
| 54 |
+
session_id: str
|
| 55 |
+
question: str
|
| 56 |
+
answer: str
|
| 57 |
+
sources: List[Source] = Field(default_factory=list)
|
| 58 |
+
model: str = Field(default="meta-llama/Llama-3.1-8B-Instruct")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ─── Health ──────────────────────────────────────────────────────────────────
|
| 62 |
+
|
| 63 |
+
class HealthResponse(BaseModel):
|
| 64 |
+
status: str = "ok"
|
| 65 |
+
version: str
|
| 66 |
+
uptime_seconds: float
|
| 67 |
+
active_sessions: int
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ─── Generic ─────────────────────────────────────────────────────────────────
|
| 71 |
+
|
| 72 |
+
class MessageResponse(BaseModel):
|
| 73 |
+
message: str
|
rag_core.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core RAG engine: document parsing, chunking, embedding, FAISS indexing, and LLM querying.
|
| 3 |
+
No Gradio dependency — pure Python, importable by the FastAPI layer.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import io
|
| 9 |
+
import logging
|
| 10 |
+
import textwrap
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import requests
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 21 |
+
# Lazy imports (heavy libraries loaded only once at first use)
|
| 22 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 23 |
+
|
| 24 |
+
_embedding_model = None
|
| 25 |
+
_faiss = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_embedding_model():
|
| 29 |
+
global _embedding_model
|
| 30 |
+
if _embedding_model is None:
|
| 31 |
+
from sentence_transformers import SentenceTransformer
|
| 32 |
+
logger.info("Loading SentenceTransformer model…")
|
| 33 |
+
_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 34 |
+
logger.info("Embedding model loaded.")
|
| 35 |
+
return _embedding_model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_faiss():
|
| 39 |
+
global _faiss
|
| 40 |
+
if _faiss is None:
|
| 41 |
+
import faiss as _faiss_module
|
| 42 |
+
_faiss = _faiss_module
|
| 43 |
+
return _faiss
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 47 |
+
# Document parsing
|
| 48 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 49 |
+
|
| 50 |
+
def _parse_pdf(data: bytes) -> str:
|
| 51 |
+
import fitz # PyMuPDF
|
| 52 |
+
doc = fitz.open(stream=data, filetype="pdf")
|
| 53 |
+
return "\n\n".join(page.get_text() for page in doc)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _parse_docx(data: bytes) -> str:
|
| 57 |
+
from docx import Document
|
| 58 |
+
doc = Document(io.BytesIO(data))
|
| 59 |
+
return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _parse_txt(data: bytes) -> str:
|
| 63 |
+
for enc in ("utf-8", "latin-1", "cp1252"):
|
| 64 |
+
try:
|
| 65 |
+
return data.decode(enc)
|
| 66 |
+
except UnicodeDecodeError:
|
| 67 |
+
continue
|
| 68 |
+
return data.decode("utf-8", errors="replace")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _parse_csv(data: bytes) -> str:
|
| 72 |
+
import csv
|
| 73 |
+
rows = []
|
| 74 |
+
reader = csv.reader(io.StringIO(_parse_txt(data)))
|
| 75 |
+
for row in reader:
|
| 76 |
+
rows.append(", ".join(row))
|
| 77 |
+
return "\n".join(rows)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
PARSERS = {
|
| 81 |
+
".pdf": _parse_pdf,
|
| 82 |
+
".docx": _parse_docx,
|
| 83 |
+
".txt": _parse_txt,
|
| 84 |
+
".md": _parse_txt,
|
| 85 |
+
".csv": _parse_csv,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def parse_file(filename: str, data: bytes) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Parse a file by extension and return its plain-text content.
|
| 92 |
+
Raises ValueError for unsupported extensions.
|
| 93 |
+
"""
|
| 94 |
+
ext = Path(filename).suffix.lower()
|
| 95 |
+
parser = PARSERS.get(ext)
|
| 96 |
+
if parser is None:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"Unsupported file type '{ext}'. "
|
| 99 |
+
f"Supported: {', '.join(PARSERS)}"
|
| 100 |
+
)
|
| 101 |
+
return parser(data)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 105 |
+
# Text chunking
|
| 106 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
def chunk_text(
|
| 109 |
+
text: str,
|
| 110 |
+
chunk_size: int = 512,
|
| 111 |
+
overlap: int = 64,
|
| 112 |
+
) -> List[str]:
|
| 113 |
+
"""Split *text* into overlapping fixed-size character chunks."""
|
| 114 |
+
text = text.strip()
|
| 115 |
+
if not text:
|
| 116 |
+
return []
|
| 117 |
+
chunks: List[str] = []
|
| 118 |
+
start = 0
|
| 119 |
+
while start < len(text):
|
| 120 |
+
end = min(start + chunk_size, len(text))
|
| 121 |
+
chunk = text[start:end].strip()
|
| 122 |
+
if chunk:
|
| 123 |
+
chunks.append(chunk)
|
| 124 |
+
if end == len(text):
|
| 125 |
+
break
|
| 126 |
+
start = end - overlap
|
| 127 |
+
return chunks
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 131 |
+
# RAG Session
|
| 132 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class IndexedChunk:
|
| 136 |
+
text: str
|
| 137 |
+
filename: str
|
| 138 |
+
chunk_index: int # global index inside this session's chunk list
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class RAGSession:
|
| 143 |
+
"""
|
| 144 |
+
Holds the FAISS vector index and conversation history for a single API session.
|
| 145 |
+
Thread-safety is the responsibility of the caller (sessions.py uses a per-session lock).
|
| 146 |
+
"""
|
| 147 |
+
chunks: List[IndexedChunk] = field(default_factory=list)
|
| 148 |
+
history: List[Tuple[str, str]] = field(default_factory=list) # [(user, assistant), …]
|
| 149 |
+
document_names: List[str] = field(default_factory=list)
|
| 150 |
+
_index = None # faiss.IndexFlatL2
|
| 151 |
+
|
| 152 |
+
# ── Public helpers ────────────────────────────────────────────────────────
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def document_count(self) -> int:
|
| 156 |
+
return len(self.document_names)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def chunk_count(self) -> int:
|
| 160 |
+
return len(self.chunks)
|
| 161 |
+
|
| 162 |
+
def index_documents(self, files: List[Tuple[str, bytes]]) -> Tuple[List[str], List[str]]:
|
| 163 |
+
"""
|
| 164 |
+
Parse, chunk, and embed a list of (filename, bytes) pairs into the FAISS index.
|
| 165 |
+
Returns (indexed_names, failed_names).
|
| 166 |
+
"""
|
| 167 |
+
model = _get_embedding_model()
|
| 168 |
+
faiss = _get_faiss()
|
| 169 |
+
|
| 170 |
+
new_chunks: List[IndexedChunk] = []
|
| 171 |
+
indexed: List[str] = []
|
| 172 |
+
failed: List[str] = []
|
| 173 |
+
|
| 174 |
+
for filename, data in files:
|
| 175 |
+
try:
|
| 176 |
+
text = parse_file(filename, data)
|
| 177 |
+
raw_chunks = chunk_text(text)
|
| 178 |
+
start_idx = len(self.chunks) + len(new_chunks)
|
| 179 |
+
for i, c in enumerate(raw_chunks):
|
| 180 |
+
new_chunks.append(IndexedChunk(
|
| 181 |
+
text=c,
|
| 182 |
+
filename=filename,
|
| 183 |
+
chunk_index=start_idx + i,
|
| 184 |
+
))
|
| 185 |
+
indexed.append(filename)
|
| 186 |
+
if filename not in self.document_names:
|
| 187 |
+
self.document_names.append(filename)
|
| 188 |
+
logger.info(f"Indexed '{filename}': {len(raw_chunks)} chunks")
|
| 189 |
+
except Exception as exc:
|
| 190 |
+
logger.warning(f"Failed to parse '{filename}': {exc}")
|
| 191 |
+
failed.append(filename)
|
| 192 |
+
|
| 193 |
+
if not new_chunks:
|
| 194 |
+
return indexed, failed
|
| 195 |
+
|
| 196 |
+
# Embed all new chunks
|
| 197 |
+
texts = [c.text for c in new_chunks]
|
| 198 |
+
vectors = model.encode(texts, show_progress_bar=False).astype(np.float32)
|
| 199 |
+
|
| 200 |
+
dim = vectors.shape[1]
|
| 201 |
+
if self._index is None:
|
| 202 |
+
self._index = faiss.IndexFlatL2(dim)
|
| 203 |
+
|
| 204 |
+
self._index.add(vectors)
|
| 205 |
+
self.chunks.extend(new_chunks)
|
| 206 |
+
return indexed, failed
|
| 207 |
+
|
| 208 |
+
def query(self, question: str, top_k: int = 5) -> List[IndexedChunk]:
|
| 209 |
+
"""
|
| 210 |
+
Run a similarity search and return the most relevant chunks.
|
| 211 |
+
Raises RuntimeError if no documents have been indexed yet.
|
| 212 |
+
"""
|
| 213 |
+
if self._index is None or not self.chunks:
|
| 214 |
+
raise RuntimeError("No documents indexed. Upload documents first.")
|
| 215 |
+
|
| 216 |
+
model = _get_embedding_model()
|
| 217 |
+
q_vec = model.encode([question], show_progress_bar=False).astype(np.float32)
|
| 218 |
+
k = min(top_k, len(self.chunks))
|
| 219 |
+
_, indices = self._index.search(q_vec, k)
|
| 220 |
+
|
| 221 |
+
return [self.chunks[i] for i in indices[0] if i < len(self.chunks)]
|
| 222 |
+
|
| 223 |
+
def add_turn(self, question: str, answer: str) -> None:
|
| 224 |
+
self.history.append((question, answer))
|
| 225 |
+
|
| 226 |
+
def clear_history(self) -> None:
|
| 227 |
+
self.history.clear()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 231 |
+
# LLM call (HuggingFace Inference API)
|
| 232 |
+
# ──────────────────────────────────────────────────────────────────────────────
|
| 233 |
+
|
| 234 |
+
_HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
|
| 235 |
+
|
| 236 |
+
_SYSTEM_PROMPT = textwrap.dedent("""\
|
| 237 |
+
You are Kerdos AI, an expert document assistant.
|
| 238 |
+
Answer ONLY from the provided document excerpts.
|
| 239 |
+
If the answer is not in the excerpts, say:
|
| 240 |
+
"I could not find this information in the uploaded documents."
|
| 241 |
+
Be concise, factual, and cite which document your answer comes from.
|
| 242 |
+
""")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def call_llm(
|
| 246 |
+
context_chunks: List[IndexedChunk],
|
| 247 |
+
question: str,
|
| 248 |
+
history: List[Tuple[str, str]],
|
| 249 |
+
hf_token: str,
|
| 250 |
+
temperature: float = 0.3,
|
| 251 |
+
max_new_tokens: int = 512,
|
| 252 |
+
) -> str:
|
| 253 |
+
"""
|
| 254 |
+
Build a chat prompt and call the HF Inference API.
|
| 255 |
+
Returns the assistant's reply as a string.
|
| 256 |
+
"""
|
| 257 |
+
# Build context block
|
| 258 |
+
context_parts = []
|
| 259 |
+
for chunk in context_chunks:
|
| 260 |
+
context_parts.append(
|
| 261 |
+
f"[Source: {chunk.filename}]\n{chunk.text}"
|
| 262 |
+
)
|
| 263 |
+
context_text = "\n\n---\n\n".join(context_parts)
|
| 264 |
+
|
| 265 |
+
# Build messages for the chat template
|
| 266 |
+
messages = [{"role": "system", "content": _SYSTEM_PROMPT}]
|
| 267 |
+
|
| 268 |
+
# Add recent history (last 6 turns to stay within context window)
|
| 269 |
+
for user_msg, asst_msg in history[-6:]:
|
| 270 |
+
messages.append({"role": "user", "content": user_msg})
|
| 271 |
+
messages.append({"role": "assistant", "content": asst_msg})
|
| 272 |
+
|
| 273 |
+
# Current turn with injected context
|
| 274 |
+
user_content = (
|
| 275 |
+
f"Document excerpts:\n\n{context_text}\n\n"
|
| 276 |
+
f"Question: {question}"
|
| 277 |
+
)
|
| 278 |
+
messages.append({"role": "user", "content": user_content})
|
| 279 |
+
|
| 280 |
+
payload = {
|
| 281 |
+
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
| 282 |
+
"messages": messages,
|
| 283 |
+
"temperature": temperature,
|
| 284 |
+
"max_tokens": max_new_tokens,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
headers = {
|
| 288 |
+
"Authorization": f"Bearer {hf_token}",
|
| 289 |
+
"Content-Type": "application/json",
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
response = requests.post(
|
| 294 |
+
_HF_API_URL,
|
| 295 |
+
json=payload,
|
| 296 |
+
headers=headers,
|
| 297 |
+
timeout=120,
|
| 298 |
+
)
|
| 299 |
+
response.raise_for_status()
|
| 300 |
+
data = response.json()
|
| 301 |
+
return data["choices"][0]["message"]["content"].strip()
|
| 302 |
+
except requests.HTTPError as exc:
|
| 303 |
+
status = exc.response.status_code
|
| 304 |
+
if status == 401:
|
| 305 |
+
raise ValueError("Invalid HuggingFace token. Please check your HF_TOKEN.") from exc
|
| 306 |
+
if status == 403:
|
| 307 |
+
raise ValueError(
|
| 308 |
+
"Access denied. Your HF token needs 'Write' permission and you must accept "
|
| 309 |
+
"the Llama 3.1 license at https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct"
|
| 310 |
+
) from exc
|
| 311 |
+
raise RuntimeError(f"HF API error {status}: {exc.response.text}") from exc
|
| 312 |
+
except requests.RequestException as exc:
|
| 313 |
+
raise RuntimeError(f"Network error calling HF API: {exc}") from exc
|
requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Kerdos AI RAG API — Python dependencies
|
| 2 |
+
|
| 3 |
+
# Web framework & server
|
| 4 |
+
fastapi>=0.110.0
|
| 5 |
+
uvicorn[standard]>=0.27.0
|
| 6 |
+
python-multipart>=0.0.9
|
| 7 |
+
|
| 8 |
+
# Data validation
|
| 9 |
+
pydantic>=2.0.0
|
| 10 |
+
pydantic-settings>=2.0.0
|
| 11 |
+
|
| 12 |
+
# AI / ML
|
| 13 |
+
sentence-transformers>=2.6.0
|
| 14 |
+
faiss-cpu>=1.7.4
|
| 15 |
+
|
| 16 |
+
# Document parsing
|
| 17 |
+
pymupdf>=1.23.0 # PDF via fitz
|
| 18 |
+
python-docx>=1.1.0 # DOCX
|
| 19 |
+
|
| 20 |
+
# HTTP client (for HF Inference API)
|
| 21 |
+
requests>=2.31.0
|
| 22 |
+
|
| 23 |
+
# Config
|
| 24 |
+
python-dotenv>=1.0.0
|
| 25 |
+
|
| 26 |
+
# Numerical
|
| 27 |
+
numpy>=1.24.0
|
sessions.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Thread-safe in-memory session store with TTL-based expiry.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import threading
|
| 8 |
+
import uuid
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from typing import Dict, Optional
|
| 12 |
+
|
| 13 |
+
from rag_core import RAGSession
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class _SessionEntry:
|
| 18 |
+
session: RAGSession
|
| 19 |
+
lock: threading.Lock = field(default_factory=threading.Lock)
|
| 20 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
| 21 |
+
expires_at: datetime = field(default_factory=datetime.utcnow) # set in __post_init__
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
# will be overwritten by SessionStore with the real TTL
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SessionStore:
|
| 29 |
+
"""
|
| 30 |
+
Global in-memory store for RAG sessions.
|
| 31 |
+
Each session has its own lock to allow concurrent requests on different sessions.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, ttl_minutes: int = 60):
|
| 35 |
+
self._sessions: Dict[str, _SessionEntry] = {}
|
| 36 |
+
self._store_lock = threading.Lock()
|
| 37 |
+
self._ttl = timedelta(minutes=ttl_minutes)
|
| 38 |
+
|
| 39 |
+
# ── Public API ────────────────────────────────────────────────────────────
|
| 40 |
+
|
| 41 |
+
def create(self) -> str:
|
| 42 |
+
"""Create a new session and return its ID."""
|
| 43 |
+
sid = str(uuid.uuid4())
|
| 44 |
+
now = datetime.utcnow()
|
| 45 |
+
entry = _SessionEntry(
|
| 46 |
+
session=RAGSession(),
|
| 47 |
+
created_at=now,
|
| 48 |
+
expires_at=now + self._ttl,
|
| 49 |
+
)
|
| 50 |
+
with self._store_lock:
|
| 51 |
+
self._sessions[sid] = entry
|
| 52 |
+
return sid
|
| 53 |
+
|
| 54 |
+
def get(self, session_id: str) -> tuple[RAGSession, threading.Lock]:
|
| 55 |
+
"""
|
| 56 |
+
Return (RAGSession, per-session Lock) or raise KeyError if not found/expired.
|
| 57 |
+
Also refreshes the TTL on access.
|
| 58 |
+
"""
|
| 59 |
+
with self._store_lock:
|
| 60 |
+
entry = self._sessions.get(session_id)
|
| 61 |
+
if entry is None or datetime.utcnow() > entry.expires_at:
|
| 62 |
+
if entry is not None:
|
| 63 |
+
del self._sessions[session_id]
|
| 64 |
+
raise KeyError(session_id)
|
| 65 |
+
# Refresh TTL on access
|
| 66 |
+
entry.expires_at = datetime.utcnow() + self._ttl
|
| 67 |
+
return entry.session, entry.lock
|
| 68 |
+
|
| 69 |
+
def get_meta(self, session_id: str) -> dict:
|
| 70 |
+
"""Return metadata (created_at, expires_at) without refreshing TTL."""
|
| 71 |
+
with self._store_lock:
|
| 72 |
+
entry = self._sessions.get(session_id)
|
| 73 |
+
if entry is None or datetime.utcnow() > entry.expires_at:
|
| 74 |
+
raise KeyError(session_id)
|
| 75 |
+
return {
|
| 76 |
+
"created_at": entry.created_at.isoformat() + "Z",
|
| 77 |
+
"expires_at": entry.expires_at.isoformat() + "Z",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def delete(self, session_id: str) -> bool:
|
| 81 |
+
"""Delete a session. Returns True if it existed."""
|
| 82 |
+
with self._store_lock:
|
| 83 |
+
return self._sessions.pop(session_id, None) is not None
|
| 84 |
+
|
| 85 |
+
def cleanup_expired(self) -> int:
|
| 86 |
+
"""Remove all expired sessions. Returns the number removed."""
|
| 87 |
+
now = datetime.utcnow()
|
| 88 |
+
with self._store_lock:
|
| 89 |
+
expired = [sid for sid, e in self._sessions.items() if now > e.expires_at]
|
| 90 |
+
for sid in expired:
|
| 91 |
+
del self._sessions[sid]
|
| 92 |
+
return len(expired)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def active_count(self) -> int:
|
| 96 |
+
with self._store_lock:
|
| 97 |
+
now = datetime.utcnow()
|
| 98 |
+
return sum(1 for e in self._sessions.values() if now <= e.expires_at)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Singleton — imported by api.py
|
| 102 |
+
store = SessionStore()
|