Spaces:
Sleeping
Sleeping
clientsphere
#1
by
ChiragPatankar
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .gitattributes +35 -0
- .gitignore +0 -37
- Dockerfile +34 -35
- README.md +47 -47
- README_HF_SPACES.md +230 -231
- app.py +13 -13
- app/__init__.py +7 -7
- app/__pycache__/__init__.cpython-313.pyc +0 -0
- app/__pycache__/config.cpython-313.pyc +0 -0
- app/__pycache__/main.cpython-313.pyc +0 -0
- app/billing/__pycache__/pricing.cpython-313.pyc +0 -0
- app/billing/__pycache__/quota.cpython-313.pyc +0 -0
- app/billing/__pycache__/usage_tracker.cpython-313.pyc +0 -0
- app/billing/pricing.py +57 -57
- app/billing/quota.py +131 -131
- app/billing/usage_tracker.py +173 -173
- app/config.py +77 -77
- app/db/__init__.py +2 -2
- app/db/__pycache__/__init__.cpython-313.pyc +0 -0
- app/db/__pycache__/database.cpython-313.pyc +0 -0
- app/db/__pycache__/models.cpython-313.pyc +0 -0
- app/db/database.py +53 -53
- app/db/models.py +129 -129
- app/main.py +1039 -1039
- app/middleware/__init__.py +13 -13
- app/middleware/__pycache__/__init__.cpython-313.pyc +0 -0
- app/middleware/__pycache__/auth.cpython-313.pyc +0 -0
- app/middleware/__pycache__/rate_limit.cpython-313.pyc +0 -0
- app/middleware/auth.py +212 -212
- app/middleware/rate_limit.py +40 -40
- app/models/__init__.py +33 -33
- app/models/__pycache__/__init__.cpython-313.pyc +0 -0
- app/models/__pycache__/billing_schemas.cpython-313.pyc +0 -0
- app/models/__pycache__/schemas.cpython-313.pyc +0 -0
- app/models/billing_schemas.py +46 -46
- app/models/schemas.py +112 -112
- app/rag/__init__.py +27 -27
- app/rag/__pycache__/__init__.cpython-313.pyc +0 -0
- app/rag/__pycache__/answer.cpython-313.pyc +0 -0
- app/rag/__pycache__/chunking.cpython-313.pyc +0 -0
- app/rag/__pycache__/embeddings.cpython-313.pyc +0 -0
- app/rag/__pycache__/ingest.cpython-313.pyc +0 -0
- app/rag/__pycache__/intent.cpython-313.pyc +0 -0
- app/rag/__pycache__/prompts.cpython-313.pyc +0 -0
- app/rag/__pycache__/retrieval.cpython-313.pyc +0 -0
- app/rag/__pycache__/vectorstore.cpython-313.pyc +0 -0
- app/rag/__pycache__/verifier.cpython-313.pyc +0 -0
- app/rag/answer.py +444 -444
- app/rag/chunking.py +196 -196
- app/rag/embeddings.py +145 -145
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
# Python files
|
| 2 |
-
*.py
|
| 3 |
-
*.pyc
|
| 4 |
-
__pycache__/
|
| 5 |
-
|
| 6 |
-
# Config files
|
| 7 |
-
*.txt
|
| 8 |
-
*.yaml
|
| 9 |
-
*.yml
|
| 10 |
-
*.toml
|
| 11 |
-
*.json
|
| 12 |
-
*.md
|
| 13 |
-
*.sh
|
| 14 |
-
*.bat
|
| 15 |
-
*.ps1
|
| 16 |
-
|
| 17 |
-
# Directories to include
|
| 18 |
-
app/
|
| 19 |
-
requirements.txt
|
| 20 |
-
Dockerfile
|
| 21 |
-
app.py
|
| 22 |
-
README.md
|
| 23 |
-
.gitignore
|
| 24 |
-
.env.example.txt
|
| 25 |
-
|
| 26 |
-
# Exclude binaries
|
| 27 |
-
*.png
|
| 28 |
-
*.jpg
|
| 29 |
-
*.jpeg
|
| 30 |
-
*.db
|
| 31 |
-
*.pdf
|
| 32 |
-
*.bin
|
| 33 |
-
*.sqlite3
|
| 34 |
-
public/
|
| 35 |
-
data/billing/
|
| 36 |
-
data/vectordb/
|
| 37 |
-
venv/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
CHANGED
|
@@ -1,35 +1,34 @@
|
|
| 1 |
-
FROM python:3.11-slim
|
| 2 |
-
|
| 3 |
-
# Set working directory
|
| 4 |
-
WORKDIR /app
|
| 5 |
-
|
| 6 |
-
# Install system dependencies
|
| 7 |
-
RUN apt-get update && apt-get install -y \
|
| 8 |
-
gcc \
|
| 9 |
-
g++ \
|
| 10 |
-
curl \
|
| 11 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
-
|
| 13 |
-
# Copy requirements first (for better caching)
|
| 14 |
-
COPY requirements.txt .
|
| 15 |
-
|
| 16 |
-
# Install Python dependencies
|
| 17 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
-
|
| 19 |
-
# Copy application code
|
| 20 |
-
COPY . .
|
| 21 |
-
|
| 22 |
-
# Create necessary directories
|
| 23 |
-
RUN mkdir -p data/uploads data/processed data/vectordb data/billing
|
| 24 |
-
|
| 25 |
-
# Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var)
|
| 26 |
-
EXPOSE 7860
|
| 27 |
-
|
| 28 |
-
# Health check
|
| 29 |
-
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 30 |
-
CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1
|
| 31 |
-
|
| 32 |
-
# Start the application
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
curl \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy requirements first (for better caching)
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
|
| 16 |
+
# Install Python dependencies
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy application code
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Create necessary directories
|
| 23 |
+
RUN mkdir -p data/uploads data/processed data/vectordb data/billing
|
| 24 |
+
|
| 25 |
+
# Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var)
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
|
| 28 |
+
# Health check
|
| 29 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 30 |
+
CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1
|
| 31 |
+
|
| 32 |
+
# Start the application (Hugging Face Spaces provides PORT env var)
|
| 33 |
+
CMD uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}
|
| 34 |
+
|
|
|
README.md
CHANGED
|
@@ -1,47 +1,47 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ClientSphere RAG Backend
|
| 3 |
-
emoji: 🤖
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: docker
|
| 7 |
-
sdk_version: "4.0.0"
|
| 8 |
-
python_version: "3.11"
|
| 9 |
-
app_file: app.py
|
| 10 |
-
pinned: false
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
# ClientSphere RAG Backend
|
| 14 |
-
|
| 15 |
-
FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform.
|
| 16 |
-
|
| 17 |
-
## Features
|
| 18 |
-
|
| 19 |
-
- 📚 Knowledge base management
|
| 20 |
-
- 🔍 Semantic search with embeddings
|
| 21 |
-
- 💬 AI-powered chat with citations
|
| 22 |
-
- 📊 Confidence scoring
|
| 23 |
-
- 🔒 Multi-tenant isolation
|
| 24 |
-
- 📈 Usage tracking and billing
|
| 25 |
-
|
| 26 |
-
## API Endpoints
|
| 27 |
-
|
| 28 |
-
- `GET /health/live` - Health check
|
| 29 |
-
- `GET /kb/stats` - Knowledge base statistics
|
| 30 |
-
- `POST /kb/upload` - Upload documents
|
| 31 |
-
- `POST /chat` - Chat with RAG
|
| 32 |
-
- `GET /kb/search` - Search knowledge base
|
| 33 |
-
|
| 34 |
-
## Environment Variables
|
| 35 |
-
|
| 36 |
-
Required:
|
| 37 |
-
- `GEMINI_API_KEY` - Google Gemini API key
|
| 38 |
-
- `ENV` - Set to `prod` for production
|
| 39 |
-
- `LLM_PROVIDER` - `gemini` or `openai`
|
| 40 |
-
|
| 41 |
-
Optional:
|
| 42 |
-
- `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated)
|
| 43 |
-
- `JWT_SECRET` - JWT secret for authentication
|
| 44 |
-
|
| 45 |
-
## Documentation
|
| 46 |
-
|
| 47 |
-
See `README_HF_SPACES.md` for deployment details.
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ClientSphere RAG Backend
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: "4.0.0"
|
| 8 |
+
python_version: "3.11"
|
| 9 |
+
app_file: app.py
|
| 10 |
+
pinned: false
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# ClientSphere RAG Backend
|
| 14 |
+
|
| 15 |
+
FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
|
| 19 |
+
- 📚 Knowledge base management
|
| 20 |
+
- 🔍 Semantic search with embeddings
|
| 21 |
+
- 💬 AI-powered chat with citations
|
| 22 |
+
- 📊 Confidence scoring
|
| 23 |
+
- 🔒 Multi-tenant isolation
|
| 24 |
+
- 📈 Usage tracking and billing
|
| 25 |
+
|
| 26 |
+
## API Endpoints
|
| 27 |
+
|
| 28 |
+
- `GET /health/live` - Health check
|
| 29 |
+
- `GET /kb/stats` - Knowledge base statistics
|
| 30 |
+
- `POST /kb/upload` - Upload documents
|
| 31 |
+
- `POST /chat` - Chat with RAG
|
| 32 |
+
- `GET /kb/search` - Search knowledge base
|
| 33 |
+
|
| 34 |
+
## Environment Variables
|
| 35 |
+
|
| 36 |
+
Required:
|
| 37 |
+
- `GEMINI_API_KEY` - Google Gemini API key
|
| 38 |
+
- `ENV` - Set to `prod` for production
|
| 39 |
+
- `LLM_PROVIDER` - `gemini` or `openai`
|
| 40 |
+
|
| 41 |
+
Optional:
|
| 42 |
+
- `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated)
|
| 43 |
+
- `JWT_SECRET` - JWT secret for authentication
|
| 44 |
+
|
| 45 |
+
## Documentation
|
| 46 |
+
|
| 47 |
+
See `README_HF_SPACES.md` for deployment details.
|
README_HF_SPACES.md
CHANGED
|
@@ -1,231 +1,230 @@
|
|
| 1 |
-
# 🚀 Deploy RAG Backend to Hugging Face Spaces
|
| 2 |
-
|
| 3 |
-
Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies!
|
| 4 |
-
|
| 5 |
-
## ✅ Why Hugging Face Spaces?
|
| 6 |
-
|
| 7 |
-
- ✅ **Free tier** with generous limits
|
| 8 |
-
- ✅ **Full Python 3.11+** support
|
| 9 |
-
- ✅ **ML libraries** fully supported (sentence-transformers, chromadb, etc.)
|
| 10 |
-
- ✅ **Persistent storage** for vector database
|
| 11 |
-
- ✅ **No bundle size limits**
|
| 12 |
-
- ✅ **GPU support** available (paid)
|
| 13 |
-
- ✅ **Automatic HTTPS** and custom domains
|
| 14 |
-
- ✅ **GitHub integration** (auto-deploy on push)
|
| 15 |
-
|
| 16 |
-
## 📋 Prerequisites
|
| 17 |
-
|
| 18 |
-
1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co)
|
| 19 |
-
2. **GitHub Repository**: Your code should be in a GitHub repository
|
| 20 |
-
3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey)
|
| 21 |
-
|
| 22 |
-
## 🚀 Step-by-Step Deployment
|
| 23 |
-
|
| 24 |
-
### Step 1: Prepare Your Repository
|
| 25 |
-
|
| 26 |
-
Your `rag-backend/` directory should contain:
|
| 27 |
-
- ✅ `app.py` - Entry point (already created)
|
| 28 |
-
- ✅ `requirements.txt` - Dependencies
|
| 29 |
-
- ✅ `app/main.py` - FastAPI application
|
| 30 |
-
- ✅ All other application files
|
| 31 |
-
|
| 32 |
-
### Step 2: Create Hugging Face Space
|
| 33 |
-
|
| 34 |
-
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 35 |
-
2. Click **"Create new Space"**
|
| 36 |
-
3. Configure:
|
| 37 |
-
- **Owner**: Your username
|
| 38 |
-
- **Space name**: `clientsphere-rag-backend` (or your choice)
|
| 39 |
-
- **SDK**: **Docker** (recommended) or **Gradio** (if you want UI)
|
| 40 |
-
- **Hardware**:
|
| 41 |
-
- **CPU basic** (free) - Good for testing
|
| 42 |
-
- **CPU upgrade** (paid) - Better performance
|
| 43 |
-
- **GPU** (paid) - For heavy ML workloads
|
| 44 |
-
|
| 45 |
-
### Step 3: Connect GitHub Repository
|
| 46 |
-
|
| 47 |
-
1. In Space creation, select **"Repository"** as source
|
| 48 |
-
2. Choose your GitHub repository
|
| 49 |
-
3. Set **Repository path** to: `rag-backend/` (subdirectory)
|
| 50 |
-
4. Click **"Create Space"**
|
| 51 |
-
|
| 52 |
-
### Step 4: Configure Environment Variables
|
| 53 |
-
|
| 54 |
-
1. Go to your Space's **Settings** tab
|
| 55 |
-
2. Scroll to **"Repository secrets"** or **"Variables"**
|
| 56 |
-
3. Add these secrets:
|
| 57 |
-
|
| 58 |
-
**Required:**
|
| 59 |
-
```
|
| 60 |
-
GEMINI_API_KEY=your_gemini_api_key_here
|
| 61 |
-
ENV=prod
|
| 62 |
-
LLM_PROVIDER=gemini
|
| 63 |
-
```
|
| 64 |
-
|
| 65 |
-
**Optional (but recommended):**
|
| 66 |
-
```
|
| 67 |
-
ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev
|
| 68 |
-
JWT_SECRET=your_secure_jwt_secret
|
| 69 |
-
DEBUG=false
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
### Step 5: Configure Docker (if using Docker SDK)
|
| 73 |
-
|
| 74 |
-
If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`.
|
| 75 |
-
|
| 76 |
-
**Your existing `Dockerfile` should work!** It's already configured correctly.
|
| 77 |
-
|
| 78 |
-
### Step 6: Alternative - Use app.py (Simpler)
|
| 79 |
-
|
| 80 |
-
If you want to use the simpler `app.py` approach:
|
| 81 |
-
|
| 82 |
-
1. In Space settings, set:
|
| 83 |
-
- **SDK**: **Gradio** or **Streamlit** (but we'll override)
|
| 84 |
-
- **App file**: `app.py`
|
| 85 |
-
|
| 86 |
-
2. Hugging Face will automatically:
|
| 87 |
-
- Install dependencies from `requirements.txt`
|
| 88 |
-
- Run `python app.py`
|
| 89 |
-
- Expose on port 7860
|
| 90 |
-
|
| 91 |
-
### Step 7: Deploy!
|
| 92 |
-
|
| 93 |
-
1. **Push to GitHub** (if not already):
|
| 94 |
-
```bash
|
| 95 |
-
git add rag-backend/app.py
|
| 96 |
-
git commit -m "Add Hugging Face Spaces entry point"
|
| 97 |
-
git push origin main
|
| 98 |
-
```
|
| 99 |
-
|
| 100 |
-
2. **Hugging Face will auto-deploy** from your GitHub repo!
|
| 101 |
-
|
| 102 |
-
3. **Wait for build** (5-10 minutes first time, faster after)
|
| 103 |
-
|
| 104 |
-
4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space`
|
| 105 |
-
|
| 106 |
-
## 🔧 Configuration Options
|
| 107 |
-
|
| 108 |
-
### Option A: Docker (Recommended)
|
| 109 |
-
|
| 110 |
-
**Advantages:**
|
| 111 |
-
- Full control over environment
|
| 112 |
-
- Can customize Python version
|
| 113 |
-
- Better for production
|
| 114 |
-
|
| 115 |
-
**Setup:**
|
| 116 |
-
- Use existing `Dockerfile`
|
| 117 |
-
- Hugging Face will build and run it
|
| 118 |
-
- Exposes on port 7860 automatically
|
| 119 |
-
|
| 120 |
-
### Option B: app.py (Simpler)
|
| 121 |
-
|
| 122 |
-
**Advantages:**
|
| 123 |
-
- Simpler setup
|
| 124 |
-
- Faster builds
|
| 125 |
-
- Good for development
|
| 126 |
-
|
| 127 |
-
**Setup:**
|
| 128 |
-
- Create `app.py` in `rag-backend/` (already done)
|
| 129 |
-
- Hugging Face runs it automatically
|
| 130 |
-
|
| 131 |
-
## 📝 Environment Variables Reference
|
| 132 |
-
|
| 133 |
-
| Variable | Required | Description |
|
| 134 |
-
|----------|----------|-------------|
|
| 135 |
-
| `GEMINI_API_KEY` | ✅ Yes | Your Google Gemini API key |
|
| 136 |
-
| `ENV` | ✅ Yes | Set to `prod` for production |
|
| 137 |
-
| `LLM_PROVIDER` | ✅ Yes | `gemini` or `openai` |
|
| 138 |
-
| `ALLOWED_ORIGINS` | ⚠️ Recommended | CORS allowed origins (comma-separated) |
|
| 139 |
-
| `JWT_SECRET` | ⚠️ Recommended | JWT secret for authentication |
|
| 140 |
-
| `DEBUG` | ❌ Optional | Set to `false` in production |
|
| 141 |
-
| `OPENAI_API_KEY` | ❌ Optional | If using OpenAI instead of Gemini |
|
| 142 |
-
|
| 143 |
-
## 🌐 CORS Configuration
|
| 144 |
-
|
| 145 |
-
After deployment, update `ALLOWED_ORIGINS` to include:
|
| 146 |
-
- Your Cloudflare Pages frontend URL
|
| 147 |
-
- Your Cloudflare Workers backend URL
|
| 148 |
-
- Any other origins that need access
|
| 149 |
-
|
| 150 |
-
Example:
|
| 151 |
-
```
|
| 152 |
-
ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev
|
| 153 |
-
```
|
| 154 |
-
|
| 155 |
-
## 🔄 Updating Deployment
|
| 156 |
-
|
| 157 |
-
**Automatic (Recommended):**
|
| 158 |
-
- Push to GitHub → Hugging Face auto-deploys
|
| 159 |
-
|
| 160 |
-
**Manual:**
|
| 161 |
-
- Go to Space → Settings → "Rebuild Space"
|
| 162 |
-
|
| 163 |
-
## 📊 Resource Limits
|
| 164 |
-
|
| 165 |
-
### Free Tier:
|
| 166 |
-
- ✅ **CPU**: Basic (sufficient for RAG)
|
| 167 |
-
- ✅ **Storage**: 50GB (plenty for vector DB)
|
| 168 |
-
- ✅ **Memory**: 16GB RAM
|
| 169 |
-
- ✅ **Build time**: 20 minutes
|
| 170 |
-
- ✅ **Sleep after inactivity**: 48 hours (wakes on request)
|
| 171 |
-
|
| 172 |
-
### Paid Tiers:
|
| 173 |
-
- **CPU upgrade**: Better performance
|
| 174 |
-
- **GPU**: For heavy ML workloads
|
| 175 |
-
- **No sleep**: Always-on service
|
| 176 |
-
|
| 177 |
-
## 🧪 Testing Deployment
|
| 178 |
-
|
| 179 |
-
After deployment, test your endpoints:
|
| 180 |
-
|
| 181 |
-
```bash
|
| 182 |
-
# Health check
|
| 183 |
-
curl https://your-username-clientsphere-rag-backend.hf.space/health/live
|
| 184 |
-
|
| 185 |
-
# KB Stats (with auth)
|
| 186 |
-
curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test
|
| 187 |
-
```
|
| 188 |
-
|
| 189 |
-
## 🔗 Update Frontend
|
| 190 |
-
|
| 191 |
-
After deployment, update Cloudflare Pages environment variable:
|
| 192 |
-
|
| 193 |
-
```
|
| 194 |
-
VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space
|
| 195 |
-
```
|
| 196 |
-
|
| 197 |
-
Then redeploy frontend:
|
| 198 |
-
```bash
|
| 199 |
-
npm run build
|
| 200 |
-
npx wrangler pages deploy dist --project-name=clientsphere
|
| 201 |
-
```
|
| 202 |
-
|
| 203 |
-
## ✅ Advantages Over Render
|
| 204 |
-
|
| 205 |
-
| Feature | Hugging Face Spaces | Render |
|
| 206 |
-
|---------|-------------------|--------|
|
| 207 |
-
| Free Tier | ✅ Generous | ⚠️ Limited |
|
| 208 |
-
| ML Libraries | ✅ Full support | ✅ Full support |
|
| 209 |
-
| Auto-deploy | ✅ GitHub integration | ✅ GitHub integration |
|
| 210 |
-
| Storage | ✅ 50GB free | ⚠️ Limited |
|
| 211 |
-
| Sleep Mode | ✅ Wakes on request | ❌ No sleep mode |
|
| 212 |
-
| GPU Support | ✅ Available | ❌ Not available |
|
| 213 |
-
| Community | ✅ Large ML community | ⚠️ Smaller |
|
| 214 |
-
|
| 215 |
-
## 🎯 Summary
|
| 216 |
-
|
| 217 |
-
1. ✅ Create Hugging Face Space
|
| 218 |
-
2. ✅ Connect GitHub repository (rag-backend/)
|
| 219 |
-
3. ✅ Set environment variables
|
| 220 |
-
4. ✅ Deploy (automatic on push)
|
| 221 |
-
5. ✅ Update frontend `VITE_RAG_API_URL`
|
| 222 |
-
6. ✅ Test and enjoy!
|
| 223 |
-
|
| 224 |
-
**Your RAG backend will be live at:**
|
| 225 |
-
`https://your-username-clientsphere-rag-backend.hf.space`
|
| 226 |
-
|
| 227 |
-
---
|
| 228 |
-
|
| 229 |
-
**Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces)
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 1 |
+
# 🚀 Deploy RAG Backend to Hugging Face Spaces
|
| 2 |
+
|
| 3 |
+
Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies!
|
| 4 |
+
|
| 5 |
+
## ✅ Why Hugging Face Spaces?
|
| 6 |
+
|
| 7 |
+
- ✅ **Free tier** with generous limits
|
| 8 |
+
- ✅ **Full Python 3.11+** support
|
| 9 |
+
- ✅ **ML libraries** fully supported (sentence-transformers, chromadb, etc.)
|
| 10 |
+
- ✅ **Persistent storage** for vector database
|
| 11 |
+
- ✅ **No bundle size limits**
|
| 12 |
+
- ✅ **GPU support** available (paid)
|
| 13 |
+
- ✅ **Automatic HTTPS** and custom domains
|
| 14 |
+
- ✅ **GitHub integration** (auto-deploy on push)
|
| 15 |
+
|
| 16 |
+
## 📋 Prerequisites
|
| 17 |
+
|
| 18 |
+
1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co)
|
| 19 |
+
2. **GitHub Repository**: Your code should be in a GitHub repository
|
| 20 |
+
3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey)
|
| 21 |
+
|
| 22 |
+
## 🚀 Step-by-Step Deployment
|
| 23 |
+
|
| 24 |
+
### Step 1: Prepare Your Repository
|
| 25 |
+
|
| 26 |
+
Your `rag-backend/` directory should contain:
|
| 27 |
+
- ✅ `app.py` - Entry point (already created)
|
| 28 |
+
- ✅ `requirements.txt` - Dependencies
|
| 29 |
+
- ✅ `app/main.py` - FastAPI application
|
| 30 |
+
- ✅ All other application files
|
| 31 |
+
|
| 32 |
+
### Step 2: Create Hugging Face Space
|
| 33 |
+
|
| 34 |
+
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 35 |
+
2. Click **"Create new Space"**
|
| 36 |
+
3. Configure:
|
| 37 |
+
- **Owner**: Your username
|
| 38 |
+
- **Space name**: `clientsphere-rag-backend` (or your choice)
|
| 39 |
+
- **SDK**: **Docker** (recommended) or **Gradio** (if you want UI)
|
| 40 |
+
- **Hardware**:
|
| 41 |
+
- **CPU basic** (free) - Good for testing
|
| 42 |
+
- **CPU upgrade** (paid) - Better performance
|
| 43 |
+
- **GPU** (paid) - For heavy ML workloads
|
| 44 |
+
|
| 45 |
+
### Step 3: Connect GitHub Repository
|
| 46 |
+
|
| 47 |
+
1. In Space creation, select **"Repository"** as source
|
| 48 |
+
2. Choose your GitHub repository
|
| 49 |
+
3. Set **Repository path** to: `rag-backend/` (subdirectory)
|
| 50 |
+
4. Click **"Create Space"**
|
| 51 |
+
|
| 52 |
+
### Step 4: Configure Environment Variables
|
| 53 |
+
|
| 54 |
+
1. Go to your Space's **Settings** tab
|
| 55 |
+
2. Scroll to **"Repository secrets"** or **"Variables"**
|
| 56 |
+
3. Add these secrets:
|
| 57 |
+
|
| 58 |
+
**Required:**
|
| 59 |
+
```
|
| 60 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 61 |
+
ENV=prod
|
| 62 |
+
LLM_PROVIDER=gemini
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
**Optional (but recommended):**
|
| 66 |
+
```
|
| 67 |
+
ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev
|
| 68 |
+
JWT_SECRET=your_secure_jwt_secret
|
| 69 |
+
DEBUG=false
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Step 5: Configure Docker (if using Docker SDK)
|
| 73 |
+
|
| 74 |
+
If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`.
|
| 75 |
+
|
| 76 |
+
**Your existing `Dockerfile` should work!** It's already configured correctly.
|
| 77 |
+
|
| 78 |
+
### Step 6: Alternative - Use app.py (Simpler)
|
| 79 |
+
|
| 80 |
+
If you want to use the simpler `app.py` approach:
|
| 81 |
+
|
| 82 |
+
1. In Space settings, set:
|
| 83 |
+
- **SDK**: **Gradio** or **Streamlit** (but we'll override)
|
| 84 |
+
- **App file**: `app.py`
|
| 85 |
+
|
| 86 |
+
2. Hugging Face will automatically:
|
| 87 |
+
- Install dependencies from `requirements.txt`
|
| 88 |
+
- Run `python app.py`
|
| 89 |
+
- Expose on port 7860
|
| 90 |
+
|
| 91 |
+
### Step 7: Deploy!
|
| 92 |
+
|
| 93 |
+
1. **Push to GitHub** (if not already):
|
| 94 |
+
```bash
|
| 95 |
+
git add rag-backend/app.py
|
| 96 |
+
git commit -m "Add Hugging Face Spaces entry point"
|
| 97 |
+
git push origin main
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
2. **Hugging Face will auto-deploy** from your GitHub repo!
|
| 101 |
+
|
| 102 |
+
3. **Wait for build** (5-10 minutes first time, faster after)
|
| 103 |
+
|
| 104 |
+
4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space`
|
| 105 |
+
|
| 106 |
+
## 🔧 Configuration Options
|
| 107 |
+
|
| 108 |
+
### Option A: Docker (Recommended)
|
| 109 |
+
|
| 110 |
+
**Advantages:**
|
| 111 |
+
- Full control over environment
|
| 112 |
+
- Can customize Python version
|
| 113 |
+
- Better for production
|
| 114 |
+
|
| 115 |
+
**Setup:**
|
| 116 |
+
- Use existing `Dockerfile`
|
| 117 |
+
- Hugging Face will build and run it
|
| 118 |
+
- Exposes on port 7860 automatically
|
| 119 |
+
|
| 120 |
+
### Option B: app.py (Simpler)
|
| 121 |
+
|
| 122 |
+
**Advantages:**
|
| 123 |
+
- Simpler setup
|
| 124 |
+
- Faster builds
|
| 125 |
+
- Good for development
|
| 126 |
+
|
| 127 |
+
**Setup:**
|
| 128 |
+
- Create `app.py` in `rag-backend/` (already done)
|
| 129 |
+
- Hugging Face runs it automatically
|
| 130 |
+
|
| 131 |
+
## 📝 Environment Variables Reference
|
| 132 |
+
|
| 133 |
+
| Variable | Required | Description |
|
| 134 |
+
|----------|----------|-------------|
|
| 135 |
+
| `GEMINI_API_KEY` | ✅ Yes | Your Google Gemini API key |
|
| 136 |
+
| `ENV` | ✅ Yes | Set to `prod` for production |
|
| 137 |
+
| `LLM_PROVIDER` | ✅ Yes | `gemini` or `openai` |
|
| 138 |
+
| `ALLOWED_ORIGINS` | ⚠️ Recommended | CORS allowed origins (comma-separated) |
|
| 139 |
+
| `JWT_SECRET` | ⚠️ Recommended | JWT secret for authentication |
|
| 140 |
+
| `DEBUG` | ❌ Optional | Set to `false` in production |
|
| 141 |
+
| `OPENAI_API_KEY` | ❌ Optional | If using OpenAI instead of Gemini |
|
| 142 |
+
|
| 143 |
+
## 🌐 CORS Configuration
|
| 144 |
+
|
| 145 |
+
After deployment, update `ALLOWED_ORIGINS` to include:
|
| 146 |
+
- Your Cloudflare Pages frontend URL
|
| 147 |
+
- Your Cloudflare Workers backend URL
|
| 148 |
+
- Any other origins that need access
|
| 149 |
+
|
| 150 |
+
Example:
|
| 151 |
+
```
|
| 152 |
+
ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## 🔄 Updating Deployment
|
| 156 |
+
|
| 157 |
+
**Automatic (Recommended):**
|
| 158 |
+
- Push to GitHub → Hugging Face auto-deploys
|
| 159 |
+
|
| 160 |
+
**Manual:**
|
| 161 |
+
- Go to Space → Settings → "Rebuild Space"
|
| 162 |
+
|
| 163 |
+
## 📊 Resource Limits
|
| 164 |
+
|
| 165 |
+
### Free Tier:
|
| 166 |
+
- ✅ **CPU**: Basic (sufficient for RAG)
|
| 167 |
+
- ✅ **Storage**: 50GB (plenty for vector DB)
|
| 168 |
+
- ✅ **Memory**: 16GB RAM
|
| 169 |
+
- ✅ **Build time**: 20 minutes
|
| 170 |
+
- ✅ **Sleep after inactivity**: 48 hours (wakes on request)
|
| 171 |
+
|
| 172 |
+
### Paid Tiers:
|
| 173 |
+
- **CPU upgrade**: Better performance
|
| 174 |
+
- **GPU**: For heavy ML workloads
|
| 175 |
+
- **No sleep**: Always-on service
|
| 176 |
+
|
| 177 |
+
## 🧪 Testing Deployment
|
| 178 |
+
|
| 179 |
+
After deployment, test your endpoints:
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
# Health check
|
| 183 |
+
curl https://your-username-clientsphere-rag-backend.hf.space/health/live
|
| 184 |
+
|
| 185 |
+
# KB Stats (with auth)
|
| 186 |
+
curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## 🔗 Update Frontend
|
| 190 |
+
|
| 191 |
+
After deployment, update Cloudflare Pages environment variable:
|
| 192 |
+
|
| 193 |
+
```
|
| 194 |
+
VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
Then redeploy frontend:
|
| 198 |
+
```bash
|
| 199 |
+
npm run build
|
| 200 |
+
npx wrangler pages deploy dist --project-name=clientsphere
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
## ✅ Advantages Over Render
|
| 204 |
+
|
| 205 |
+
| Feature | Hugging Face Spaces | Render |
|
| 206 |
+
|---------|-------------------|--------|
|
| 207 |
+
| Free Tier | ✅ Generous | ⚠️ Limited |
|
| 208 |
+
| ML Libraries | ✅ Full support | ✅ Full support |
|
| 209 |
+
| Auto-deploy | ✅ GitHub integration | ✅ GitHub integration |
|
| 210 |
+
| Storage | ✅ 50GB free | ⚠️ Limited |
|
| 211 |
+
| Sleep Mode | ✅ Wakes on request | ❌ No sleep mode |
|
| 212 |
+
| GPU Support | ✅ Available | ❌ Not available |
|
| 213 |
+
| Community | ✅ Large ML community | ⚠️ Smaller |
|
| 214 |
+
|
| 215 |
+
## 🎯 Summary
|
| 216 |
+
|
| 217 |
+
1. ✅ Create Hugging Face Space
|
| 218 |
+
2. ✅ Connect GitHub repository (rag-backend/)
|
| 219 |
+
3. ✅ Set environment variables
|
| 220 |
+
4. ✅ Deploy (automatic on push)
|
| 221 |
+
5. ✅ Update frontend `VITE_RAG_API_URL`
|
| 222 |
+
6. ✅ Test and enjoy!
|
| 223 |
+
|
| 224 |
+
**Your RAG backend will be live at:**
|
| 225 |
+
`https://your-username-clientsphere-rag-backend.hf.space`
|
| 226 |
+
|
| 227 |
+
---
|
| 228 |
+
|
| 229 |
+
**Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces)
|
| 230 |
+
|
|
|
app.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hugging Face Spaces entry point for RAG Backend.
|
| 3 |
-
This file is used when deploying to Hugging Face Spaces.
|
| 4 |
-
"""
|
| 5 |
-
import os
|
| 6 |
-
import uvicorn
|
| 7 |
-
from app.main import app
|
| 8 |
-
|
| 9 |
-
if __name__ == "__main__":
|
| 10 |
-
# Hugging Face Spaces provides PORT environment variable (defaults to 7860)
|
| 11 |
-
port = int(os.getenv("PORT", 7860))
|
| 12 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 13 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Spaces entry point for RAG Backend.
|
| 3 |
+
This file is used when deploying to Hugging Face Spaces.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import uvicorn
|
| 7 |
+
from app.main import app
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
# Hugging Face Spaces provides PORT environment variable (defaults to 7860)
|
| 11 |
+
port = int(os.getenv("PORT", 7860))
|
| 12 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
| 13 |
+
|
app/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ClientSphere RAG Backend Application.
|
| 3 |
-
"""
|
| 4 |
-
__version__ = "1.0.0"
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ClientSphere RAG Backend Application.
|
| 3 |
+
"""
|
| 4 |
+
__version__ = "1.0.0"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
app/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (222 Bytes)
|
|
|
app/__pycache__/config.cpython-313.pyc
DELETED
|
Binary file (3.22 kB)
|
|
|
app/__pycache__/main.cpython-313.pyc
DELETED
|
Binary file (38.9 kB)
|
|
|
app/billing/__pycache__/pricing.cpython-313.pyc
DELETED
|
Binary file (2.32 kB)
|
|
|
app/billing/__pycache__/quota.cpython-313.pyc
DELETED
|
Binary file (5.12 kB)
|
|
|
app/billing/__pycache__/usage_tracker.cpython-313.pyc
DELETED
|
Binary file (5.56 kB)
|
|
|
app/billing/pricing.py
CHANGED
|
@@ -1,57 +1,57 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pricing table for LLM providers.
|
| 3 |
-
Used to calculate estimated costs from token usage.
|
| 4 |
-
"""
|
| 5 |
-
from typing import Dict, Optional
|
| 6 |
-
|
| 7 |
-
# Pricing per 1M tokens (as of 2024, update as needed)
|
| 8 |
-
PRICING_TABLE: Dict[str, Dict[str, float]] = {
|
| 9 |
-
"gemini": {
|
| 10 |
-
"gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens
|
| 11 |
-
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
| 12 |
-
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
| 13 |
-
"gemini-1.0-pro": {"input": 0.50, "output": 1.50},
|
| 14 |
-
"default": {"input": 0.50, "output": 1.50}
|
| 15 |
-
},
|
| 16 |
-
"openai": {
|
| 17 |
-
"gpt-4": {"input": 30.00, "output": 60.00},
|
| 18 |
-
"gpt-4-turbo": {"input": 10.00, "output": 30.00},
|
| 19 |
-
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
|
| 20 |
-
"default": {"input": 0.50, "output": 1.50}
|
| 21 |
-
}
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def calculate_cost(
|
| 26 |
-
provider: str,
|
| 27 |
-
model: str,
|
| 28 |
-
prompt_tokens: int,
|
| 29 |
-
completion_tokens: int
|
| 30 |
-
) -> float:
|
| 31 |
-
"""
|
| 32 |
-
Calculate estimated cost in USD based on token usage.
|
| 33 |
-
|
| 34 |
-
Args:
|
| 35 |
-
provider: "gemini" or "openai"
|
| 36 |
-
model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo")
|
| 37 |
-
prompt_tokens: Number of input tokens
|
| 38 |
-
completion_tokens: Number of output tokens
|
| 39 |
-
|
| 40 |
-
Returns:
|
| 41 |
-
Estimated cost in USD
|
| 42 |
-
"""
|
| 43 |
-
provider_pricing = PRICING_TABLE.get(provider.lower(), {})
|
| 44 |
-
model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
|
| 45 |
-
|
| 46 |
-
# Calculate cost: (tokens / 1M) * price_per_1M
|
| 47 |
-
input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"]
|
| 48 |
-
output_cost = (completion_tokens / 1_000_000) * model_pricing["output"]
|
| 49 |
-
|
| 50 |
-
return input_cost + output_cost
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def get_model_pricing(provider: str, model: str) -> Dict[str, float]:
|
| 54 |
-
"""Get pricing for a specific model."""
|
| 55 |
-
provider_pricing = PRICING_TABLE.get(provider.lower(), {})
|
| 56 |
-
return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
|
| 57 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pricing table for LLM providers.
|
| 3 |
+
Used to calculate estimated costs from token usage.
|
| 4 |
+
"""
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
# Pricing per 1M tokens (as of 2024, update as needed)
|
| 8 |
+
PRICING_TABLE: Dict[str, Dict[str, float]] = {
|
| 9 |
+
"gemini": {
|
| 10 |
+
"gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens
|
| 11 |
+
"gemini-1.5-pro": {"input": 1.25, "output": 5.00},
|
| 12 |
+
"gemini-1.5-flash": {"input": 0.075, "output": 0.30},
|
| 13 |
+
"gemini-1.0-pro": {"input": 0.50, "output": 1.50},
|
| 14 |
+
"default": {"input": 0.50, "output": 1.50}
|
| 15 |
+
},
|
| 16 |
+
"openai": {
|
| 17 |
+
"gpt-4": {"input": 30.00, "output": 60.00},
|
| 18 |
+
"gpt-4-turbo": {"input": 10.00, "output": 30.00},
|
| 19 |
+
"gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
|
| 20 |
+
"default": {"input": 0.50, "output": 1.50}
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def calculate_cost(
|
| 26 |
+
provider: str,
|
| 27 |
+
model: str,
|
| 28 |
+
prompt_tokens: int,
|
| 29 |
+
completion_tokens: int
|
| 30 |
+
) -> float:
|
| 31 |
+
"""
|
| 32 |
+
Calculate estimated cost in USD based on token usage.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
provider: "gemini" or "openai"
|
| 36 |
+
model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo")
|
| 37 |
+
prompt_tokens: Number of input tokens
|
| 38 |
+
completion_tokens: Number of output tokens
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Estimated cost in USD
|
| 42 |
+
"""
|
| 43 |
+
provider_pricing = PRICING_TABLE.get(provider.lower(), {})
|
| 44 |
+
model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
|
| 45 |
+
|
| 46 |
+
# Calculate cost: (tokens / 1M) * price_per_1M
|
| 47 |
+
input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"]
|
| 48 |
+
output_cost = (completion_tokens / 1_000_000) * model_pricing["output"]
|
| 49 |
+
|
| 50 |
+
return input_cost + output_cost
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_model_pricing(provider: str, model: str) -> Dict[str, float]:
|
| 54 |
+
"""Get pricing for a specific model."""
|
| 55 |
+
provider_pricing = PRICING_TABLE.get(provider.lower(), {})
|
| 56 |
+
return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
|
| 57 |
+
|
app/billing/quota.py
CHANGED
|
@@ -1,131 +1,131 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Quota management and enforcement.
|
| 3 |
-
"""
|
| 4 |
-
from sqlalchemy.orm import Session
|
| 5 |
-
from sqlalchemy import func, and_
|
| 6 |
-
from datetime import datetime, timedelta
|
| 7 |
-
from typing import Optional, Tuple
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
from app.db.models import TenantPlan, UsageMonthly, Tenant
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
# Plan limits (chats per month)
|
| 14 |
-
PLAN_LIMITS = {
|
| 15 |
-
"starter": 500,
|
| 16 |
-
"growth": 5000,
|
| 17 |
-
"pro": -1 # -1 means unlimited
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]:
|
| 22 |
-
"""Get tenant's current plan."""
|
| 23 |
-
return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]:
|
| 27 |
-
"""Get monthly usage for tenant."""
|
| 28 |
-
now = datetime.utcnow()
|
| 29 |
-
target_year = year or now.year
|
| 30 |
-
target_month = month or now.month
|
| 31 |
-
|
| 32 |
-
return db.query(UsageMonthly).filter(
|
| 33 |
-
and_(
|
| 34 |
-
UsageMonthly.tenant_id == tenant_id,
|
| 35 |
-
UsageMonthly.year == target_year,
|
| 36 |
-
UsageMonthly.month == target_month
|
| 37 |
-
)
|
| 38 |
-
).first()
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]:
|
| 42 |
-
"""
|
| 43 |
-
Check if tenant has quota remaining for the current month.
|
| 44 |
-
|
| 45 |
-
Returns:
|
| 46 |
-
(has_quota, error_message)
|
| 47 |
-
has_quota: True if quota available, False if exceeded
|
| 48 |
-
error_message: None if quota available, error message if exceeded
|
| 49 |
-
"""
|
| 50 |
-
# Get tenant plan
|
| 51 |
-
plan = get_tenant_plan(db, tenant_id)
|
| 52 |
-
|
| 53 |
-
if not plan:
|
| 54 |
-
# Default to starter plan if no plan assigned
|
| 55 |
-
logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter")
|
| 56 |
-
monthly_limit = PLAN_LIMITS.get("starter", 500)
|
| 57 |
-
else:
|
| 58 |
-
monthly_limit = plan.monthly_chat_limit
|
| 59 |
-
|
| 60 |
-
# Unlimited plan (-1) always passes
|
| 61 |
-
if monthly_limit == -1:
|
| 62 |
-
return True, None
|
| 63 |
-
|
| 64 |
-
# Get current month usage
|
| 65 |
-
now = datetime.utcnow()
|
| 66 |
-
monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
|
| 67 |
-
|
| 68 |
-
current_usage = monthly_usage.total_requests if monthly_usage else 0
|
| 69 |
-
|
| 70 |
-
# Check if quota exceeded
|
| 71 |
-
if current_usage >= monthly_limit:
|
| 72 |
-
return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan."
|
| 73 |
-
|
| 74 |
-
return True, None
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def ensure_tenant_exists(db: Session, tenant_id: str) -> None:
|
| 78 |
-
"""Ensure tenant record exists in database."""
|
| 79 |
-
tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first()
|
| 80 |
-
if not tenant:
|
| 81 |
-
# Create tenant with default starter plan
|
| 82 |
-
tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}")
|
| 83 |
-
db.add(tenant)
|
| 84 |
-
|
| 85 |
-
# Create default starter plan
|
| 86 |
-
plan = TenantPlan(
|
| 87 |
-
tenant_id=tenant_id,
|
| 88 |
-
plan_name="starter",
|
| 89 |
-
monthly_chat_limit=PLAN_LIMITS["starter"]
|
| 90 |
-
)
|
| 91 |
-
db.add(plan)
|
| 92 |
-
db.commit()
|
| 93 |
-
logger.info(f"Created tenant {tenant_id} with starter plan")
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan:
|
| 97 |
-
"""
|
| 98 |
-
Set tenant's subscription plan.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
db: Database session
|
| 102 |
-
tenant_id: Tenant ID
|
| 103 |
-
plan_name: "starter", "growth", or "pro"
|
| 104 |
-
|
| 105 |
-
Returns:
|
| 106 |
-
Updated TenantPlan
|
| 107 |
-
"""
|
| 108 |
-
if plan_name not in PLAN_LIMITS:
|
| 109 |
-
raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}")
|
| 110 |
-
|
| 111 |
-
# Ensure tenant exists
|
| 112 |
-
ensure_tenant_exists(db, tenant_id)
|
| 113 |
-
|
| 114 |
-
# Get or create plan
|
| 115 |
-
plan = get_tenant_plan(db, tenant_id)
|
| 116 |
-
if plan:
|
| 117 |
-
plan.plan_name = plan_name
|
| 118 |
-
plan.monthly_chat_limit = PLAN_LIMITS[plan_name]
|
| 119 |
-
plan.updated_at = datetime.utcnow()
|
| 120 |
-
else:
|
| 121 |
-
plan = TenantPlan(
|
| 122 |
-
tenant_id=tenant_id,
|
| 123 |
-
plan_name=plan_name,
|
| 124 |
-
monthly_chat_limit=PLAN_LIMITS[plan_name]
|
| 125 |
-
)
|
| 126 |
-
db.add(plan)
|
| 127 |
-
|
| 128 |
-
db.commit()
|
| 129 |
-
db.refresh(plan)
|
| 130 |
-
return plan
|
| 131 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quota management and enforcement.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy.orm import Session
|
| 5 |
+
from sqlalchemy import func, and_
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from app.db.models import TenantPlan, UsageMonthly, Tenant
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Plan limits (chats per month)
|
| 14 |
+
PLAN_LIMITS = {
|
| 15 |
+
"starter": 500,
|
| 16 |
+
"growth": 5000,
|
| 17 |
+
"pro": -1 # -1 means unlimited
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]:
|
| 22 |
+
"""Get tenant's current plan."""
|
| 23 |
+
return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]:
|
| 27 |
+
"""Get monthly usage for tenant."""
|
| 28 |
+
now = datetime.utcnow()
|
| 29 |
+
target_year = year or now.year
|
| 30 |
+
target_month = month or now.month
|
| 31 |
+
|
| 32 |
+
return db.query(UsageMonthly).filter(
|
| 33 |
+
and_(
|
| 34 |
+
UsageMonthly.tenant_id == tenant_id,
|
| 35 |
+
UsageMonthly.year == target_year,
|
| 36 |
+
UsageMonthly.month == target_month
|
| 37 |
+
)
|
| 38 |
+
).first()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]:
|
| 42 |
+
"""
|
| 43 |
+
Check if tenant has quota remaining for the current month.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
(has_quota, error_message)
|
| 47 |
+
has_quota: True if quota available, False if exceeded
|
| 48 |
+
error_message: None if quota available, error message if exceeded
|
| 49 |
+
"""
|
| 50 |
+
# Get tenant plan
|
| 51 |
+
plan = get_tenant_plan(db, tenant_id)
|
| 52 |
+
|
| 53 |
+
if not plan:
|
| 54 |
+
# Default to starter plan if no plan assigned
|
| 55 |
+
logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter")
|
| 56 |
+
monthly_limit = PLAN_LIMITS.get("starter", 500)
|
| 57 |
+
else:
|
| 58 |
+
monthly_limit = plan.monthly_chat_limit
|
| 59 |
+
|
| 60 |
+
# Unlimited plan (-1) always passes
|
| 61 |
+
if monthly_limit == -1:
|
| 62 |
+
return True, None
|
| 63 |
+
|
| 64 |
+
# Get current month usage
|
| 65 |
+
now = datetime.utcnow()
|
| 66 |
+
monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
|
| 67 |
+
|
| 68 |
+
current_usage = monthly_usage.total_requests if monthly_usage else 0
|
| 69 |
+
|
| 70 |
+
# Check if quota exceeded
|
| 71 |
+
if current_usage >= monthly_limit:
|
| 72 |
+
return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan."
|
| 73 |
+
|
| 74 |
+
return True, None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def ensure_tenant_exists(db: Session, tenant_id: str) -> None:
|
| 78 |
+
"""Ensure tenant record exists in database."""
|
| 79 |
+
tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first()
|
| 80 |
+
if not tenant:
|
| 81 |
+
# Create tenant with default starter plan
|
| 82 |
+
tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}")
|
| 83 |
+
db.add(tenant)
|
| 84 |
+
|
| 85 |
+
# Create default starter plan
|
| 86 |
+
plan = TenantPlan(
|
| 87 |
+
tenant_id=tenant_id,
|
| 88 |
+
plan_name="starter",
|
| 89 |
+
monthly_chat_limit=PLAN_LIMITS["starter"]
|
| 90 |
+
)
|
| 91 |
+
db.add(plan)
|
| 92 |
+
db.commit()
|
| 93 |
+
logger.info(f"Created tenant {tenant_id} with starter plan")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan:
|
| 97 |
+
"""
|
| 98 |
+
Set tenant's subscription plan.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
db: Database session
|
| 102 |
+
tenant_id: Tenant ID
|
| 103 |
+
plan_name: "starter", "growth", or "pro"
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Updated TenantPlan
|
| 107 |
+
"""
|
| 108 |
+
if plan_name not in PLAN_LIMITS:
|
| 109 |
+
raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}")
|
| 110 |
+
|
| 111 |
+
# Ensure tenant exists
|
| 112 |
+
ensure_tenant_exists(db, tenant_id)
|
| 113 |
+
|
| 114 |
+
# Get or create plan
|
| 115 |
+
plan = get_tenant_plan(db, tenant_id)
|
| 116 |
+
if plan:
|
| 117 |
+
plan.plan_name = plan_name
|
| 118 |
+
plan.monthly_chat_limit = PLAN_LIMITS[plan_name]
|
| 119 |
+
plan.updated_at = datetime.utcnow()
|
| 120 |
+
else:
|
| 121 |
+
plan = TenantPlan(
|
| 122 |
+
tenant_id=tenant_id,
|
| 123 |
+
plan_name=plan_name,
|
| 124 |
+
monthly_chat_limit=PLAN_LIMITS[plan_name]
|
| 125 |
+
)
|
| 126 |
+
db.add(plan)
|
| 127 |
+
|
| 128 |
+
db.commit()
|
| 129 |
+
db.refresh(plan)
|
| 130 |
+
return plan
|
| 131 |
+
|
app/billing/usage_tracker.py
CHANGED
|
@@ -1,173 +1,173 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Usage tracking service.
|
| 3 |
-
Tracks token usage and costs for each LLM request.
|
| 4 |
-
"""
|
| 5 |
-
from sqlalchemy.orm import Session
|
| 6 |
-
from sqlalchemy import func, and_
|
| 7 |
-
from datetime import datetime, timedelta
|
| 8 |
-
from typing import Optional
|
| 9 |
-
import uuid
|
| 10 |
-
import logging
|
| 11 |
-
|
| 12 |
-
from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant
|
| 13 |
-
from app.billing.pricing import calculate_cost
|
| 14 |
-
from app.billing.quota import ensure_tenant_exists
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def track_usage(
|
| 20 |
-
db: Session,
|
| 21 |
-
tenant_id: str,
|
| 22 |
-
user_id: str,
|
| 23 |
-
kb_id: str,
|
| 24 |
-
provider: str,
|
| 25 |
-
model: str,
|
| 26 |
-
prompt_tokens: int,
|
| 27 |
-
completion_tokens: int,
|
| 28 |
-
request_timestamp: Optional[datetime] = None
|
| 29 |
-
) -> UsageEvent:
|
| 30 |
-
"""
|
| 31 |
-
Track a single usage event.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
db: Database session
|
| 35 |
-
tenant_id: Tenant ID
|
| 36 |
-
user_id: User ID
|
| 37 |
-
kb_id: Knowledge base ID
|
| 38 |
-
provider: "gemini" or "openai"
|
| 39 |
-
model: Model name
|
| 40 |
-
prompt_tokens: Input tokens
|
| 41 |
-
completion_tokens: Output tokens
|
| 42 |
-
request_timestamp: Request timestamp (defaults to now)
|
| 43 |
-
|
| 44 |
-
Returns:
|
| 45 |
-
Created UsageEvent
|
| 46 |
-
"""
|
| 47 |
-
# Ensure tenant exists
|
| 48 |
-
ensure_tenant_exists(db, tenant_id)
|
| 49 |
-
|
| 50 |
-
# Calculate cost
|
| 51 |
-
total_tokens = prompt_tokens + completion_tokens
|
| 52 |
-
estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens)
|
| 53 |
-
|
| 54 |
-
# Create usage event
|
| 55 |
-
request_id = f"req_{uuid.uuid4().hex[:16]}"
|
| 56 |
-
timestamp = request_timestamp or datetime.utcnow()
|
| 57 |
-
|
| 58 |
-
usage_event = UsageEvent(
|
| 59 |
-
request_id=request_id,
|
| 60 |
-
tenant_id=tenant_id,
|
| 61 |
-
user_id=user_id,
|
| 62 |
-
kb_id=kb_id,
|
| 63 |
-
provider=provider,
|
| 64 |
-
model=model,
|
| 65 |
-
prompt_tokens=prompt_tokens,
|
| 66 |
-
completion_tokens=completion_tokens,
|
| 67 |
-
total_tokens=total_tokens,
|
| 68 |
-
estimated_cost_usd=estimated_cost,
|
| 69 |
-
request_timestamp=timestamp
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
db.add(usage_event)
|
| 73 |
-
|
| 74 |
-
# Update daily aggregation
|
| 75 |
-
_update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
|
| 76 |
-
|
| 77 |
-
# Update monthly aggregation
|
| 78 |
-
_update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
|
| 79 |
-
|
| 80 |
-
db.commit()
|
| 81 |
-
db.refresh(usage_event)
|
| 82 |
-
|
| 83 |
-
logger.info(
|
| 84 |
-
f"Tracked usage: tenant={tenant_id}, provider={provider}, "
|
| 85 |
-
f"tokens={total_tokens}, cost=${estimated_cost:.6f}"
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
return usage_event
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def _update_daily_usage(
|
| 92 |
-
db: Session,
|
| 93 |
-
tenant_id: str,
|
| 94 |
-
timestamp: datetime,
|
| 95 |
-
provider: str,
|
| 96 |
-
tokens: int,
|
| 97 |
-
cost: float
|
| 98 |
-
):
|
| 99 |
-
"""Update daily usage aggregation."""
|
| 100 |
-
date = timestamp.date()
|
| 101 |
-
date_start = datetime.combine(date, datetime.min.time())
|
| 102 |
-
|
| 103 |
-
daily = db.query(UsageDaily).filter(
|
| 104 |
-
and_(
|
| 105 |
-
UsageDaily.tenant_id == tenant_id,
|
| 106 |
-
UsageDaily.date == date_start
|
| 107 |
-
)
|
| 108 |
-
).first()
|
| 109 |
-
|
| 110 |
-
if daily:
|
| 111 |
-
daily.total_requests += 1
|
| 112 |
-
daily.total_tokens += tokens
|
| 113 |
-
daily.total_cost_usd += cost
|
| 114 |
-
if provider == "gemini":
|
| 115 |
-
daily.gemini_requests += 1
|
| 116 |
-
elif provider == "openai":
|
| 117 |
-
daily.openai_requests += 1
|
| 118 |
-
daily.updated_at = datetime.utcnow()
|
| 119 |
-
else:
|
| 120 |
-
daily = UsageDaily(
|
| 121 |
-
tenant_id=tenant_id,
|
| 122 |
-
date=date_start,
|
| 123 |
-
total_requests=1,
|
| 124 |
-
total_tokens=tokens,
|
| 125 |
-
total_cost_usd=cost,
|
| 126 |
-
gemini_requests=1 if provider == "gemini" else 0,
|
| 127 |
-
openai_requests=1 if provider == "openai" else 0
|
| 128 |
-
)
|
| 129 |
-
db.add(daily)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def _update_monthly_usage(
|
| 133 |
-
db: Session,
|
| 134 |
-
tenant_id: str,
|
| 135 |
-
timestamp: datetime,
|
| 136 |
-
provider: str,
|
| 137 |
-
tokens: int,
|
| 138 |
-
cost: float
|
| 139 |
-
):
|
| 140 |
-
"""Update monthly usage aggregation."""
|
| 141 |
-
year = timestamp.year
|
| 142 |
-
month = timestamp.month
|
| 143 |
-
|
| 144 |
-
monthly = db.query(UsageMonthly).filter(
|
| 145 |
-
and_(
|
| 146 |
-
UsageMonthly.tenant_id == tenant_id,
|
| 147 |
-
UsageMonthly.year == year,
|
| 148 |
-
UsageMonthly.month == month
|
| 149 |
-
)
|
| 150 |
-
).first()
|
| 151 |
-
|
| 152 |
-
if monthly:
|
| 153 |
-
monthly.total_requests += 1
|
| 154 |
-
monthly.total_tokens += tokens
|
| 155 |
-
monthly.total_cost_usd += cost
|
| 156 |
-
if provider == "gemini":
|
| 157 |
-
monthly.gemini_requests += 1
|
| 158 |
-
elif provider == "openai":
|
| 159 |
-
monthly.openai_requests += 1
|
| 160 |
-
monthly.updated_at = datetime.utcnow()
|
| 161 |
-
else:
|
| 162 |
-
monthly = UsageMonthly(
|
| 163 |
-
tenant_id=tenant_id,
|
| 164 |
-
year=year,
|
| 165 |
-
month=month,
|
| 166 |
-
total_requests=1,
|
| 167 |
-
total_tokens=tokens,
|
| 168 |
-
total_cost_usd=cost,
|
| 169 |
-
gemini_requests=1 if provider == "gemini" else 0,
|
| 170 |
-
openai_requests=1 if provider == "openai" else 0
|
| 171 |
-
)
|
| 172 |
-
db.add(monthly)
|
| 173 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage tracking service.
|
| 3 |
+
Tracks token usage and costs for each LLM request.
|
| 4 |
+
"""
|
| 5 |
+
from sqlalchemy.orm import Session
|
| 6 |
+
from sqlalchemy import func, and_
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import uuid
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant
|
| 13 |
+
from app.billing.pricing import calculate_cost
|
| 14 |
+
from app.billing.quota import ensure_tenant_exists
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def track_usage(
|
| 20 |
+
db: Session,
|
| 21 |
+
tenant_id: str,
|
| 22 |
+
user_id: str,
|
| 23 |
+
kb_id: str,
|
| 24 |
+
provider: str,
|
| 25 |
+
model: str,
|
| 26 |
+
prompt_tokens: int,
|
| 27 |
+
completion_tokens: int,
|
| 28 |
+
request_timestamp: Optional[datetime] = None
|
| 29 |
+
) -> UsageEvent:
|
| 30 |
+
"""
|
| 31 |
+
Track a single usage event.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
db: Database session
|
| 35 |
+
tenant_id: Tenant ID
|
| 36 |
+
user_id: User ID
|
| 37 |
+
kb_id: Knowledge base ID
|
| 38 |
+
provider: "gemini" or "openai"
|
| 39 |
+
model: Model name
|
| 40 |
+
prompt_tokens: Input tokens
|
| 41 |
+
completion_tokens: Output tokens
|
| 42 |
+
request_timestamp: Request timestamp (defaults to now)
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Created UsageEvent
|
| 46 |
+
"""
|
| 47 |
+
# Ensure tenant exists
|
| 48 |
+
ensure_tenant_exists(db, tenant_id)
|
| 49 |
+
|
| 50 |
+
# Calculate cost
|
| 51 |
+
total_tokens = prompt_tokens + completion_tokens
|
| 52 |
+
estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens)
|
| 53 |
+
|
| 54 |
+
# Create usage event
|
| 55 |
+
request_id = f"req_{uuid.uuid4().hex[:16]}"
|
| 56 |
+
timestamp = request_timestamp or datetime.utcnow()
|
| 57 |
+
|
| 58 |
+
usage_event = UsageEvent(
|
| 59 |
+
request_id=request_id,
|
| 60 |
+
tenant_id=tenant_id,
|
| 61 |
+
user_id=user_id,
|
| 62 |
+
kb_id=kb_id,
|
| 63 |
+
provider=provider,
|
| 64 |
+
model=model,
|
| 65 |
+
prompt_tokens=prompt_tokens,
|
| 66 |
+
completion_tokens=completion_tokens,
|
| 67 |
+
total_tokens=total_tokens,
|
| 68 |
+
estimated_cost_usd=estimated_cost,
|
| 69 |
+
request_timestamp=timestamp
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
db.add(usage_event)
|
| 73 |
+
|
| 74 |
+
# Update daily aggregation
|
| 75 |
+
_update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
|
| 76 |
+
|
| 77 |
+
# Update monthly aggregation
|
| 78 |
+
_update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
|
| 79 |
+
|
| 80 |
+
db.commit()
|
| 81 |
+
db.refresh(usage_event)
|
| 82 |
+
|
| 83 |
+
logger.info(
|
| 84 |
+
f"Tracked usage: tenant={tenant_id}, provider={provider}, "
|
| 85 |
+
f"tokens={total_tokens}, cost=${estimated_cost:.6f}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return usage_event
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _update_daily_usage(
|
| 92 |
+
db: Session,
|
| 93 |
+
tenant_id: str,
|
| 94 |
+
timestamp: datetime,
|
| 95 |
+
provider: str,
|
| 96 |
+
tokens: int,
|
| 97 |
+
cost: float
|
| 98 |
+
):
|
| 99 |
+
"""Update daily usage aggregation."""
|
| 100 |
+
date = timestamp.date()
|
| 101 |
+
date_start = datetime.combine(date, datetime.min.time())
|
| 102 |
+
|
| 103 |
+
daily = db.query(UsageDaily).filter(
|
| 104 |
+
and_(
|
| 105 |
+
UsageDaily.tenant_id == tenant_id,
|
| 106 |
+
UsageDaily.date == date_start
|
| 107 |
+
)
|
| 108 |
+
).first()
|
| 109 |
+
|
| 110 |
+
if daily:
|
| 111 |
+
daily.total_requests += 1
|
| 112 |
+
daily.total_tokens += tokens
|
| 113 |
+
daily.total_cost_usd += cost
|
| 114 |
+
if provider == "gemini":
|
| 115 |
+
daily.gemini_requests += 1
|
| 116 |
+
elif provider == "openai":
|
| 117 |
+
daily.openai_requests += 1
|
| 118 |
+
daily.updated_at = datetime.utcnow()
|
| 119 |
+
else:
|
| 120 |
+
daily = UsageDaily(
|
| 121 |
+
tenant_id=tenant_id,
|
| 122 |
+
date=date_start,
|
| 123 |
+
total_requests=1,
|
| 124 |
+
total_tokens=tokens,
|
| 125 |
+
total_cost_usd=cost,
|
| 126 |
+
gemini_requests=1 if provider == "gemini" else 0,
|
| 127 |
+
openai_requests=1 if provider == "openai" else 0
|
| 128 |
+
)
|
| 129 |
+
db.add(daily)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _update_monthly_usage(
|
| 133 |
+
db: Session,
|
| 134 |
+
tenant_id: str,
|
| 135 |
+
timestamp: datetime,
|
| 136 |
+
provider: str,
|
| 137 |
+
tokens: int,
|
| 138 |
+
cost: float
|
| 139 |
+
):
|
| 140 |
+
"""Update monthly usage aggregation."""
|
| 141 |
+
year = timestamp.year
|
| 142 |
+
month = timestamp.month
|
| 143 |
+
|
| 144 |
+
monthly = db.query(UsageMonthly).filter(
|
| 145 |
+
and_(
|
| 146 |
+
UsageMonthly.tenant_id == tenant_id,
|
| 147 |
+
UsageMonthly.year == year,
|
| 148 |
+
UsageMonthly.month == month
|
| 149 |
+
)
|
| 150 |
+
).first()
|
| 151 |
+
|
| 152 |
+
if monthly:
|
| 153 |
+
monthly.total_requests += 1
|
| 154 |
+
monthly.total_tokens += tokens
|
| 155 |
+
monthly.total_cost_usd += cost
|
| 156 |
+
if provider == "gemini":
|
| 157 |
+
monthly.gemini_requests += 1
|
| 158 |
+
elif provider == "openai":
|
| 159 |
+
monthly.openai_requests += 1
|
| 160 |
+
monthly.updated_at = datetime.utcnow()
|
| 161 |
+
else:
|
| 162 |
+
monthly = UsageMonthly(
|
| 163 |
+
tenant_id=tenant_id,
|
| 164 |
+
year=year,
|
| 165 |
+
month=month,
|
| 166 |
+
total_requests=1,
|
| 167 |
+
total_tokens=tokens,
|
| 168 |
+
total_cost_usd=cost,
|
| 169 |
+
gemini_requests=1 if provider == "gemini" else 0,
|
| 170 |
+
openai_requests=1 if provider == "openai" else 0
|
| 171 |
+
)
|
| 172 |
+
db.add(monthly)
|
| 173 |
+
|
app/config.py
CHANGED
|
@@ -1,77 +1,77 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Configuration settings for the RAG backend.
|
| 3 |
-
"""
|
| 4 |
-
from pydantic_settings import BaseSettings
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Optional
|
| 7 |
-
import os
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Settings(BaseSettings):
|
| 11 |
-
"""Application settings with environment variable support."""
|
| 12 |
-
|
| 13 |
-
# App settings
|
| 14 |
-
APP_NAME: str = "ClientSphere RAG Backend"
|
| 15 |
-
DEBUG: bool = True
|
| 16 |
-
ENV: str = "dev" # "dev" or "prod" - controls tenant_id security
|
| 17 |
-
|
| 18 |
-
# Paths
|
| 19 |
-
BASE_DIR: Path = Path(__file__).parent.parent
|
| 20 |
-
DATA_DIR: Path = BASE_DIR / "data"
|
| 21 |
-
UPLOADS_DIR: Path = DATA_DIR / "uploads"
|
| 22 |
-
PROCESSED_DIR: Path = DATA_DIR / "processed"
|
| 23 |
-
VECTORDB_DIR: Path = DATA_DIR / "vectordb"
|
| 24 |
-
|
| 25 |
-
# Chunking settings (optimized for retrieval quality)
|
| 26 |
-
CHUNK_SIZE: int = 600 # tokens (increased for better context)
|
| 27 |
-
CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity)
|
| 28 |
-
MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks)
|
| 29 |
-
|
| 30 |
-
# Embedding settings
|
| 31 |
-
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality
|
| 32 |
-
EMBEDDING_DIMENSION: int = 384
|
| 33 |
-
|
| 34 |
-
# Vector store settings
|
| 35 |
-
COLLECTION_NAME: str = "clientsphere_kb"
|
| 36 |
-
|
| 37 |
-
# Retrieval settings (optimized for maximum confidence)
|
| 38 |
-
TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches)
|
| 39 |
-
SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks
|
| 40 |
-
SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination)
|
| 41 |
-
|
| 42 |
-
# LLM settings
|
| 43 |
-
LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai"
|
| 44 |
-
GEMINI_API_KEY: Optional[str] = None
|
| 45 |
-
OPENAI_API_KEY: Optional[str] = None
|
| 46 |
-
GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model
|
| 47 |
-
OPENAI_MODEL: str = "gpt-3.5-turbo"
|
| 48 |
-
|
| 49 |
-
# Response settings
|
| 50 |
-
MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus)
|
| 51 |
-
TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination)
|
| 52 |
-
REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention
|
| 53 |
-
|
| 54 |
-
# Security settings
|
| 55 |
-
MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB
|
| 56 |
-
ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all)
|
| 57 |
-
RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute
|
| 58 |
-
JWT_SECRET: Optional[str] = None # JWT secret for authentication
|
| 59 |
-
|
| 60 |
-
# Rate limiting
|
| 61 |
-
RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting
|
| 62 |
-
|
| 63 |
-
class Config:
|
| 64 |
-
env_file = ".env"
|
| 65 |
-
env_file_encoding = "utf-8"
|
| 66 |
-
|
| 67 |
-
def __init__(self, **kwargs):
|
| 68 |
-
super().__init__(**kwargs)
|
| 69 |
-
# Create directories if they don't exist
|
| 70 |
-
self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
| 71 |
-
self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
|
| 72 |
-
self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# Global settings instance
|
| 76 |
-
settings = Settings()
|
| 77 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the RAG backend.
|
| 3 |
+
"""
|
| 4 |
+
from pydantic_settings import BaseSettings
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Settings(BaseSettings):
|
| 11 |
+
"""Application settings with environment variable support."""
|
| 12 |
+
|
| 13 |
+
# App settings
|
| 14 |
+
APP_NAME: str = "ClientSphere RAG Backend"
|
| 15 |
+
DEBUG: bool = True
|
| 16 |
+
ENV: str = "dev" # "dev" or "prod" - controls tenant_id security
|
| 17 |
+
|
| 18 |
+
# Paths
|
| 19 |
+
BASE_DIR: Path = Path(__file__).parent.parent
|
| 20 |
+
DATA_DIR: Path = BASE_DIR / "data"
|
| 21 |
+
UPLOADS_DIR: Path = DATA_DIR / "uploads"
|
| 22 |
+
PROCESSED_DIR: Path = DATA_DIR / "processed"
|
| 23 |
+
VECTORDB_DIR: Path = DATA_DIR / "vectordb"
|
| 24 |
+
|
| 25 |
+
# Chunking settings (optimized for retrieval quality)
|
| 26 |
+
CHUNK_SIZE: int = 600 # tokens (increased for better context)
|
| 27 |
+
CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity)
|
| 28 |
+
MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks)
|
| 29 |
+
|
| 30 |
+
# Embedding settings
|
| 31 |
+
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality
|
| 32 |
+
EMBEDDING_DIMENSION: int = 384
|
| 33 |
+
|
| 34 |
+
# Vector store settings
|
| 35 |
+
COLLECTION_NAME: str = "clientsphere_kb"
|
| 36 |
+
|
| 37 |
+
# Retrieval settings (optimized for maximum confidence)
|
| 38 |
+
TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches)
|
| 39 |
+
SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks
|
| 40 |
+
SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination)
|
| 41 |
+
|
| 42 |
+
# LLM settings
|
| 43 |
+
LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai"
|
| 44 |
+
GEMINI_API_KEY: Optional[str] = None
|
| 45 |
+
OPENAI_API_KEY: Optional[str] = None
|
| 46 |
+
GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model
|
| 47 |
+
OPENAI_MODEL: str = "gpt-3.5-turbo"
|
| 48 |
+
|
| 49 |
+
# Response settings
|
| 50 |
+
MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus)
|
| 51 |
+
TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination)
|
| 52 |
+
REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention
|
| 53 |
+
|
| 54 |
+
# Security settings
|
| 55 |
+
MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB
|
| 56 |
+
ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all)
|
| 57 |
+
RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute
|
| 58 |
+
JWT_SECRET: Optional[str] = None # JWT secret for authentication
|
| 59 |
+
|
| 60 |
+
# Rate limiting
|
| 61 |
+
RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting
|
| 62 |
+
|
| 63 |
+
class Config:
|
| 64 |
+
env_file = ".env"
|
| 65 |
+
env_file_encoding = "utf-8"
|
| 66 |
+
|
| 67 |
+
def __init__(self, **kwargs):
|
| 68 |
+
super().__init__(**kwargs)
|
| 69 |
+
# Create directories if they don't exist
|
| 70 |
+
self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Global settings instance
|
| 76 |
+
settings = Settings()
|
| 77 |
+
|
app/db/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
# Database module
|
| 2 |
-
|
|
|
|
| 1 |
+
# Database module
|
| 2 |
+
|
app/db/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (138 Bytes)
|
|
|
app/db/__pycache__/database.cpython-313.pyc
DELETED
|
Binary file (2.12 kB)
|
|
|
app/db/__pycache__/models.cpython-313.pyc
DELETED
|
Binary file (5.03 kB)
|
|
|
app/db/database.py
CHANGED
|
@@ -1,53 +1,53 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Database setup and session management.
|
| 3 |
-
Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema.
|
| 4 |
-
"""
|
| 5 |
-
from sqlalchemy import create_engine
|
| 6 |
-
from sqlalchemy.ext.declarative import declarative_base
|
| 7 |
-
from sqlalchemy.orm import sessionmaker, Session
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
import logging
|
| 10 |
-
|
| 11 |
-
from app.config import settings
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
# Database path
|
| 16 |
-
DB_DIR = settings.DATA_DIR / "billing"
|
| 17 |
-
DB_DIR.mkdir(parents=True, exist_ok=True)
|
| 18 |
-
DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}"
|
| 19 |
-
|
| 20 |
-
# Create engine
|
| 21 |
-
engine = create_engine(
|
| 22 |
-
DATABASE_URL,
|
| 23 |
-
connect_args={"check_same_thread": False}, # SQLite specific
|
| 24 |
-
echo=False # Set to True for SQL query logging
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
# Session factory
|
| 28 |
-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 29 |
-
|
| 30 |
-
# Base class for models
|
| 31 |
-
Base = declarative_base()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_db() -> Session:
|
| 35 |
-
"""Get database session (dependency for FastAPI)."""
|
| 36 |
-
db = SessionLocal()
|
| 37 |
-
try:
|
| 38 |
-
yield db
|
| 39 |
-
finally:
|
| 40 |
-
db.close()
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def init_db():
|
| 44 |
-
"""Initialize database tables."""
|
| 45 |
-
Base.metadata.create_all(bind=engine)
|
| 46 |
-
logger.info("Database tables created/verified")
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def drop_db():
|
| 50 |
-
"""Drop all tables (use with caution!)."""
|
| 51 |
-
Base.metadata.drop_all(bind=engine)
|
| 52 |
-
logger.warning("All database tables dropped")
|
| 53 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database setup and session management.
|
| 3 |
+
Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema.
|
| 4 |
+
"""
|
| 5 |
+
from sqlalchemy import create_engine
|
| 6 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 7 |
+
from sqlalchemy.orm import sessionmaker, Session
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from app.config import settings
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Database path
|
| 16 |
+
DB_DIR = settings.DATA_DIR / "billing"
|
| 17 |
+
DB_DIR.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}"
|
| 19 |
+
|
| 20 |
+
# Create engine
|
| 21 |
+
engine = create_engine(
|
| 22 |
+
DATABASE_URL,
|
| 23 |
+
connect_args={"check_same_thread": False}, # SQLite specific
|
| 24 |
+
echo=False # Set to True for SQL query logging
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Session factory
|
| 28 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 29 |
+
|
| 30 |
+
# Base class for models
|
| 31 |
+
Base = declarative_base()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_db() -> Session:
|
| 35 |
+
"""Get database session (dependency for FastAPI)."""
|
| 36 |
+
db = SessionLocal()
|
| 37 |
+
try:
|
| 38 |
+
yield db
|
| 39 |
+
finally:
|
| 40 |
+
db.close()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def init_db():
|
| 44 |
+
"""Initialize database tables."""
|
| 45 |
+
Base.metadata.create_all(bind=engine)
|
| 46 |
+
logger.info("Database tables created/verified")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def drop_db():
|
| 50 |
+
"""Drop all tables (use with caution!)."""
|
| 51 |
+
Base.metadata.drop_all(bind=engine)
|
| 52 |
+
logger.warning("All database tables dropped")
|
| 53 |
+
|
app/db/models.py
CHANGED
|
@@ -1,129 +1,129 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Database models for billing and usage tracking.
|
| 3 |
-
"""
|
| 4 |
-
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text
|
| 5 |
-
from sqlalchemy.orm import relationship
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
from typing import Optional
|
| 8 |
-
|
| 9 |
-
from app.db.database import Base
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class Tenant(Base):
|
| 13 |
-
"""Tenant/organization model."""
|
| 14 |
-
__tablename__ = "tenants"
|
| 15 |
-
|
| 16 |
-
id = Column(String, primary_key=True, index=True)
|
| 17 |
-
name = Column(String, nullable=False)
|
| 18 |
-
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 19 |
-
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 20 |
-
|
| 21 |
-
# Relationships
|
| 22 |
-
plan = relationship("TenantPlan", back_populates="tenant", uselist=False)
|
| 23 |
-
usage_events = relationship("UsageEvent", back_populates="tenant")
|
| 24 |
-
daily_usage = relationship("UsageDaily", back_populates="tenant")
|
| 25 |
-
monthly_usage = relationship("UsageMonthly", back_populates="tenant")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class TenantPlan(Base):
|
| 29 |
-
"""Tenant subscription plan."""
|
| 30 |
-
__tablename__ = "tenant_plans"
|
| 31 |
-
|
| 32 |
-
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 33 |
-
tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True)
|
| 34 |
-
plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro"
|
| 35 |
-
monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited
|
| 36 |
-
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 37 |
-
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 38 |
-
|
| 39 |
-
# Relationships
|
| 40 |
-
tenant = relationship("Tenant", back_populates="plan")
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class UsageEvent(Base):
|
| 44 |
-
"""Individual usage event (each /chat request)."""
|
| 45 |
-
__tablename__ = "usage_events"
|
| 46 |
-
|
| 47 |
-
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 48 |
-
request_id = Column(String, unique=True, nullable=False, index=True)
|
| 49 |
-
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 50 |
-
user_id = Column(String, nullable=False, index=True)
|
| 51 |
-
kb_id = Column(String, nullable=False)
|
| 52 |
-
|
| 53 |
-
# LLM details
|
| 54 |
-
provider = Column(String, nullable=False) # "gemini" or "openai"
|
| 55 |
-
model = Column(String, nullable=False)
|
| 56 |
-
|
| 57 |
-
# Token usage
|
| 58 |
-
prompt_tokens = Column(Integer, nullable=False, default=0)
|
| 59 |
-
completion_tokens = Column(Integer, nullable=False, default=0)
|
| 60 |
-
total_tokens = Column(Integer, nullable=False, default=0)
|
| 61 |
-
|
| 62 |
-
# Cost tracking
|
| 63 |
-
estimated_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 64 |
-
|
| 65 |
-
# Timestamp
|
| 66 |
-
request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
| 67 |
-
|
| 68 |
-
# Relationships
|
| 69 |
-
tenant = relationship("Tenant", back_populates="usage_events")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class UsageDaily(Base):
|
| 73 |
-
"""Daily aggregated usage per tenant."""
|
| 74 |
-
__tablename__ = "usage_daily"
|
| 75 |
-
|
| 76 |
-
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 77 |
-
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 78 |
-
date = Column(DateTime, nullable=False, index=True)
|
| 79 |
-
|
| 80 |
-
# Aggregated metrics
|
| 81 |
-
total_requests = Column(Integer, nullable=False, default=0)
|
| 82 |
-
total_tokens = Column(Integer, nullable=False, default=0)
|
| 83 |
-
total_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 84 |
-
|
| 85 |
-
# Provider breakdown
|
| 86 |
-
gemini_requests = Column(Integer, nullable=False, default=0)
|
| 87 |
-
openai_requests = Column(Integer, nullable=False, default=0)
|
| 88 |
-
|
| 89 |
-
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 90 |
-
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 91 |
-
|
| 92 |
-
# Unique constraint: one record per tenant per day
|
| 93 |
-
__table_args__ = (
|
| 94 |
-
{"sqlite_autoincrement": True},
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
# Relationships
|
| 98 |
-
tenant = relationship("Tenant", back_populates="daily_usage")
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class UsageMonthly(Base):
|
| 102 |
-
"""Monthly aggregated usage per tenant."""
|
| 103 |
-
__tablename__ = "usage_monthly"
|
| 104 |
-
|
| 105 |
-
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 106 |
-
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 107 |
-
year = Column(Integer, nullable=False, index=True)
|
| 108 |
-
month = Column(Integer, nullable=False, index=True) # 1-12
|
| 109 |
-
|
| 110 |
-
# Aggregated metrics
|
| 111 |
-
total_requests = Column(Integer, nullable=False, default=0)
|
| 112 |
-
total_tokens = Column(Integer, nullable=False, default=0)
|
| 113 |
-
total_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 114 |
-
|
| 115 |
-
# Provider breakdown
|
| 116 |
-
gemini_requests = Column(Integer, nullable=False, default=0)
|
| 117 |
-
openai_requests = Column(Integer, nullable=False, default=0)
|
| 118 |
-
|
| 119 |
-
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 120 |
-
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 121 |
-
|
| 122 |
-
# Unique constraint: one record per tenant per month
|
| 123 |
-
__table_args__ = (
|
| 124 |
-
{"sqlite_autoincrement": True},
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
# Relationships
|
| 128 |
-
tenant = relationship("Tenant", back_populates="monthly_usage")
|
| 129 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database models for billing and usage tracking.
|
| 3 |
+
"""
|
| 4 |
+
from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text
|
| 5 |
+
from sqlalchemy.orm import relationship
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from app.db.database import Base
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Tenant(Base):
|
| 13 |
+
"""Tenant/organization model."""
|
| 14 |
+
__tablename__ = "tenants"
|
| 15 |
+
|
| 16 |
+
id = Column(String, primary_key=True, index=True)
|
| 17 |
+
name = Column(String, nullable=False)
|
| 18 |
+
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 19 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 20 |
+
|
| 21 |
+
# Relationships
|
| 22 |
+
plan = relationship("TenantPlan", back_populates="tenant", uselist=False)
|
| 23 |
+
usage_events = relationship("UsageEvent", back_populates="tenant")
|
| 24 |
+
daily_usage = relationship("UsageDaily", back_populates="tenant")
|
| 25 |
+
monthly_usage = relationship("UsageMonthly", back_populates="tenant")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TenantPlan(Base):
|
| 29 |
+
"""Tenant subscription plan."""
|
| 30 |
+
__tablename__ = "tenant_plans"
|
| 31 |
+
|
| 32 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 33 |
+
tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True)
|
| 34 |
+
plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro"
|
| 35 |
+
monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited
|
| 36 |
+
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 37 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 38 |
+
|
| 39 |
+
# Relationships
|
| 40 |
+
tenant = relationship("Tenant", back_populates="plan")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class UsageEvent(Base):
|
| 44 |
+
"""Individual usage event (each /chat request)."""
|
| 45 |
+
__tablename__ = "usage_events"
|
| 46 |
+
|
| 47 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 48 |
+
request_id = Column(String, unique=True, nullable=False, index=True)
|
| 49 |
+
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 50 |
+
user_id = Column(String, nullable=False, index=True)
|
| 51 |
+
kb_id = Column(String, nullable=False)
|
| 52 |
+
|
| 53 |
+
# LLM details
|
| 54 |
+
provider = Column(String, nullable=False) # "gemini" or "openai"
|
| 55 |
+
model = Column(String, nullable=False)
|
| 56 |
+
|
| 57 |
+
# Token usage
|
| 58 |
+
prompt_tokens = Column(Integer, nullable=False, default=0)
|
| 59 |
+
completion_tokens = Column(Integer, nullable=False, default=0)
|
| 60 |
+
total_tokens = Column(Integer, nullable=False, default=0)
|
| 61 |
+
|
| 62 |
+
# Cost tracking
|
| 63 |
+
estimated_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 64 |
+
|
| 65 |
+
# Timestamp
|
| 66 |
+
request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
| 67 |
+
|
| 68 |
+
# Relationships
|
| 69 |
+
tenant = relationship("Tenant", back_populates="usage_events")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class UsageDaily(Base):
|
| 73 |
+
"""Daily aggregated usage per tenant."""
|
| 74 |
+
__tablename__ = "usage_daily"
|
| 75 |
+
|
| 76 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 77 |
+
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 78 |
+
date = Column(DateTime, nullable=False, index=True)
|
| 79 |
+
|
| 80 |
+
# Aggregated metrics
|
| 81 |
+
total_requests = Column(Integer, nullable=False, default=0)
|
| 82 |
+
total_tokens = Column(Integer, nullable=False, default=0)
|
| 83 |
+
total_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 84 |
+
|
| 85 |
+
# Provider breakdown
|
| 86 |
+
gemini_requests = Column(Integer, nullable=False, default=0)
|
| 87 |
+
openai_requests = Column(Integer, nullable=False, default=0)
|
| 88 |
+
|
| 89 |
+
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 90 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 91 |
+
|
| 92 |
+
# Unique constraint: one record per tenant per day
|
| 93 |
+
__table_args__ = (
|
| 94 |
+
{"sqlite_autoincrement": True},
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Relationships
|
| 98 |
+
tenant = relationship("Tenant", back_populates="daily_usage")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class UsageMonthly(Base):
|
| 102 |
+
"""Monthly aggregated usage per tenant."""
|
| 103 |
+
__tablename__ = "usage_monthly"
|
| 104 |
+
|
| 105 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 106 |
+
tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
|
| 107 |
+
year = Column(Integer, nullable=False, index=True)
|
| 108 |
+
month = Column(Integer, nullable=False, index=True) # 1-12
|
| 109 |
+
|
| 110 |
+
# Aggregated metrics
|
| 111 |
+
total_requests = Column(Integer, nullable=False, default=0)
|
| 112 |
+
total_tokens = Column(Integer, nullable=False, default=0)
|
| 113 |
+
total_cost_usd = Column(Float, nullable=False, default=0.0)
|
| 114 |
+
|
| 115 |
+
# Provider breakdown
|
| 116 |
+
gemini_requests = Column(Integer, nullable=False, default=0)
|
| 117 |
+
openai_requests = Column(Integer, nullable=False, default=0)
|
| 118 |
+
|
| 119 |
+
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
| 120 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
| 121 |
+
|
| 122 |
+
# Unique constraint: one record per tenant per month
|
| 123 |
+
__table_args__ = (
|
| 124 |
+
{"sqlite_autoincrement": True},
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Relationships
|
| 128 |
+
tenant = relationship("Tenant", back_populates="monthly_usage")
|
| 129 |
+
|
app/main.py
CHANGED
|
@@ -1,1039 +1,1039 @@
|
|
| 1 |
-
"""
|
| 2 |
-
FastAPI application for ClientSphere RAG Backend.
|
| 3 |
-
Provides endpoints for knowledge base management and chat.
|
| 4 |
-
"""
|
| 5 |
-
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends
|
| 6 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
-
from fastapi.exceptions import RequestValidationError
|
| 8 |
-
from fastapi.responses import JSONResponse
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
import shutil
|
| 11 |
-
import uuid
|
| 12 |
-
from datetime import datetime
|
| 13 |
-
from typing import Optional
|
| 14 |
-
import logging
|
| 15 |
-
|
| 16 |
-
from app.config import settings
|
| 17 |
-
from app.middleware.auth import get_auth_context, require_auth
|
| 18 |
-
from app.middleware.rate_limit import (
|
| 19 |
-
limiter,
|
| 20 |
-
get_tenant_rate_limit_key,
|
| 21 |
-
RateLimitExceeded,
|
| 22 |
-
_rate_limit_exceeded_handler
|
| 23 |
-
)
|
| 24 |
-
from app.models.schemas import (
|
| 25 |
-
UploadResponse,
|
| 26 |
-
ChatRequest,
|
| 27 |
-
ChatResponse,
|
| 28 |
-
KnowledgeBaseStats,
|
| 29 |
-
HealthResponse,
|
| 30 |
-
DocumentStatus,
|
| 31 |
-
Citation,
|
| 32 |
-
)
|
| 33 |
-
from app.models.billing_schemas import (
|
| 34 |
-
UsageResponse,
|
| 35 |
-
PlanLimitsResponse,
|
| 36 |
-
CostReportResponse,
|
| 37 |
-
SetPlanRequest
|
| 38 |
-
)
|
| 39 |
-
from app.rag.ingest import parser
|
| 40 |
-
from app.rag.chunking import chunker
|
| 41 |
-
from app.rag.embeddings import get_embedding_service
|
| 42 |
-
from app.rag.vectorstore import get_vector_store
|
| 43 |
-
from app.rag.retrieval import get_retrieval_service
|
| 44 |
-
from app.rag.answer import get_answer_service
|
| 45 |
-
from app.db.database import get_db, init_db
|
| 46 |
-
from app.billing.quota import check_quota, ensure_tenant_exists
|
| 47 |
-
from app.billing.usage_tracker import track_usage
|
| 48 |
-
|
| 49 |
-
logging.basicConfig(level=logging.INFO)
|
| 50 |
-
logger = logging.getLogger(__name__)
|
| 51 |
-
|
| 52 |
-
# Initialize FastAPI app
|
| 53 |
-
app = FastAPI(
|
| 54 |
-
title=settings.APP_NAME,
|
| 55 |
-
description="RAG-based customer support chatbot API",
|
| 56 |
-
version="1.0.0",
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
# Initialize database on startup
|
| 60 |
-
@app.on_event("startup")
|
| 61 |
-
async def startup_event():
|
| 62 |
-
"""Initialize database on application startup."""
|
| 63 |
-
init_db()
|
| 64 |
-
logger.info("Database initialized")
|
| 65 |
-
|
| 66 |
-
# Configure CORS - SECURITY: Restrict in production
|
| 67 |
-
if settings.ALLOWED_ORIGINS == "*":
|
| 68 |
-
allowed_origins = ["*"]
|
| 69 |
-
else:
|
| 70 |
-
# Split by comma and strip whitespace
|
| 71 |
-
allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()]
|
| 72 |
-
|
| 73 |
-
# Default to allowing localhost if no origins specified
|
| 74 |
-
if not allowed_origins or allowed_origins == ["*"]:
|
| 75 |
-
allowed_origins = ["*"] # Allow all in dev mode
|
| 76 |
-
|
| 77 |
-
logger.info(f"CORS configured with origins: {allowed_origins}")
|
| 78 |
-
|
| 79 |
-
app.add_middleware(
|
| 80 |
-
CORSMiddleware,
|
| 81 |
-
allow_origins=allowed_origins,
|
| 82 |
-
allow_credentials=True,
|
| 83 |
-
allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight
|
| 84 |
-
allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Configure rate limiting
|
| 88 |
-
app.state.limiter = limiter
|
| 89 |
-
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 90 |
-
|
| 91 |
-
# Add exception handler for validation errors
|
| 92 |
-
@app.exception_handler(RequestValidationError)
|
| 93 |
-
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 94 |
-
"""Handle request validation errors with detailed logging."""
|
| 95 |
-
body = await request.body()
|
| 96 |
-
logger.error(f"Request validation error: {exc.errors()}")
|
| 97 |
-
logger.error(f"Request body (raw): {body}")
|
| 98 |
-
logger.error(f"Request headers: {dict(request.headers)}")
|
| 99 |
-
return JSONResponse(
|
| 100 |
-
status_code=422,
|
| 101 |
-
content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')}
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
# Add exception handler for validation errors
|
| 105 |
-
from fastapi.exceptions import RequestValidationError
|
| 106 |
-
from fastapi.responses import JSONResponse
|
| 107 |
-
|
| 108 |
-
@app.exception_handler(RequestValidationError)
|
| 109 |
-
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 110 |
-
"""Handle request validation errors with detailed logging."""
|
| 111 |
-
logger.error(f"Request validation error: {exc.errors()}")
|
| 112 |
-
logger.error(f"Request body: {await request.body()}")
|
| 113 |
-
return JSONResponse(
|
| 114 |
-
status_code=422,
|
| 115 |
-
content={"detail": exc.errors(), "body": str(await request.body())}
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# ============== Health & Status Endpoints ==============
|
| 120 |
-
|
| 121 |
-
@app.get("/", response_model=HealthResponse)
|
| 122 |
-
async def root():
|
| 123 |
-
"""Root endpoint with basic info."""
|
| 124 |
-
return HealthResponse(
|
| 125 |
-
status="ok",
|
| 126 |
-
version="1.0.0",
|
| 127 |
-
vector_db_connected=True,
|
| 128 |
-
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
@app.get("/health", response_model=HealthResponse)
|
| 133 |
-
async def health_check():
|
| 134 |
-
"""Health check endpoint."""
|
| 135 |
-
try:
|
| 136 |
-
vector_store = get_vector_store()
|
| 137 |
-
stats = vector_store.get_stats()
|
| 138 |
-
|
| 139 |
-
return HealthResponse(
|
| 140 |
-
status="healthy",
|
| 141 |
-
version="1.0.0",
|
| 142 |
-
vector_db_connected=True,
|
| 143 |
-
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 144 |
-
)
|
| 145 |
-
except Exception as e:
|
| 146 |
-
logger.error(f"Health check failed: {e}")
|
| 147 |
-
return HealthResponse(
|
| 148 |
-
status="unhealthy",
|
| 149 |
-
version="1.0.0",
|
| 150 |
-
vector_db_connected=False,
|
| 151 |
-
llm_configured=False
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
@app.get("/health/live")
|
| 156 |
-
async def liveness():
|
| 157 |
-
"""Kubernetes liveness probe - always returns alive."""
|
| 158 |
-
return {"status": "alive"}
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
@app.get("/health/ready")
|
| 162 |
-
async def readiness():
|
| 163 |
-
"""Kubernetes readiness probe - checks dependencies."""
|
| 164 |
-
checks = {
|
| 165 |
-
"vector_db": False,
|
| 166 |
-
"llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
# Check vector DB connection
|
| 170 |
-
try:
|
| 171 |
-
vector_store = get_vector_store()
|
| 172 |
-
vector_store.get_stats()
|
| 173 |
-
checks["vector_db"] = True
|
| 174 |
-
except Exception as e:
|
| 175 |
-
logger.warning(f"Vector DB check failed: {e}")
|
| 176 |
-
checks["vector_db"] = False
|
| 177 |
-
|
| 178 |
-
# All checks must pass
|
| 179 |
-
if all(checks.values()):
|
| 180 |
-
return {"status": "ready", "checks": checks}
|
| 181 |
-
else:
|
| 182 |
-
from fastapi import HTTPException
|
| 183 |
-
raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks})
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# ============== Knowledge Base Endpoints ==============
|
| 187 |
-
|
| 188 |
-
@app.post("/kb/upload", response_model=UploadResponse)
|
| 189 |
-
@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key)
|
| 190 |
-
async def upload_document(
|
| 191 |
-
background_tasks: BackgroundTasks,
|
| 192 |
-
request: Request,
|
| 193 |
-
file: UploadFile = File(...),
|
| 194 |
-
tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
|
| 195 |
-
user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
|
| 196 |
-
kb_id: str = Form(...)
|
| 197 |
-
):
|
| 198 |
-
"""
|
| 199 |
-
Upload a document to the knowledge base.
|
| 200 |
-
|
| 201 |
-
- Saves file to disk
|
| 202 |
-
- Parses and chunks the document
|
| 203 |
-
- Generates embeddings
|
| 204 |
-
- Stores in vector database
|
| 205 |
-
"""
|
| 206 |
-
# SECURITY: Extract tenant_id from auth token in production
|
| 207 |
-
if settings.ENV == "prod":
|
| 208 |
-
auth_context = await require_auth(request)
|
| 209 |
-
tenant_id = auth_context.get("tenant_id")
|
| 210 |
-
if not tenant_id:
|
| 211 |
-
raise HTTPException(
|
| 212 |
-
status_code=403,
|
| 213 |
-
detail="tenant_id must come from authentication token in production mode"
|
| 214 |
-
)
|
| 215 |
-
elif not tenant_id:
|
| 216 |
-
raise HTTPException(
|
| 217 |
-
status_code=400,
|
| 218 |
-
detail="tenant_id is required"
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
# Validate file type
|
| 222 |
-
file_ext = Path(file.filename).suffix.lower()
|
| 223 |
-
if file_ext not in parser.SUPPORTED_EXTENSIONS:
|
| 224 |
-
raise HTTPException(
|
| 225 |
-
status_code=400,
|
| 226 |
-
detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}"
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
# Validate file size (SECURITY)
|
| 230 |
-
file.file.seek(0, 2) # Seek to end
|
| 231 |
-
file_size = file.file.tell()
|
| 232 |
-
file.file.seek(0) # Reset to start
|
| 233 |
-
max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024
|
| 234 |
-
if file_size > max_size_bytes:
|
| 235 |
-
raise HTTPException(
|
| 236 |
-
status_code=400,
|
| 237 |
-
detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB"
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
# Generate document ID
|
| 241 |
-
doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}"
|
| 242 |
-
|
| 243 |
-
# Save file to uploads directory
|
| 244 |
-
upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}"
|
| 245 |
-
try:
|
| 246 |
-
with open(upload_path, "wb") as buffer:
|
| 247 |
-
shutil.copyfileobj(file.file, buffer)
|
| 248 |
-
logger.info(f"Saved file: {upload_path}")
|
| 249 |
-
except Exception as e:
|
| 250 |
-
logger.error(f"Error saving file: {e}")
|
| 251 |
-
raise HTTPException(status_code=500, detail="Failed to save file")
|
| 252 |
-
|
| 253 |
-
# Process document in background
|
| 254 |
-
background_tasks.add_task(
|
| 255 |
-
process_document,
|
| 256 |
-
upload_path,
|
| 257 |
-
tenant_id, # CRITICAL: Multi-tenant isolation
|
| 258 |
-
user_id,
|
| 259 |
-
kb_id,
|
| 260 |
-
file.filename,
|
| 261 |
-
doc_id
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
return UploadResponse(
|
| 265 |
-
success=True,
|
| 266 |
-
message="Document upload started. Processing in background.",
|
| 267 |
-
document_id=doc_id,
|
| 268 |
-
file_name=file.filename,
|
| 269 |
-
chunks_created=0,
|
| 270 |
-
status=DocumentStatus.PROCESSING
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
async def process_document(
|
| 275 |
-
file_path: Path,
|
| 276 |
-
tenant_id: str, # CRITICAL: Multi-tenant isolation
|
| 277 |
-
user_id: str,
|
| 278 |
-
kb_id: str,
|
| 279 |
-
original_filename: str,
|
| 280 |
-
document_id: str
|
| 281 |
-
):
|
| 282 |
-
"""
|
| 283 |
-
Background task to process an uploaded document.
|
| 284 |
-
"""
|
| 285 |
-
try:
|
| 286 |
-
logger.info(f"Processing document: {original_filename}")
|
| 287 |
-
|
| 288 |
-
# Parse document
|
| 289 |
-
parsed_doc = parser.parse(file_path)
|
| 290 |
-
logger.info(f"Parsed document: {len(parsed_doc.text)} characters")
|
| 291 |
-
|
| 292 |
-
# Chunk document
|
| 293 |
-
chunks = chunker.chunk_text(
|
| 294 |
-
parsed_doc.text,
|
| 295 |
-
page_numbers=parsed_doc.page_map
|
| 296 |
-
)
|
| 297 |
-
logger.info(f"Created {len(chunks)} chunks")
|
| 298 |
-
|
| 299 |
-
if not chunks:
|
| 300 |
-
logger.warning(f"No chunks created from {original_filename}")
|
| 301 |
-
return
|
| 302 |
-
|
| 303 |
-
# Create metadata for each chunk
|
| 304 |
-
metadatas = []
|
| 305 |
-
chunk_ids = []
|
| 306 |
-
chunk_texts = []
|
| 307 |
-
|
| 308 |
-
for chunk in chunks:
|
| 309 |
-
metadata = chunker.create_chunk_metadata(
|
| 310 |
-
chunk=chunk,
|
| 311 |
-
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 312 |
-
kb_id=kb_id,
|
| 313 |
-
user_id=user_id,
|
| 314 |
-
file_name=original_filename,
|
| 315 |
-
file_type=parsed_doc.file_type,
|
| 316 |
-
total_chunks=len(chunks),
|
| 317 |
-
document_id=document_id
|
| 318 |
-
)
|
| 319 |
-
metadatas.append(metadata)
|
| 320 |
-
chunk_ids.append(metadata["chunk_id"])
|
| 321 |
-
chunk_texts.append(chunk.content)
|
| 322 |
-
|
| 323 |
-
# Generate embeddings
|
| 324 |
-
embedding_service = get_embedding_service()
|
| 325 |
-
embeddings = embedding_service.embed_texts(chunk_texts)
|
| 326 |
-
logger.info(f"Generated {len(embeddings)} embeddings")
|
| 327 |
-
|
| 328 |
-
# Store in vector database
|
| 329 |
-
vector_store = get_vector_store()
|
| 330 |
-
vector_store.add_documents(
|
| 331 |
-
documents=chunk_texts,
|
| 332 |
-
embeddings=embeddings,
|
| 333 |
-
metadatas=metadatas,
|
| 334 |
-
ids=chunk_ids
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored")
|
| 338 |
-
|
| 339 |
-
except Exception as e:
|
| 340 |
-
logger.error(f"Error processing document {original_filename}: {e}")
|
| 341 |
-
raise
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
@app.get("/kb/stats", response_model=KnowledgeBaseStats)
|
| 345 |
-
async def get_kb_stats(
|
| 346 |
-
request: Request,
|
| 347 |
-
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 348 |
-
kb_id: Optional[str] = None,
|
| 349 |
-
user_id: Optional[str] = None # Optional in dev, ignored in prod
|
| 350 |
-
):
|
| 351 |
-
"""Get statistics for a knowledge base."""
|
| 352 |
-
# SECURITY: Get tenant_id and user_id from auth context
|
| 353 |
-
auth_context = await get_auth_context(request)
|
| 354 |
-
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 355 |
-
user_id_from_auth = auth_context.get("user_id")
|
| 356 |
-
|
| 357 |
-
if settings.ENV == "prod":
|
| 358 |
-
if not tenant_id_from_auth or not user_id_from_auth:
|
| 359 |
-
raise HTTPException(
|
| 360 |
-
status_code=403,
|
| 361 |
-
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 362 |
-
)
|
| 363 |
-
tenant_id = tenant_id_from_auth
|
| 364 |
-
user_id = user_id_from_auth
|
| 365 |
-
else:
|
| 366 |
-
tenant_id = tenant_id or tenant_id_from_auth
|
| 367 |
-
user_id = user_id or user_id_from_auth
|
| 368 |
-
if not tenant_id or not kb_id or not user_id:
|
| 369 |
-
raise HTTPException(
|
| 370 |
-
status_code=400,
|
| 371 |
-
detail="tenant_id, kb_id, and user_id are required"
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
try:
|
| 375 |
-
vector_store = get_vector_store()
|
| 376 |
-
stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id)
|
| 377 |
-
|
| 378 |
-
return KnowledgeBaseStats(
|
| 379 |
-
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 380 |
-
kb_id=kb_id,
|
| 381 |
-
user_id=user_id,
|
| 382 |
-
total_documents=len(stats.get("file_names", [])),
|
| 383 |
-
total_chunks=stats.get("total_chunks", 0),
|
| 384 |
-
file_names=stats.get("file_names", []),
|
| 385 |
-
last_updated=datetime.utcnow()
|
| 386 |
-
)
|
| 387 |
-
except Exception as e:
|
| 388 |
-
logger.error(f"Error getting KB stats: {e}")
|
| 389 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
@app.delete("/kb/document")
|
| 393 |
-
async def delete_document(
|
| 394 |
-
request: Request,
|
| 395 |
-
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 396 |
-
kb_id: Optional[str] = None,
|
| 397 |
-
user_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 398 |
-
file_name: Optional[str] = None
|
| 399 |
-
):
|
| 400 |
-
"""Delete a document from the knowledge base."""
|
| 401 |
-
# SECURITY: Get tenant_id and user_id from auth context
|
| 402 |
-
auth_context = await get_auth_context(request)
|
| 403 |
-
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 404 |
-
user_id_from_auth = auth_context.get("user_id")
|
| 405 |
-
|
| 406 |
-
if settings.ENV == "prod":
|
| 407 |
-
if not tenant_id_from_auth or not user_id_from_auth:
|
| 408 |
-
raise HTTPException(
|
| 409 |
-
status_code=403,
|
| 410 |
-
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 411 |
-
)
|
| 412 |
-
tenant_id = tenant_id_from_auth
|
| 413 |
-
user_id = user_id_from_auth
|
| 414 |
-
else:
|
| 415 |
-
tenant_id = tenant_id or tenant_id_from_auth
|
| 416 |
-
user_id = user_id or user_id_from_auth
|
| 417 |
-
if not tenant_id or not kb_id or not user_id or not file_name:
|
| 418 |
-
raise HTTPException(
|
| 419 |
-
status_code=400,
|
| 420 |
-
detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)"
|
| 421 |
-
)
|
| 422 |
-
|
| 423 |
-
try:
|
| 424 |
-
vector_store = get_vector_store()
|
| 425 |
-
deleted = vector_store.delete_by_filter({
|
| 426 |
-
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 427 |
-
"kb_id": kb_id,
|
| 428 |
-
"user_id": user_id,
|
| 429 |
-
"file_name": file_name
|
| 430 |
-
})
|
| 431 |
-
|
| 432 |
-
return {
|
| 433 |
-
"success": True,
|
| 434 |
-
"message": f"Deleted {deleted} chunks",
|
| 435 |
-
"file_name": file_name
|
| 436 |
-
}
|
| 437 |
-
except Exception as e:
|
| 438 |
-
logger.error(f"Error deleting document: {e}")
|
| 439 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
@app.delete("/kb/clear")
|
| 443 |
-
async def clear_kb(
|
| 444 |
-
request: Request,
|
| 445 |
-
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 446 |
-
kb_id: Optional[str] = None,
|
| 447 |
-
user_id: Optional[str] = None # Optional in dev, ignored in prod
|
| 448 |
-
):
|
| 449 |
-
"""Clear all documents from a knowledge base."""
|
| 450 |
-
# SECURITY: Get tenant_id and user_id from auth context
|
| 451 |
-
auth_context = await get_auth_context(request)
|
| 452 |
-
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 453 |
-
user_id_from_auth = auth_context.get("user_id")
|
| 454 |
-
|
| 455 |
-
if settings.ENV == "prod":
|
| 456 |
-
if not tenant_id_from_auth or not user_id_from_auth:
|
| 457 |
-
raise HTTPException(
|
| 458 |
-
status_code=403,
|
| 459 |
-
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 460 |
-
)
|
| 461 |
-
tenant_id = tenant_id_from_auth
|
| 462 |
-
user_id = user_id_from_auth
|
| 463 |
-
else:
|
| 464 |
-
tenant_id = tenant_id or tenant_id_from_auth
|
| 465 |
-
user_id = user_id or user_id_from_auth
|
| 466 |
-
if not tenant_id or not kb_id or not user_id:
|
| 467 |
-
raise HTTPException(
|
| 468 |
-
status_code=400,
|
| 469 |
-
detail="tenant_id, kb_id, and user_id are required"
|
| 470 |
-
)
|
| 471 |
-
try:
|
| 472 |
-
vector_store = get_vector_store()
|
| 473 |
-
deleted = vector_store.delete_by_filter({
|
| 474 |
-
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 475 |
-
"kb_id": kb_id,
|
| 476 |
-
"user_id": user_id
|
| 477 |
-
})
|
| 478 |
-
|
| 479 |
-
return {
|
| 480 |
-
"success": True,
|
| 481 |
-
"message": f"Cleared knowledge base. Deleted {deleted} chunks.",
|
| 482 |
-
"kb_id": kb_id
|
| 483 |
-
}
|
| 484 |
-
except Exception as e:
|
| 485 |
-
logger.error(f"Error clearing KB: {e}")
|
| 486 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
# ============== Chat Endpoints ==============
|
| 490 |
-
|
| 491 |
-
@app.post("/chat", response_model=ChatResponse)
|
| 492 |
-
@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key)
|
| 493 |
-
async def chat(chat_request: ChatRequest, request: Request):
|
| 494 |
-
"""
|
| 495 |
-
Process a chat message using RAG.
|
| 496 |
-
|
| 497 |
-
- Retrieves relevant context from knowledge base
|
| 498 |
-
- Generates answer using LLM
|
| 499 |
-
- Returns answer with citations
|
| 500 |
-
"""
|
| 501 |
-
conversation_id = "unknown"
|
| 502 |
-
try:
|
| 503 |
-
logger.info(f"=== CHAT REQUEST RECEIVED ===")
|
| 504 |
-
logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}")
|
| 505 |
-
logger.info(f"Request headers: {dict(request.headers)}")
|
| 506 |
-
|
| 507 |
-
# SECURITY: Get tenant_id and user_id from auth context
|
| 508 |
-
# In PROD: MUST come from JWT token (never from request body)
|
| 509 |
-
try:
|
| 510 |
-
auth_context = await get_auth_context(request)
|
| 511 |
-
except Exception as e:
|
| 512 |
-
logger.error(f"Error getting auth context: {e}", exc_info=True)
|
| 513 |
-
raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}")
|
| 514 |
-
|
| 515 |
-
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 516 |
-
user_id_from_auth = auth_context.get("user_id")
|
| 517 |
-
|
| 518 |
-
if settings.ENV == "prod":
|
| 519 |
-
if not tenant_id_from_auth or not user_id_from_auth:
|
| 520 |
-
raise HTTPException(
|
| 521 |
-
status_code=403,
|
| 522 |
-
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 523 |
-
)
|
| 524 |
-
# Override request values with auth context (security enforcement)
|
| 525 |
-
chat_request.tenant_id = tenant_id_from_auth
|
| 526 |
-
chat_request.user_id = user_id_from_auth
|
| 527 |
-
else:
|
| 528 |
-
# DEV mode: use from request if provided, otherwise from auth context
|
| 529 |
-
if not chat_request.tenant_id:
|
| 530 |
-
chat_request.tenant_id = tenant_id_from_auth
|
| 531 |
-
if not chat_request.user_id:
|
| 532 |
-
chat_request.user_id = user_id_from_auth
|
| 533 |
-
if not chat_request.tenant_id or not chat_request.user_id:
|
| 534 |
-
raise HTTPException(
|
| 535 |
-
status_code=400,
|
| 536 |
-
detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)"
|
| 537 |
-
)
|
| 538 |
-
|
| 539 |
-
# Log without PII in production
|
| 540 |
-
if settings.ENV == "prod":
|
| 541 |
-
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}")
|
| 542 |
-
else:
|
| 543 |
-
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...")
|
| 544 |
-
|
| 545 |
-
# Generate conversation ID if not provided
|
| 546 |
-
conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}"
|
| 547 |
-
|
| 548 |
-
# Get database session
|
| 549 |
-
try:
|
| 550 |
-
db = next(get_db())
|
| 551 |
-
except Exception as e:
|
| 552 |
-
logger.error(f"Database connection error: {e}", exc_info=True)
|
| 553 |
-
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
| 554 |
-
|
| 555 |
-
try:
|
| 556 |
-
# Ensure tenant exists in billing DB
|
| 557 |
-
ensure_tenant_exists(db, chat_request.tenant_id)
|
| 558 |
-
|
| 559 |
-
# Check quota BEFORE making LLM call
|
| 560 |
-
has_quota, quota_error = check_quota(db, chat_request.tenant_id)
|
| 561 |
-
if not has_quota:
|
| 562 |
-
logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}")
|
| 563 |
-
raise HTTPException(
|
| 564 |
-
status_code=402,
|
| 565 |
-
detail=quota_error or "AI quota exceeded. Upgrade your plan."
|
| 566 |
-
)
|
| 567 |
-
|
| 568 |
-
# Retrieve relevant context
|
| 569 |
-
retrieval_service = get_retrieval_service()
|
| 570 |
-
results, confidence, has_relevant = retrieval_service.retrieve(
|
| 571 |
-
query=chat_request.question,
|
| 572 |
-
tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation
|
| 573 |
-
kb_id=chat_request.kb_id,
|
| 574 |
-
user_id=chat_request.user_id
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}")
|
| 578 |
-
|
| 579 |
-
# Format context for LLM
|
| 580 |
-
context, citations_info = retrieval_service.get_context_for_llm(results)
|
| 581 |
-
|
| 582 |
-
logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}")
|
| 583 |
-
|
| 584 |
-
# Generate answer
|
| 585 |
-
answer_service = get_answer_service()
|
| 586 |
-
answer_result = answer_service.generate_answer(
|
| 587 |
-
question=chat_request.question,
|
| 588 |
-
context=context,
|
| 589 |
-
citations_info=citations_info,
|
| 590 |
-
confidence=confidence,
|
| 591 |
-
has_relevant_results=has_relevant
|
| 592 |
-
)
|
| 593 |
-
|
| 594 |
-
# Track usage if LLM was called (usage info present)
|
| 595 |
-
usage_info = answer_result.get("usage")
|
| 596 |
-
if usage_info:
|
| 597 |
-
try:
|
| 598 |
-
track_usage(
|
| 599 |
-
db=db,
|
| 600 |
-
tenant_id=chat_request.tenant_id,
|
| 601 |
-
user_id=chat_request.user_id,
|
| 602 |
-
kb_id=chat_request.kb_id,
|
| 603 |
-
provider=settings.LLM_PROVIDER,
|
| 604 |
-
model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL),
|
| 605 |
-
prompt_tokens=usage_info.get("prompt_tokens", 0),
|
| 606 |
-
completion_tokens=usage_info.get("completion_tokens", 0)
|
| 607 |
-
)
|
| 608 |
-
except Exception as e:
|
| 609 |
-
logger.error(f"Failed to track usage: {e}", exc_info=True)
|
| 610 |
-
# Don't fail the request if usage tracking fails
|
| 611 |
-
|
| 612 |
-
# Build metadata with refusal info
|
| 613 |
-
metadata = {
|
| 614 |
-
"chunks_retrieved": len(results),
|
| 615 |
-
"kb_id": chat_request.kb_id
|
| 616 |
-
}
|
| 617 |
-
if "refused" in answer_result:
|
| 618 |
-
metadata["refused"] = answer_result["refused"]
|
| 619 |
-
if "refusal_reason" in answer_result:
|
| 620 |
-
metadata["refusal_reason"] = answer_result["refusal_reason"]
|
| 621 |
-
if "verifier_passed" in answer_result:
|
| 622 |
-
metadata["verifier_passed"] = answer_result["verifier_passed"]
|
| 623 |
-
|
| 624 |
-
return ChatResponse(
|
| 625 |
-
success=True,
|
| 626 |
-
answer=answer_result["answer"],
|
| 627 |
-
citations=answer_result["citations"],
|
| 628 |
-
confidence=answer_result["confidence"],
|
| 629 |
-
from_knowledge_base=answer_result["from_knowledge_base"],
|
| 630 |
-
escalation_suggested=answer_result["escalation_suggested"],
|
| 631 |
-
conversation_id=conversation_id,
|
| 632 |
-
refused=answer_result.get("refused", False),
|
| 633 |
-
metadata=metadata
|
| 634 |
-
)
|
| 635 |
-
|
| 636 |
-
except ValueError as e:
|
| 637 |
-
# API key or configuration error
|
| 638 |
-
error_msg = str(e)
|
| 639 |
-
logger.error(f"Configuration error: {error_msg}")
|
| 640 |
-
if "API key" in error_msg.lower():
|
| 641 |
-
return ChatResponse(
|
| 642 |
-
success=False,
|
| 643 |
-
answer="⚠️ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.",
|
| 644 |
-
citations=[],
|
| 645 |
-
confidence=0.0,
|
| 646 |
-
from_knowledge_base=False,
|
| 647 |
-
escalation_suggested=True,
|
| 648 |
-
conversation_id=conversation_id,
|
| 649 |
-
metadata={"error": error_msg, "error_type": "configuration"}
|
| 650 |
-
)
|
| 651 |
-
else:
|
| 652 |
-
return ChatResponse(
|
| 653 |
-
success=False,
|
| 654 |
-
answer=f"Configuration error: {error_msg}",
|
| 655 |
-
citations=[],
|
| 656 |
-
confidence=0.0,
|
| 657 |
-
from_knowledge_base=False,
|
| 658 |
-
escalation_suggested=True,
|
| 659 |
-
conversation_id=conversation_id,
|
| 660 |
-
metadata={"error": error_msg}
|
| 661 |
-
)
|
| 662 |
-
except HTTPException:
|
| 663 |
-
# Re-raise HTTP exceptions (they have proper status codes)
|
| 664 |
-
raise
|
| 665 |
-
except Exception as e:
|
| 666 |
-
logger.error(f"Chat error: {e}", exc_info=True)
|
| 667 |
-
logger.error(f"Error type: {type(e).__name__}", exc_info=True)
|
| 668 |
-
return ChatResponse(
|
| 669 |
-
success=False,
|
| 670 |
-
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
|
| 671 |
-
citations=[],
|
| 672 |
-
confidence=0.0,
|
| 673 |
-
from_knowledge_base=False,
|
| 674 |
-
escalation_suggested=True,
|
| 675 |
-
conversation_id=conversation_id,
|
| 676 |
-
metadata={"error": str(e), "error_type": type(e).__name__}
|
| 677 |
-
)
|
| 678 |
-
except HTTPException:
|
| 679 |
-
# Re-raise HTTP exceptions from outer try block
|
| 680 |
-
raise
|
| 681 |
-
except Exception as e:
|
| 682 |
-
logger.error(f"Outer chat error: {e}", exc_info=True)
|
| 683 |
-
return ChatResponse(
|
| 684 |
-
success=False,
|
| 685 |
-
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
|
| 686 |
-
citations=[],
|
| 687 |
-
confidence=0.0,
|
| 688 |
-
from_knowledge_base=False,
|
| 689 |
-
escalation_suggested=True,
|
| 690 |
-
conversation_id=conversation_id,
|
| 691 |
-
metadata={"error": str(e), "error_type": type(e).__name__}
|
| 692 |
-
)
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
# ============== Utility Endpoints ==============
|
| 696 |
-
|
| 697 |
-
@app.get("/kb/search")
|
| 698 |
-
@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key)
|
| 699 |
-
async def search_kb(
|
| 700 |
-
request: Request,
|
| 701 |
-
query: str,
|
| 702 |
-
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 703 |
-
kb_id: Optional[str] = None,
|
| 704 |
-
user_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 705 |
-
top_k: int = 5
|
| 706 |
-
):
|
| 707 |
-
"""
|
| 708 |
-
Search the knowledge base without generating an answer.
|
| 709 |
-
Useful for debugging and testing retrieval.
|
| 710 |
-
"""
|
| 711 |
-
# SECURITY: Extract tenant_id from auth token in production
|
| 712 |
-
if settings.ENV == "prod":
|
| 713 |
-
auth_context = await require_auth(request)
|
| 714 |
-
tenant_id = auth_context.get("tenant_id")
|
| 715 |
-
user_id = auth_context.get("user_id")
|
| 716 |
-
if not tenant_id or not user_id:
|
| 717 |
-
raise HTTPException(
|
| 718 |
-
status_code=403,
|
| 719 |
-
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 720 |
-
)
|
| 721 |
-
elif not tenant_id or not kb_id or not user_id:
|
| 722 |
-
raise HTTPException(
|
| 723 |
-
status_code=400,
|
| 724 |
-
detail="tenant_id, kb_id, and user_id are required"
|
| 725 |
-
)
|
| 726 |
-
|
| 727 |
-
try:
|
| 728 |
-
retrieval_service = get_retrieval_service()
|
| 729 |
-
results, confidence, has_relevant = retrieval_service.retrieve(
|
| 730 |
-
query=query,
|
| 731 |
-
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 732 |
-
kb_id=kb_id,
|
| 733 |
-
user_id=user_id,
|
| 734 |
-
top_k=top_k
|
| 735 |
-
)
|
| 736 |
-
|
| 737 |
-
return {
|
| 738 |
-
"success": True,
|
| 739 |
-
"results": [
|
| 740 |
-
{
|
| 741 |
-
"chunk_id": r.chunk_id,
|
| 742 |
-
"content": r.content[:500] + "..." if len(r.content) > 500 else r.content,
|
| 743 |
-
"metadata": r.metadata,
|
| 744 |
-
"similarity_score": r.similarity_score
|
| 745 |
-
}
|
| 746 |
-
for r in results
|
| 747 |
-
],
|
| 748 |
-
"confidence": confidence,
|
| 749 |
-
"has_relevant_results": has_relevant
|
| 750 |
-
}
|
| 751 |
-
except Exception as e:
|
| 752 |
-
logger.error(f"Search error: {e}")
|
| 753 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
# ============== Billing & Usage Endpoints ==============
|
| 757 |
-
|
| 758 |
-
@app.get("/billing/usage", response_model=UsageResponse)
|
| 759 |
-
async def get_usage(
|
| 760 |
-
request: Request,
|
| 761 |
-
range: str = "month", # "day" or "month"
|
| 762 |
-
year: Optional[int] = None,
|
| 763 |
-
month: Optional[int] = None,
|
| 764 |
-
day: Optional[int] = None
|
| 765 |
-
):
|
| 766 |
-
"""
|
| 767 |
-
Get usage statistics for the current tenant.
|
| 768 |
-
|
| 769 |
-
Args:
|
| 770 |
-
range: "day" or "month"
|
| 771 |
-
year: Year (optional, defaults to current)
|
| 772 |
-
month: Month 1-12 (optional, defaults to current)
|
| 773 |
-
day: Day 1-31 (optional, defaults to current, only for range="day")
|
| 774 |
-
"""
|
| 775 |
-
# Get tenant from auth
|
| 776 |
-
auth_context = await get_auth_context(request)
|
| 777 |
-
tenant_id = auth_context.get("tenant_id")
|
| 778 |
-
|
| 779 |
-
if not tenant_id:
|
| 780 |
-
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 781 |
-
|
| 782 |
-
db = next(get_db())
|
| 783 |
-
|
| 784 |
-
try:
|
| 785 |
-
from app.db.models import UsageDaily, UsageMonthly
|
| 786 |
-
from datetime import datetime
|
| 787 |
-
from calendar import monthrange
|
| 788 |
-
|
| 789 |
-
now = datetime.utcnow()
|
| 790 |
-
target_year = year or now.year
|
| 791 |
-
target_month = month or now.month
|
| 792 |
-
|
| 793 |
-
if range == "day":
|
| 794 |
-
target_day = day or now.day
|
| 795 |
-
date_start = datetime(target_year, target_month, target_day)
|
| 796 |
-
|
| 797 |
-
daily = db.query(UsageDaily).filter(
|
| 798 |
-
UsageDaily.tenant_id == tenant_id,
|
| 799 |
-
UsageDaily.date == date_start
|
| 800 |
-
).first()
|
| 801 |
-
|
| 802 |
-
if not daily:
|
| 803 |
-
return UsageResponse(
|
| 804 |
-
tenant_id=tenant_id,
|
| 805 |
-
period="day",
|
| 806 |
-
total_requests=0,
|
| 807 |
-
total_tokens=0,
|
| 808 |
-
total_cost_usd=0.0,
|
| 809 |
-
start_date=date_start,
|
| 810 |
-
end_date=date_start
|
| 811 |
-
)
|
| 812 |
-
|
| 813 |
-
return UsageResponse(
|
| 814 |
-
tenant_id=tenant_id,
|
| 815 |
-
period="day",
|
| 816 |
-
total_requests=daily.total_requests,
|
| 817 |
-
total_tokens=daily.total_tokens,
|
| 818 |
-
total_cost_usd=daily.total_cost_usd,
|
| 819 |
-
gemini_requests=daily.gemini_requests,
|
| 820 |
-
openai_requests=daily.openai_requests,
|
| 821 |
-
start_date=daily.date,
|
| 822 |
-
end_date=daily.date
|
| 823 |
-
)
|
| 824 |
-
else: # month
|
| 825 |
-
monthly = db.query(UsageMonthly).filter(
|
| 826 |
-
UsageMonthly.tenant_id == tenant_id,
|
| 827 |
-
UsageMonthly.year == target_year,
|
| 828 |
-
UsageMonthly.month == target_month
|
| 829 |
-
).first()
|
| 830 |
-
|
| 831 |
-
if not monthly:
|
| 832 |
-
# Calculate date range for the month
|
| 833 |
-
_, last_day = monthrange(target_year, target_month)
|
| 834 |
-
start_date = datetime(target_year, target_month, 1)
|
| 835 |
-
end_date = datetime(target_year, target_month, last_day)
|
| 836 |
-
|
| 837 |
-
return UsageResponse(
|
| 838 |
-
tenant_id=tenant_id,
|
| 839 |
-
period="month",
|
| 840 |
-
total_requests=0,
|
| 841 |
-
total_tokens=0,
|
| 842 |
-
total_cost_usd=0.0,
|
| 843 |
-
start_date=start_date,
|
| 844 |
-
end_date=end_date
|
| 845 |
-
)
|
| 846 |
-
|
| 847 |
-
_, last_day = monthrange(monthly.year, monthly.month)
|
| 848 |
-
start_date = datetime(monthly.year, monthly.month, 1)
|
| 849 |
-
end_date = datetime(monthly.year, monthly.month, last_day)
|
| 850 |
-
|
| 851 |
-
return UsageResponse(
|
| 852 |
-
tenant_id=tenant_id,
|
| 853 |
-
period="month",
|
| 854 |
-
total_requests=monthly.total_requests,
|
| 855 |
-
total_tokens=monthly.total_tokens,
|
| 856 |
-
total_cost_usd=monthly.total_cost_usd,
|
| 857 |
-
gemini_requests=monthly.gemini_requests,
|
| 858 |
-
openai_requests=monthly.openai_requests,
|
| 859 |
-
start_date=start_date,
|
| 860 |
-
end_date=end_date
|
| 861 |
-
)
|
| 862 |
-
except Exception as e:
|
| 863 |
-
logger.error(f"Error getting usage: {e}", exc_info=True)
|
| 864 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
@app.get("/billing/limits", response_model=PlanLimitsResponse)
|
| 868 |
-
async def get_limits(request: Request):
|
| 869 |
-
"""Get current plan limits and usage for the tenant."""
|
| 870 |
-
# Get tenant from auth
|
| 871 |
-
auth_context = await get_auth_context(request)
|
| 872 |
-
tenant_id = auth_context.get("tenant_id")
|
| 873 |
-
|
| 874 |
-
if not tenant_id:
|
| 875 |
-
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 876 |
-
|
| 877 |
-
db = next(get_db())
|
| 878 |
-
|
| 879 |
-
try:
|
| 880 |
-
from app.billing.quota import get_tenant_plan, get_monthly_usage
|
| 881 |
-
from datetime import datetime
|
| 882 |
-
|
| 883 |
-
plan = get_tenant_plan(db, tenant_id)
|
| 884 |
-
if not plan:
|
| 885 |
-
# Default to starter
|
| 886 |
-
plan_name = "starter"
|
| 887 |
-
monthly_limit = 500
|
| 888 |
-
else:
|
| 889 |
-
plan_name = plan.plan_name
|
| 890 |
-
monthly_limit = plan.monthly_chat_limit
|
| 891 |
-
|
| 892 |
-
# Get current month usage
|
| 893 |
-
now = datetime.utcnow()
|
| 894 |
-
monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
|
| 895 |
-
current_usage = monthly_usage.total_requests if monthly_usage else 0
|
| 896 |
-
|
| 897 |
-
remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage)
|
| 898 |
-
|
| 899 |
-
return PlanLimitsResponse(
|
| 900 |
-
tenant_id=tenant_id,
|
| 901 |
-
plan_name=plan_name,
|
| 902 |
-
monthly_chat_limit=monthly_limit,
|
| 903 |
-
current_month_usage=current_usage,
|
| 904 |
-
remaining_chats=remaining
|
| 905 |
-
)
|
| 906 |
-
except Exception as e:
|
| 907 |
-
logger.error(f"Error getting limits: {e}", exc_info=True)
|
| 908 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
@app.post("/billing/plan")
|
| 912 |
-
async def set_plan(request_body: SetPlanRequest, http_request: Request):
|
| 913 |
-
"""
|
| 914 |
-
Set tenant's subscription plan (admin only in production).
|
| 915 |
-
|
| 916 |
-
In dev mode, allows any tenant to set their plan.
|
| 917 |
-
In prod mode, should be restricted to admin users.
|
| 918 |
-
"""
|
| 919 |
-
# Get tenant from auth
|
| 920 |
-
auth_context = await get_auth_context(http_request)
|
| 921 |
-
auth_tenant_id = auth_context.get("tenant_id")
|
| 922 |
-
|
| 923 |
-
# In prod, verify admin role (placeholder - implement actual admin check)
|
| 924 |
-
if settings.ENV == "prod":
|
| 925 |
-
# TODO: Add admin role check
|
| 926 |
-
if auth_tenant_id != request_body.tenant_id:
|
| 927 |
-
raise HTTPException(status_code=403, detail="Cannot set plan for other tenants")
|
| 928 |
-
|
| 929 |
-
db = next(get_db())
|
| 930 |
-
|
| 931 |
-
try:
|
| 932 |
-
from app.billing.quota import set_tenant_plan
|
| 933 |
-
|
| 934 |
-
plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name)
|
| 935 |
-
|
| 936 |
-
return {
|
| 937 |
-
"success": True,
|
| 938 |
-
"tenant_id": request_body.tenant_id,
|
| 939 |
-
"plan_name": plan.plan_name,
|
| 940 |
-
"monthly_chat_limit": plan.monthly_chat_limit
|
| 941 |
-
}
|
| 942 |
-
except ValueError as e:
|
| 943 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 944 |
-
except Exception as e:
|
| 945 |
-
logger.error(f"Error setting plan: {e}", exc_info=True)
|
| 946 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
@app.get("/billing/cost-report", response_model=CostReportResponse)
|
| 950 |
-
async def get_cost_report(
|
| 951 |
-
request: Request,
|
| 952 |
-
range: str = "month",
|
| 953 |
-
year: Optional[int] = None,
|
| 954 |
-
month: Optional[int] = None
|
| 955 |
-
):
|
| 956 |
-
"""Get cost report with breakdown by provider and model."""
|
| 957 |
-
# Get tenant from auth
|
| 958 |
-
auth_context = await get_auth_context(request)
|
| 959 |
-
tenant_id = auth_context.get("tenant_id")
|
| 960 |
-
|
| 961 |
-
if not tenant_id:
|
| 962 |
-
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 963 |
-
|
| 964 |
-
db = next(get_db())
|
| 965 |
-
|
| 966 |
-
try:
|
| 967 |
-
from app.db.models import UsageEvent
|
| 968 |
-
from datetime import datetime
|
| 969 |
-
from sqlalchemy import func, and_
|
| 970 |
-
|
| 971 |
-
now = datetime.utcnow()
|
| 972 |
-
target_year = year or now.year
|
| 973 |
-
target_month = month or now.month
|
| 974 |
-
|
| 975 |
-
# Query usage events for the period
|
| 976 |
-
if range == "month":
|
| 977 |
-
query = db.query(UsageEvent).filter(
|
| 978 |
-
and_(
|
| 979 |
-
UsageEvent.tenant_id == tenant_id,
|
| 980 |
-
func.extract('year', UsageEvent.request_timestamp) == target_year,
|
| 981 |
-
func.extract('month', UsageEvent.request_timestamp) == target_month
|
| 982 |
-
)
|
| 983 |
-
)
|
| 984 |
-
else: # all time
|
| 985 |
-
query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id)
|
| 986 |
-
|
| 987 |
-
events = query.all()
|
| 988 |
-
|
| 989 |
-
# Calculate totals
|
| 990 |
-
total_cost = sum(e.estimated_cost_usd for e in events)
|
| 991 |
-
total_requests = len(events)
|
| 992 |
-
total_tokens = sum(e.total_tokens for e in events)
|
| 993 |
-
|
| 994 |
-
# Breakdown by provider
|
| 995 |
-
breakdown_by_provider = {}
|
| 996 |
-
for event in events:
|
| 997 |
-
provider = event.provider
|
| 998 |
-
if provider not in breakdown_by_provider:
|
| 999 |
-
breakdown_by_provider[provider] = {
|
| 1000 |
-
"requests": 0,
|
| 1001 |
-
"tokens": 0,
|
| 1002 |
-
"cost_usd": 0.0
|
| 1003 |
-
}
|
| 1004 |
-
breakdown_by_provider[provider]["requests"] += 1
|
| 1005 |
-
breakdown_by_provider[provider]["tokens"] += event.total_tokens
|
| 1006 |
-
breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd
|
| 1007 |
-
|
| 1008 |
-
# Breakdown by model
|
| 1009 |
-
breakdown_by_model = {}
|
| 1010 |
-
for event in events:
|
| 1011 |
-
model = event.model
|
| 1012 |
-
if model not in breakdown_by_model:
|
| 1013 |
-
breakdown_by_model[model] = {
|
| 1014 |
-
"requests": 0,
|
| 1015 |
-
"tokens": 0,
|
| 1016 |
-
"cost_usd": 0.0
|
| 1017 |
-
}
|
| 1018 |
-
breakdown_by_model[model]["requests"] += 1
|
| 1019 |
-
breakdown_by_model[model]["tokens"] += event.total_tokens
|
| 1020 |
-
breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd
|
| 1021 |
-
|
| 1022 |
-
return CostReportResponse(
|
| 1023 |
-
tenant_id=tenant_id,
|
| 1024 |
-
period=range,
|
| 1025 |
-
total_cost_usd=total_cost,
|
| 1026 |
-
total_requests=total_requests,
|
| 1027 |
-
total_tokens=total_tokens,
|
| 1028 |
-
breakdown_by_provider=breakdown_by_provider,
|
| 1029 |
-
breakdown_by_model=breakdown_by_model
|
| 1030 |
-
)
|
| 1031 |
-
except Exception as e:
|
| 1032 |
-
logger.error(f"Error getting cost report: {e}", exc_info=True)
|
| 1033 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
if __name__ == "__main__":
|
| 1037 |
-
import uvicorn
|
| 1038 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 1039 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for ClientSphere RAG Backend.
|
| 3 |
+
Provides endpoints for knowledge base management and chat.
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.exceptions import RequestValidationError
|
| 8 |
+
from fastapi.responses import JSONResponse
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import shutil
|
| 11 |
+
import uuid
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Optional
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
from app.config import settings
|
| 17 |
+
from app.middleware.auth import get_auth_context, require_auth
|
| 18 |
+
from app.middleware.rate_limit import (
|
| 19 |
+
limiter,
|
| 20 |
+
get_tenant_rate_limit_key,
|
| 21 |
+
RateLimitExceeded,
|
| 22 |
+
_rate_limit_exceeded_handler
|
| 23 |
+
)
|
| 24 |
+
from app.models.schemas import (
|
| 25 |
+
UploadResponse,
|
| 26 |
+
ChatRequest,
|
| 27 |
+
ChatResponse,
|
| 28 |
+
KnowledgeBaseStats,
|
| 29 |
+
HealthResponse,
|
| 30 |
+
DocumentStatus,
|
| 31 |
+
Citation,
|
| 32 |
+
)
|
| 33 |
+
from app.models.billing_schemas import (
|
| 34 |
+
UsageResponse,
|
| 35 |
+
PlanLimitsResponse,
|
| 36 |
+
CostReportResponse,
|
| 37 |
+
SetPlanRequest
|
| 38 |
+
)
|
| 39 |
+
from app.rag.ingest import parser
|
| 40 |
+
from app.rag.chunking import chunker
|
| 41 |
+
from app.rag.embeddings import get_embedding_service
|
| 42 |
+
from app.rag.vectorstore import get_vector_store
|
| 43 |
+
from app.rag.retrieval import get_retrieval_service
|
| 44 |
+
from app.rag.answer import get_answer_service
|
| 45 |
+
from app.db.database import get_db, init_db
|
| 46 |
+
from app.billing.quota import check_quota, ensure_tenant_exists
|
| 47 |
+
from app.billing.usage_tracker import track_usage
|
| 48 |
+
|
| 49 |
+
logging.basicConfig(level=logging.INFO)
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
# Initialize FastAPI app
|
| 53 |
+
app = FastAPI(
|
| 54 |
+
title=settings.APP_NAME,
|
| 55 |
+
description="RAG-based customer support chatbot API",
|
| 56 |
+
version="1.0.0",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Initialize database on startup
|
| 60 |
+
@app.on_event("startup")
|
| 61 |
+
async def startup_event():
|
| 62 |
+
"""Initialize database on application startup."""
|
| 63 |
+
init_db()
|
| 64 |
+
logger.info("Database initialized")
|
| 65 |
+
|
| 66 |
+
# Configure CORS - SECURITY: Restrict in production
|
| 67 |
+
if settings.ALLOWED_ORIGINS == "*":
|
| 68 |
+
allowed_origins = ["*"]
|
| 69 |
+
else:
|
| 70 |
+
# Split by comma and strip whitespace
|
| 71 |
+
allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()]
|
| 72 |
+
|
| 73 |
+
# Default to allowing localhost if no origins specified
|
| 74 |
+
if not allowed_origins or allowed_origins == ["*"]:
|
| 75 |
+
allowed_origins = ["*"] # Allow all in dev mode
|
| 76 |
+
|
| 77 |
+
logger.info(f"CORS configured with origins: {allowed_origins}")
|
| 78 |
+
|
| 79 |
+
app.add_middleware(
|
| 80 |
+
CORSMiddleware,
|
| 81 |
+
allow_origins=allowed_origins,
|
| 82 |
+
allow_credentials=True,
|
| 83 |
+
allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight
|
| 84 |
+
allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Configure rate limiting
|
| 88 |
+
app.state.limiter = limiter
|
| 89 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 90 |
+
|
| 91 |
+
# Add exception handler for validation errors
|
| 92 |
+
@app.exception_handler(RequestValidationError)
|
| 93 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 94 |
+
"""Handle request validation errors with detailed logging."""
|
| 95 |
+
body = await request.body()
|
| 96 |
+
logger.error(f"Request validation error: {exc.errors()}")
|
| 97 |
+
logger.error(f"Request body (raw): {body}")
|
| 98 |
+
logger.error(f"Request headers: {dict(request.headers)}")
|
| 99 |
+
return JSONResponse(
|
| 100 |
+
status_code=422,
|
| 101 |
+
content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')}
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Add exception handler for validation errors
|
| 105 |
+
from fastapi.exceptions import RequestValidationError
|
| 106 |
+
from fastapi.responses import JSONResponse
|
| 107 |
+
|
| 108 |
+
@app.exception_handler(RequestValidationError)
|
| 109 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 110 |
+
"""Handle request validation errors with detailed logging."""
|
| 111 |
+
logger.error(f"Request validation error: {exc.errors()}")
|
| 112 |
+
logger.error(f"Request body: {await request.body()}")
|
| 113 |
+
return JSONResponse(
|
| 114 |
+
status_code=422,
|
| 115 |
+
content={"detail": exc.errors(), "body": str(await request.body())}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ============== Health & Status Endpoints ==============
|
| 120 |
+
|
| 121 |
+
@app.get("/", response_model=HealthResponse)
|
| 122 |
+
async def root():
|
| 123 |
+
"""Root endpoint with basic info."""
|
| 124 |
+
return HealthResponse(
|
| 125 |
+
status="ok",
|
| 126 |
+
version="1.0.0",
|
| 127 |
+
vector_db_connected=True,
|
| 128 |
+
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.get("/health", response_model=HealthResponse)
|
| 133 |
+
async def health_check():
|
| 134 |
+
"""Health check endpoint."""
|
| 135 |
+
try:
|
| 136 |
+
vector_store = get_vector_store()
|
| 137 |
+
stats = vector_store.get_stats()
|
| 138 |
+
|
| 139 |
+
return HealthResponse(
|
| 140 |
+
status="healthy",
|
| 141 |
+
version="1.0.0",
|
| 142 |
+
vector_db_connected=True,
|
| 143 |
+
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 144 |
+
)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Health check failed: {e}")
|
| 147 |
+
return HealthResponse(
|
| 148 |
+
status="unhealthy",
|
| 149 |
+
version="1.0.0",
|
| 150 |
+
vector_db_connected=False,
|
| 151 |
+
llm_configured=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.get("/health/live")
|
| 156 |
+
async def liveness():
|
| 157 |
+
"""Kubernetes liveness probe - always returns alive."""
|
| 158 |
+
return {"status": "alive"}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@app.get("/health/ready")
|
| 162 |
+
async def readiness():
|
| 163 |
+
"""Kubernetes readiness probe - checks dependencies."""
|
| 164 |
+
checks = {
|
| 165 |
+
"vector_db": False,
|
| 166 |
+
"llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Check vector DB connection
|
| 170 |
+
try:
|
| 171 |
+
vector_store = get_vector_store()
|
| 172 |
+
vector_store.get_stats()
|
| 173 |
+
checks["vector_db"] = True
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.warning(f"Vector DB check failed: {e}")
|
| 176 |
+
checks["vector_db"] = False
|
| 177 |
+
|
| 178 |
+
# All checks must pass
|
| 179 |
+
if all(checks.values()):
|
| 180 |
+
return {"status": "ready", "checks": checks}
|
| 181 |
+
else:
|
| 182 |
+
from fastapi import HTTPException
|
| 183 |
+
raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks})
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ============== Knowledge Base Endpoints ==============
|
| 187 |
+
|
| 188 |
+
@app.post("/kb/upload", response_model=UploadResponse)
|
| 189 |
+
@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key)
|
| 190 |
+
async def upload_document(
|
| 191 |
+
background_tasks: BackgroundTasks,
|
| 192 |
+
request: Request,
|
| 193 |
+
file: UploadFile = File(...),
|
| 194 |
+
tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
|
| 195 |
+
user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
|
| 196 |
+
kb_id: str = Form(...)
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Upload a document to the knowledge base.
|
| 200 |
+
|
| 201 |
+
- Saves file to disk
|
| 202 |
+
- Parses and chunks the document
|
| 203 |
+
- Generates embeddings
|
| 204 |
+
- Stores in vector database
|
| 205 |
+
"""
|
| 206 |
+
# SECURITY: Extract tenant_id from auth token in production
|
| 207 |
+
if settings.ENV == "prod":
|
| 208 |
+
auth_context = await require_auth(request)
|
| 209 |
+
tenant_id = auth_context.get("tenant_id")
|
| 210 |
+
if not tenant_id:
|
| 211 |
+
raise HTTPException(
|
| 212 |
+
status_code=403,
|
| 213 |
+
detail="tenant_id must come from authentication token in production mode"
|
| 214 |
+
)
|
| 215 |
+
elif not tenant_id:
|
| 216 |
+
raise HTTPException(
|
| 217 |
+
status_code=400,
|
| 218 |
+
detail="tenant_id is required"
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Validate file type
|
| 222 |
+
file_ext = Path(file.filename).suffix.lower()
|
| 223 |
+
if file_ext not in parser.SUPPORTED_EXTENSIONS:
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=400,
|
| 226 |
+
detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Validate file size (SECURITY)
|
| 230 |
+
file.file.seek(0, 2) # Seek to end
|
| 231 |
+
file_size = file.file.tell()
|
| 232 |
+
file.file.seek(0) # Reset to start
|
| 233 |
+
max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024
|
| 234 |
+
if file_size > max_size_bytes:
|
| 235 |
+
raise HTTPException(
|
| 236 |
+
status_code=400,
|
| 237 |
+
detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Generate document ID
|
| 241 |
+
doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}"
|
| 242 |
+
|
| 243 |
+
# Save file to uploads directory
|
| 244 |
+
upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}"
|
| 245 |
+
try:
|
| 246 |
+
with open(upload_path, "wb") as buffer:
|
| 247 |
+
shutil.copyfileobj(file.file, buffer)
|
| 248 |
+
logger.info(f"Saved file: {upload_path}")
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"Error saving file: {e}")
|
| 251 |
+
raise HTTPException(status_code=500, detail="Failed to save file")
|
| 252 |
+
|
| 253 |
+
# Process document in background
|
| 254 |
+
background_tasks.add_task(
|
| 255 |
+
process_document,
|
| 256 |
+
upload_path,
|
| 257 |
+
tenant_id, # CRITICAL: Multi-tenant isolation
|
| 258 |
+
user_id,
|
| 259 |
+
kb_id,
|
| 260 |
+
file.filename,
|
| 261 |
+
doc_id
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return UploadResponse(
|
| 265 |
+
success=True,
|
| 266 |
+
message="Document upload started. Processing in background.",
|
| 267 |
+
document_id=doc_id,
|
| 268 |
+
file_name=file.filename,
|
| 269 |
+
chunks_created=0,
|
| 270 |
+
status=DocumentStatus.PROCESSING
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
async def process_document(
|
| 275 |
+
file_path: Path,
|
| 276 |
+
tenant_id: str, # CRITICAL: Multi-tenant isolation
|
| 277 |
+
user_id: str,
|
| 278 |
+
kb_id: str,
|
| 279 |
+
original_filename: str,
|
| 280 |
+
document_id: str
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Background task to process an uploaded document.
|
| 284 |
+
"""
|
| 285 |
+
try:
|
| 286 |
+
logger.info(f"Processing document: {original_filename}")
|
| 287 |
+
|
| 288 |
+
# Parse document
|
| 289 |
+
parsed_doc = parser.parse(file_path)
|
| 290 |
+
logger.info(f"Parsed document: {len(parsed_doc.text)} characters")
|
| 291 |
+
|
| 292 |
+
# Chunk document
|
| 293 |
+
chunks = chunker.chunk_text(
|
| 294 |
+
parsed_doc.text,
|
| 295 |
+
page_numbers=parsed_doc.page_map
|
| 296 |
+
)
|
| 297 |
+
logger.info(f"Created {len(chunks)} chunks")
|
| 298 |
+
|
| 299 |
+
if not chunks:
|
| 300 |
+
logger.warning(f"No chunks created from {original_filename}")
|
| 301 |
+
return
|
| 302 |
+
|
| 303 |
+
# Create metadata for each chunk
|
| 304 |
+
metadatas = []
|
| 305 |
+
chunk_ids = []
|
| 306 |
+
chunk_texts = []
|
| 307 |
+
|
| 308 |
+
for chunk in chunks:
|
| 309 |
+
metadata = chunker.create_chunk_metadata(
|
| 310 |
+
chunk=chunk,
|
| 311 |
+
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 312 |
+
kb_id=kb_id,
|
| 313 |
+
user_id=user_id,
|
| 314 |
+
file_name=original_filename,
|
| 315 |
+
file_type=parsed_doc.file_type,
|
| 316 |
+
total_chunks=len(chunks),
|
| 317 |
+
document_id=document_id
|
| 318 |
+
)
|
| 319 |
+
metadatas.append(metadata)
|
| 320 |
+
chunk_ids.append(metadata["chunk_id"])
|
| 321 |
+
chunk_texts.append(chunk.content)
|
| 322 |
+
|
| 323 |
+
# Generate embeddings
|
| 324 |
+
embedding_service = get_embedding_service()
|
| 325 |
+
embeddings = embedding_service.embed_texts(chunk_texts)
|
| 326 |
+
logger.info(f"Generated {len(embeddings)} embeddings")
|
| 327 |
+
|
| 328 |
+
# Store in vector database
|
| 329 |
+
vector_store = get_vector_store()
|
| 330 |
+
vector_store.add_documents(
|
| 331 |
+
documents=chunk_texts,
|
| 332 |
+
embeddings=embeddings,
|
| 333 |
+
metadatas=metadatas,
|
| 334 |
+
ids=chunk_ids
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored")
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.error(f"Error processing document {original_filename}: {e}")
|
| 341 |
+
raise
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@app.get("/kb/stats", response_model=KnowledgeBaseStats)
|
| 345 |
+
async def get_kb_stats(
|
| 346 |
+
request: Request,
|
| 347 |
+
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 348 |
+
kb_id: Optional[str] = None,
|
| 349 |
+
user_id: Optional[str] = None # Optional in dev, ignored in prod
|
| 350 |
+
):
|
| 351 |
+
"""Get statistics for a knowledge base."""
|
| 352 |
+
# SECURITY: Get tenant_id and user_id from auth context
|
| 353 |
+
auth_context = await get_auth_context(request)
|
| 354 |
+
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 355 |
+
user_id_from_auth = auth_context.get("user_id")
|
| 356 |
+
|
| 357 |
+
if settings.ENV == "prod":
|
| 358 |
+
if not tenant_id_from_auth or not user_id_from_auth:
|
| 359 |
+
raise HTTPException(
|
| 360 |
+
status_code=403,
|
| 361 |
+
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 362 |
+
)
|
| 363 |
+
tenant_id = tenant_id_from_auth
|
| 364 |
+
user_id = user_id_from_auth
|
| 365 |
+
else:
|
| 366 |
+
tenant_id = tenant_id or tenant_id_from_auth
|
| 367 |
+
user_id = user_id or user_id_from_auth
|
| 368 |
+
if not tenant_id or not kb_id or not user_id:
|
| 369 |
+
raise HTTPException(
|
| 370 |
+
status_code=400,
|
| 371 |
+
detail="tenant_id, kb_id, and user_id are required"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
try:
|
| 375 |
+
vector_store = get_vector_store()
|
| 376 |
+
stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id)
|
| 377 |
+
|
| 378 |
+
return KnowledgeBaseStats(
|
| 379 |
+
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 380 |
+
kb_id=kb_id,
|
| 381 |
+
user_id=user_id,
|
| 382 |
+
total_documents=len(stats.get("file_names", [])),
|
| 383 |
+
total_chunks=stats.get("total_chunks", 0),
|
| 384 |
+
file_names=stats.get("file_names", []),
|
| 385 |
+
last_updated=datetime.utcnow()
|
| 386 |
+
)
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.error(f"Error getting KB stats: {e}")
|
| 389 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@app.delete("/kb/document")
|
| 393 |
+
async def delete_document(
|
| 394 |
+
request: Request,
|
| 395 |
+
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 396 |
+
kb_id: Optional[str] = None,
|
| 397 |
+
user_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 398 |
+
file_name: Optional[str] = None
|
| 399 |
+
):
|
| 400 |
+
"""Delete a document from the knowledge base."""
|
| 401 |
+
# SECURITY: Get tenant_id and user_id from auth context
|
| 402 |
+
auth_context = await get_auth_context(request)
|
| 403 |
+
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 404 |
+
user_id_from_auth = auth_context.get("user_id")
|
| 405 |
+
|
| 406 |
+
if settings.ENV == "prod":
|
| 407 |
+
if not tenant_id_from_auth or not user_id_from_auth:
|
| 408 |
+
raise HTTPException(
|
| 409 |
+
status_code=403,
|
| 410 |
+
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 411 |
+
)
|
| 412 |
+
tenant_id = tenant_id_from_auth
|
| 413 |
+
user_id = user_id_from_auth
|
| 414 |
+
else:
|
| 415 |
+
tenant_id = tenant_id or tenant_id_from_auth
|
| 416 |
+
user_id = user_id or user_id_from_auth
|
| 417 |
+
if not tenant_id or not kb_id or not user_id or not file_name:
|
| 418 |
+
raise HTTPException(
|
| 419 |
+
status_code=400,
|
| 420 |
+
detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
vector_store = get_vector_store()
|
| 425 |
+
deleted = vector_store.delete_by_filter({
|
| 426 |
+
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 427 |
+
"kb_id": kb_id,
|
| 428 |
+
"user_id": user_id,
|
| 429 |
+
"file_name": file_name
|
| 430 |
+
})
|
| 431 |
+
|
| 432 |
+
return {
|
| 433 |
+
"success": True,
|
| 434 |
+
"message": f"Deleted {deleted} chunks",
|
| 435 |
+
"file_name": file_name
|
| 436 |
+
}
|
| 437 |
+
except Exception as e:
|
| 438 |
+
logger.error(f"Error deleting document: {e}")
|
| 439 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@app.delete("/kb/clear")
|
| 443 |
+
async def clear_kb(
|
| 444 |
+
request: Request,
|
| 445 |
+
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 446 |
+
kb_id: Optional[str] = None,
|
| 447 |
+
user_id: Optional[str] = None # Optional in dev, ignored in prod
|
| 448 |
+
):
|
| 449 |
+
"""Clear all documents from a knowledge base."""
|
| 450 |
+
# SECURITY: Get tenant_id and user_id from auth context
|
| 451 |
+
auth_context = await get_auth_context(request)
|
| 452 |
+
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 453 |
+
user_id_from_auth = auth_context.get("user_id")
|
| 454 |
+
|
| 455 |
+
if settings.ENV == "prod":
|
| 456 |
+
if not tenant_id_from_auth or not user_id_from_auth:
|
| 457 |
+
raise HTTPException(
|
| 458 |
+
status_code=403,
|
| 459 |
+
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 460 |
+
)
|
| 461 |
+
tenant_id = tenant_id_from_auth
|
| 462 |
+
user_id = user_id_from_auth
|
| 463 |
+
else:
|
| 464 |
+
tenant_id = tenant_id or tenant_id_from_auth
|
| 465 |
+
user_id = user_id or user_id_from_auth
|
| 466 |
+
if not tenant_id or not kb_id or not user_id:
|
| 467 |
+
raise HTTPException(
|
| 468 |
+
status_code=400,
|
| 469 |
+
detail="tenant_id, kb_id, and user_id are required"
|
| 470 |
+
)
|
| 471 |
+
try:
|
| 472 |
+
vector_store = get_vector_store()
|
| 473 |
+
deleted = vector_store.delete_by_filter({
|
| 474 |
+
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 475 |
+
"kb_id": kb_id,
|
| 476 |
+
"user_id": user_id
|
| 477 |
+
})
|
| 478 |
+
|
| 479 |
+
return {
|
| 480 |
+
"success": True,
|
| 481 |
+
"message": f"Cleared knowledge base. Deleted {deleted} chunks.",
|
| 482 |
+
"kb_id": kb_id
|
| 483 |
+
}
|
| 484 |
+
except Exception as e:
|
| 485 |
+
logger.error(f"Error clearing KB: {e}")
|
| 486 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# ============== Chat Endpoints ==============
|
| 490 |
+
|
| 491 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 492 |
+
@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key)
|
| 493 |
+
async def chat(chat_request: ChatRequest, request: Request):
|
| 494 |
+
"""
|
| 495 |
+
Process a chat message using RAG.
|
| 496 |
+
|
| 497 |
+
- Retrieves relevant context from knowledge base
|
| 498 |
+
- Generates answer using LLM
|
| 499 |
+
- Returns answer with citations
|
| 500 |
+
"""
|
| 501 |
+
conversation_id = "unknown"
|
| 502 |
+
try:
|
| 503 |
+
logger.info(f"=== CHAT REQUEST RECEIVED ===")
|
| 504 |
+
logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}")
|
| 505 |
+
logger.info(f"Request headers: {dict(request.headers)}")
|
| 506 |
+
|
| 507 |
+
# SECURITY: Get tenant_id and user_id from auth context
|
| 508 |
+
# In PROD: MUST come from JWT token (never from request body)
|
| 509 |
+
try:
|
| 510 |
+
auth_context = await get_auth_context(request)
|
| 511 |
+
except Exception as e:
|
| 512 |
+
logger.error(f"Error getting auth context: {e}", exc_info=True)
|
| 513 |
+
raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}")
|
| 514 |
+
|
| 515 |
+
tenant_id_from_auth = auth_context.get("tenant_id")
|
| 516 |
+
user_id_from_auth = auth_context.get("user_id")
|
| 517 |
+
|
| 518 |
+
if settings.ENV == "prod":
|
| 519 |
+
if not tenant_id_from_auth or not user_id_from_auth:
|
| 520 |
+
raise HTTPException(
|
| 521 |
+
status_code=403,
|
| 522 |
+
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 523 |
+
)
|
| 524 |
+
# Override request values with auth context (security enforcement)
|
| 525 |
+
chat_request.tenant_id = tenant_id_from_auth
|
| 526 |
+
chat_request.user_id = user_id_from_auth
|
| 527 |
+
else:
|
| 528 |
+
# DEV mode: use from request if provided, otherwise from auth context
|
| 529 |
+
if not chat_request.tenant_id:
|
| 530 |
+
chat_request.tenant_id = tenant_id_from_auth
|
| 531 |
+
if not chat_request.user_id:
|
| 532 |
+
chat_request.user_id = user_id_from_auth
|
| 533 |
+
if not chat_request.tenant_id or not chat_request.user_id:
|
| 534 |
+
raise HTTPException(
|
| 535 |
+
status_code=400,
|
| 536 |
+
detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Log without PII in production
|
| 540 |
+
if settings.ENV == "prod":
|
| 541 |
+
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}")
|
| 542 |
+
else:
|
| 543 |
+
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...")
|
| 544 |
+
|
| 545 |
+
# Generate conversation ID if not provided
|
| 546 |
+
conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}"
|
| 547 |
+
|
| 548 |
+
# Get database session
|
| 549 |
+
try:
|
| 550 |
+
db = next(get_db())
|
| 551 |
+
except Exception as e:
|
| 552 |
+
logger.error(f"Database connection error: {e}", exc_info=True)
|
| 553 |
+
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
|
| 554 |
+
|
| 555 |
+
try:
|
| 556 |
+
# Ensure tenant exists in billing DB
|
| 557 |
+
ensure_tenant_exists(db, chat_request.tenant_id)
|
| 558 |
+
|
| 559 |
+
# Check quota BEFORE making LLM call
|
| 560 |
+
has_quota, quota_error = check_quota(db, chat_request.tenant_id)
|
| 561 |
+
if not has_quota:
|
| 562 |
+
logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}")
|
| 563 |
+
raise HTTPException(
|
| 564 |
+
status_code=402,
|
| 565 |
+
detail=quota_error or "AI quota exceeded. Upgrade your plan."
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Retrieve relevant context
|
| 569 |
+
retrieval_service = get_retrieval_service()
|
| 570 |
+
results, confidence, has_relevant = retrieval_service.retrieve(
|
| 571 |
+
query=chat_request.question,
|
| 572 |
+
tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation
|
| 573 |
+
kb_id=chat_request.kb_id,
|
| 574 |
+
user_id=chat_request.user_id
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}")
|
| 578 |
+
|
| 579 |
+
# Format context for LLM
|
| 580 |
+
context, citations_info = retrieval_service.get_context_for_llm(results)
|
| 581 |
+
|
| 582 |
+
logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}")
|
| 583 |
+
|
| 584 |
+
# Generate answer
|
| 585 |
+
answer_service = get_answer_service()
|
| 586 |
+
answer_result = answer_service.generate_answer(
|
| 587 |
+
question=chat_request.question,
|
| 588 |
+
context=context,
|
| 589 |
+
citations_info=citations_info,
|
| 590 |
+
confidence=confidence,
|
| 591 |
+
has_relevant_results=has_relevant
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# Track usage if LLM was called (usage info present)
|
| 595 |
+
usage_info = answer_result.get("usage")
|
| 596 |
+
if usage_info:
|
| 597 |
+
try:
|
| 598 |
+
track_usage(
|
| 599 |
+
db=db,
|
| 600 |
+
tenant_id=chat_request.tenant_id,
|
| 601 |
+
user_id=chat_request.user_id,
|
| 602 |
+
kb_id=chat_request.kb_id,
|
| 603 |
+
provider=settings.LLM_PROVIDER,
|
| 604 |
+
model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL),
|
| 605 |
+
prompt_tokens=usage_info.get("prompt_tokens", 0),
|
| 606 |
+
completion_tokens=usage_info.get("completion_tokens", 0)
|
| 607 |
+
)
|
| 608 |
+
except Exception as e:
|
| 609 |
+
logger.error(f"Failed to track usage: {e}", exc_info=True)
|
| 610 |
+
# Don't fail the request if usage tracking fails
|
| 611 |
+
|
| 612 |
+
# Build metadata with refusal info
|
| 613 |
+
metadata = {
|
| 614 |
+
"chunks_retrieved": len(results),
|
| 615 |
+
"kb_id": chat_request.kb_id
|
| 616 |
+
}
|
| 617 |
+
if "refused" in answer_result:
|
| 618 |
+
metadata["refused"] = answer_result["refused"]
|
| 619 |
+
if "refusal_reason" in answer_result:
|
| 620 |
+
metadata["refusal_reason"] = answer_result["refusal_reason"]
|
| 621 |
+
if "verifier_passed" in answer_result:
|
| 622 |
+
metadata["verifier_passed"] = answer_result["verifier_passed"]
|
| 623 |
+
|
| 624 |
+
return ChatResponse(
|
| 625 |
+
success=True,
|
| 626 |
+
answer=answer_result["answer"],
|
| 627 |
+
citations=answer_result["citations"],
|
| 628 |
+
confidence=answer_result["confidence"],
|
| 629 |
+
from_knowledge_base=answer_result["from_knowledge_base"],
|
| 630 |
+
escalation_suggested=answer_result["escalation_suggested"],
|
| 631 |
+
conversation_id=conversation_id,
|
| 632 |
+
refused=answer_result.get("refused", False),
|
| 633 |
+
metadata=metadata
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
except ValueError as e:
|
| 637 |
+
# API key or configuration error
|
| 638 |
+
error_msg = str(e)
|
| 639 |
+
logger.error(f"Configuration error: {error_msg}")
|
| 640 |
+
if "API key" in error_msg.lower():
|
| 641 |
+
return ChatResponse(
|
| 642 |
+
success=False,
|
| 643 |
+
answer="⚠️ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.",
|
| 644 |
+
citations=[],
|
| 645 |
+
confidence=0.0,
|
| 646 |
+
from_knowledge_base=False,
|
| 647 |
+
escalation_suggested=True,
|
| 648 |
+
conversation_id=conversation_id,
|
| 649 |
+
metadata={"error": error_msg, "error_type": "configuration"}
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
return ChatResponse(
|
| 653 |
+
success=False,
|
| 654 |
+
answer=f"Configuration error: {error_msg}",
|
| 655 |
+
citations=[],
|
| 656 |
+
confidence=0.0,
|
| 657 |
+
from_knowledge_base=False,
|
| 658 |
+
escalation_suggested=True,
|
| 659 |
+
conversation_id=conversation_id,
|
| 660 |
+
metadata={"error": error_msg}
|
| 661 |
+
)
|
| 662 |
+
except HTTPException:
|
| 663 |
+
# Re-raise HTTP exceptions (they have proper status codes)
|
| 664 |
+
raise
|
| 665 |
+
except Exception as e:
|
| 666 |
+
logger.error(f"Chat error: {e}", exc_info=True)
|
| 667 |
+
logger.error(f"Error type: {type(e).__name__}", exc_info=True)
|
| 668 |
+
return ChatResponse(
|
| 669 |
+
success=False,
|
| 670 |
+
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
|
| 671 |
+
citations=[],
|
| 672 |
+
confidence=0.0,
|
| 673 |
+
from_knowledge_base=False,
|
| 674 |
+
escalation_suggested=True,
|
| 675 |
+
conversation_id=conversation_id,
|
| 676 |
+
metadata={"error": str(e), "error_type": type(e).__name__}
|
| 677 |
+
)
|
| 678 |
+
except HTTPException:
|
| 679 |
+
# Re-raise HTTP exceptions from outer try block
|
| 680 |
+
raise
|
| 681 |
+
except Exception as e:
|
| 682 |
+
logger.error(f"Outer chat error: {e}", exc_info=True)
|
| 683 |
+
return ChatResponse(
|
| 684 |
+
success=False,
|
| 685 |
+
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
|
| 686 |
+
citations=[],
|
| 687 |
+
confidence=0.0,
|
| 688 |
+
from_knowledge_base=False,
|
| 689 |
+
escalation_suggested=True,
|
| 690 |
+
conversation_id=conversation_id,
|
| 691 |
+
metadata={"error": str(e), "error_type": type(e).__name__}
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
# ============== Utility Endpoints ==============
|
| 696 |
+
|
| 697 |
+
@app.get("/kb/search")
|
| 698 |
+
@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key)
|
| 699 |
+
async def search_kb(
|
| 700 |
+
request: Request,
|
| 701 |
+
query: str,
|
| 702 |
+
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 703 |
+
kb_id: Optional[str] = None,
|
| 704 |
+
user_id: Optional[str] = None, # Optional in dev, ignored in prod
|
| 705 |
+
top_k: int = 5
|
| 706 |
+
):
|
| 707 |
+
"""
|
| 708 |
+
Search the knowledge base without generating an answer.
|
| 709 |
+
Useful for debugging and testing retrieval.
|
| 710 |
+
"""
|
| 711 |
+
# SECURITY: Extract tenant_id from auth token in production
|
| 712 |
+
if settings.ENV == "prod":
|
| 713 |
+
auth_context = await require_auth(request)
|
| 714 |
+
tenant_id = auth_context.get("tenant_id")
|
| 715 |
+
user_id = auth_context.get("user_id")
|
| 716 |
+
if not tenant_id or not user_id:
|
| 717 |
+
raise HTTPException(
|
| 718 |
+
status_code=403,
|
| 719 |
+
detail="tenant_id and user_id must come from authentication token in production mode"
|
| 720 |
+
)
|
| 721 |
+
elif not tenant_id or not kb_id or not user_id:
|
| 722 |
+
raise HTTPException(
|
| 723 |
+
status_code=400,
|
| 724 |
+
detail="tenant_id, kb_id, and user_id are required"
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
try:
|
| 728 |
+
retrieval_service = get_retrieval_service()
|
| 729 |
+
results, confidence, has_relevant = retrieval_service.retrieve(
|
| 730 |
+
query=query,
|
| 731 |
+
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
|
| 732 |
+
kb_id=kb_id,
|
| 733 |
+
user_id=user_id,
|
| 734 |
+
top_k=top_k
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
return {
|
| 738 |
+
"success": True,
|
| 739 |
+
"results": [
|
| 740 |
+
{
|
| 741 |
+
"chunk_id": r.chunk_id,
|
| 742 |
+
"content": r.content[:500] + "..." if len(r.content) > 500 else r.content,
|
| 743 |
+
"metadata": r.metadata,
|
| 744 |
+
"similarity_score": r.similarity_score
|
| 745 |
+
}
|
| 746 |
+
for r in results
|
| 747 |
+
],
|
| 748 |
+
"confidence": confidence,
|
| 749 |
+
"has_relevant_results": has_relevant
|
| 750 |
+
}
|
| 751 |
+
except Exception as e:
|
| 752 |
+
logger.error(f"Search error: {e}")
|
| 753 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
# ============== Billing & Usage Endpoints ==============
|
| 757 |
+
|
| 758 |
+
@app.get("/billing/usage", response_model=UsageResponse)
|
| 759 |
+
async def get_usage(
|
| 760 |
+
request: Request,
|
| 761 |
+
range: str = "month", # "day" or "month"
|
| 762 |
+
year: Optional[int] = None,
|
| 763 |
+
month: Optional[int] = None,
|
| 764 |
+
day: Optional[int] = None
|
| 765 |
+
):
|
| 766 |
+
"""
|
| 767 |
+
Get usage statistics for the current tenant.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
range: "day" or "month"
|
| 771 |
+
year: Year (optional, defaults to current)
|
| 772 |
+
month: Month 1-12 (optional, defaults to current)
|
| 773 |
+
day: Day 1-31 (optional, defaults to current, only for range="day")
|
| 774 |
+
"""
|
| 775 |
+
# Get tenant from auth
|
| 776 |
+
auth_context = await get_auth_context(request)
|
| 777 |
+
tenant_id = auth_context.get("tenant_id")
|
| 778 |
+
|
| 779 |
+
if not tenant_id:
|
| 780 |
+
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 781 |
+
|
| 782 |
+
db = next(get_db())
|
| 783 |
+
|
| 784 |
+
try:
|
| 785 |
+
from app.db.models import UsageDaily, UsageMonthly
|
| 786 |
+
from datetime import datetime
|
| 787 |
+
from calendar import monthrange
|
| 788 |
+
|
| 789 |
+
now = datetime.utcnow()
|
| 790 |
+
target_year = year or now.year
|
| 791 |
+
target_month = month or now.month
|
| 792 |
+
|
| 793 |
+
if range == "day":
|
| 794 |
+
target_day = day or now.day
|
| 795 |
+
date_start = datetime(target_year, target_month, target_day)
|
| 796 |
+
|
| 797 |
+
daily = db.query(UsageDaily).filter(
|
| 798 |
+
UsageDaily.tenant_id == tenant_id,
|
| 799 |
+
UsageDaily.date == date_start
|
| 800 |
+
).first()
|
| 801 |
+
|
| 802 |
+
if not daily:
|
| 803 |
+
return UsageResponse(
|
| 804 |
+
tenant_id=tenant_id,
|
| 805 |
+
period="day",
|
| 806 |
+
total_requests=0,
|
| 807 |
+
total_tokens=0,
|
| 808 |
+
total_cost_usd=0.0,
|
| 809 |
+
start_date=date_start,
|
| 810 |
+
end_date=date_start
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
return UsageResponse(
|
| 814 |
+
tenant_id=tenant_id,
|
| 815 |
+
period="day",
|
| 816 |
+
total_requests=daily.total_requests,
|
| 817 |
+
total_tokens=daily.total_tokens,
|
| 818 |
+
total_cost_usd=daily.total_cost_usd,
|
| 819 |
+
gemini_requests=daily.gemini_requests,
|
| 820 |
+
openai_requests=daily.openai_requests,
|
| 821 |
+
start_date=daily.date,
|
| 822 |
+
end_date=daily.date
|
| 823 |
+
)
|
| 824 |
+
else: # month
|
| 825 |
+
monthly = db.query(UsageMonthly).filter(
|
| 826 |
+
UsageMonthly.tenant_id == tenant_id,
|
| 827 |
+
UsageMonthly.year == target_year,
|
| 828 |
+
UsageMonthly.month == target_month
|
| 829 |
+
).first()
|
| 830 |
+
|
| 831 |
+
if not monthly:
|
| 832 |
+
# Calculate date range for the month
|
| 833 |
+
_, last_day = monthrange(target_year, target_month)
|
| 834 |
+
start_date = datetime(target_year, target_month, 1)
|
| 835 |
+
end_date = datetime(target_year, target_month, last_day)
|
| 836 |
+
|
| 837 |
+
return UsageResponse(
|
| 838 |
+
tenant_id=tenant_id,
|
| 839 |
+
period="month",
|
| 840 |
+
total_requests=0,
|
| 841 |
+
total_tokens=0,
|
| 842 |
+
total_cost_usd=0.0,
|
| 843 |
+
start_date=start_date,
|
| 844 |
+
end_date=end_date
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
_, last_day = monthrange(monthly.year, monthly.month)
|
| 848 |
+
start_date = datetime(monthly.year, monthly.month, 1)
|
| 849 |
+
end_date = datetime(monthly.year, monthly.month, last_day)
|
| 850 |
+
|
| 851 |
+
return UsageResponse(
|
| 852 |
+
tenant_id=tenant_id,
|
| 853 |
+
period="month",
|
| 854 |
+
total_requests=monthly.total_requests,
|
| 855 |
+
total_tokens=monthly.total_tokens,
|
| 856 |
+
total_cost_usd=monthly.total_cost_usd,
|
| 857 |
+
gemini_requests=monthly.gemini_requests,
|
| 858 |
+
openai_requests=monthly.openai_requests,
|
| 859 |
+
start_date=start_date,
|
| 860 |
+
end_date=end_date
|
| 861 |
+
)
|
| 862 |
+
except Exception as e:
|
| 863 |
+
logger.error(f"Error getting usage: {e}", exc_info=True)
|
| 864 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
@app.get("/billing/limits", response_model=PlanLimitsResponse)
|
| 868 |
+
async def get_limits(request: Request):
|
| 869 |
+
"""Get current plan limits and usage for the tenant."""
|
| 870 |
+
# Get tenant from auth
|
| 871 |
+
auth_context = await get_auth_context(request)
|
| 872 |
+
tenant_id = auth_context.get("tenant_id")
|
| 873 |
+
|
| 874 |
+
if not tenant_id:
|
| 875 |
+
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 876 |
+
|
| 877 |
+
db = next(get_db())
|
| 878 |
+
|
| 879 |
+
try:
|
| 880 |
+
from app.billing.quota import get_tenant_plan, get_monthly_usage
|
| 881 |
+
from datetime import datetime
|
| 882 |
+
|
| 883 |
+
plan = get_tenant_plan(db, tenant_id)
|
| 884 |
+
if not plan:
|
| 885 |
+
# Default to starter
|
| 886 |
+
plan_name = "starter"
|
| 887 |
+
monthly_limit = 500
|
| 888 |
+
else:
|
| 889 |
+
plan_name = plan.plan_name
|
| 890 |
+
monthly_limit = plan.monthly_chat_limit
|
| 891 |
+
|
| 892 |
+
# Get current month usage
|
| 893 |
+
now = datetime.utcnow()
|
| 894 |
+
monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
|
| 895 |
+
current_usage = monthly_usage.total_requests if monthly_usage else 0
|
| 896 |
+
|
| 897 |
+
remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage)
|
| 898 |
+
|
| 899 |
+
return PlanLimitsResponse(
|
| 900 |
+
tenant_id=tenant_id,
|
| 901 |
+
plan_name=plan_name,
|
| 902 |
+
monthly_chat_limit=monthly_limit,
|
| 903 |
+
current_month_usage=current_usage,
|
| 904 |
+
remaining_chats=remaining
|
| 905 |
+
)
|
| 906 |
+
except Exception as e:
|
| 907 |
+
logger.error(f"Error getting limits: {e}", exc_info=True)
|
| 908 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@app.post("/billing/plan")
|
| 912 |
+
async def set_plan(request_body: SetPlanRequest, http_request: Request):
|
| 913 |
+
"""
|
| 914 |
+
Set tenant's subscription plan (admin only in production).
|
| 915 |
+
|
| 916 |
+
In dev mode, allows any tenant to set their plan.
|
| 917 |
+
In prod mode, should be restricted to admin users.
|
| 918 |
+
"""
|
| 919 |
+
# Get tenant from auth
|
| 920 |
+
auth_context = await get_auth_context(http_request)
|
| 921 |
+
auth_tenant_id = auth_context.get("tenant_id")
|
| 922 |
+
|
| 923 |
+
# In prod, verify admin role (placeholder - implement actual admin check)
|
| 924 |
+
if settings.ENV == "prod":
|
| 925 |
+
# TODO: Add admin role check
|
| 926 |
+
if auth_tenant_id != request_body.tenant_id:
|
| 927 |
+
raise HTTPException(status_code=403, detail="Cannot set plan for other tenants")
|
| 928 |
+
|
| 929 |
+
db = next(get_db())
|
| 930 |
+
|
| 931 |
+
try:
|
| 932 |
+
from app.billing.quota import set_tenant_plan
|
| 933 |
+
|
| 934 |
+
plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name)
|
| 935 |
+
|
| 936 |
+
return {
|
| 937 |
+
"success": True,
|
| 938 |
+
"tenant_id": request_body.tenant_id,
|
| 939 |
+
"plan_name": plan.plan_name,
|
| 940 |
+
"monthly_chat_limit": plan.monthly_chat_limit
|
| 941 |
+
}
|
| 942 |
+
except ValueError as e:
|
| 943 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 944 |
+
except Exception as e:
|
| 945 |
+
logger.error(f"Error setting plan: {e}", exc_info=True)
|
| 946 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
@app.get("/billing/cost-report", response_model=CostReportResponse)
|
| 950 |
+
async def get_cost_report(
|
| 951 |
+
request: Request,
|
| 952 |
+
range: str = "month",
|
| 953 |
+
year: Optional[int] = None,
|
| 954 |
+
month: Optional[int] = None
|
| 955 |
+
):
|
| 956 |
+
"""Get cost report with breakdown by provider and model."""
|
| 957 |
+
# Get tenant from auth
|
| 958 |
+
auth_context = await get_auth_context(request)
|
| 959 |
+
tenant_id = auth_context.get("tenant_id")
|
| 960 |
+
|
| 961 |
+
if not tenant_id:
|
| 962 |
+
raise HTTPException(status_code=403, detail="tenant_id required")
|
| 963 |
+
|
| 964 |
+
db = next(get_db())
|
| 965 |
+
|
| 966 |
+
try:
|
| 967 |
+
from app.db.models import UsageEvent
|
| 968 |
+
from datetime import datetime
|
| 969 |
+
from sqlalchemy import func, and_
|
| 970 |
+
|
| 971 |
+
now = datetime.utcnow()
|
| 972 |
+
target_year = year or now.year
|
| 973 |
+
target_month = month or now.month
|
| 974 |
+
|
| 975 |
+
# Query usage events for the period
|
| 976 |
+
if range == "month":
|
| 977 |
+
query = db.query(UsageEvent).filter(
|
| 978 |
+
and_(
|
| 979 |
+
UsageEvent.tenant_id == tenant_id,
|
| 980 |
+
func.extract('year', UsageEvent.request_timestamp) == target_year,
|
| 981 |
+
func.extract('month', UsageEvent.request_timestamp) == target_month
|
| 982 |
+
)
|
| 983 |
+
)
|
| 984 |
+
else: # all time
|
| 985 |
+
query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id)
|
| 986 |
+
|
| 987 |
+
events = query.all()
|
| 988 |
+
|
| 989 |
+
# Calculate totals
|
| 990 |
+
total_cost = sum(e.estimated_cost_usd for e in events)
|
| 991 |
+
total_requests = len(events)
|
| 992 |
+
total_tokens = sum(e.total_tokens for e in events)
|
| 993 |
+
|
| 994 |
+
# Breakdown by provider
|
| 995 |
+
breakdown_by_provider = {}
|
| 996 |
+
for event in events:
|
| 997 |
+
provider = event.provider
|
| 998 |
+
if provider not in breakdown_by_provider:
|
| 999 |
+
breakdown_by_provider[provider] = {
|
| 1000 |
+
"requests": 0,
|
| 1001 |
+
"tokens": 0,
|
| 1002 |
+
"cost_usd": 0.0
|
| 1003 |
+
}
|
| 1004 |
+
breakdown_by_provider[provider]["requests"] += 1
|
| 1005 |
+
breakdown_by_provider[provider]["tokens"] += event.total_tokens
|
| 1006 |
+
breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd
|
| 1007 |
+
|
| 1008 |
+
# Breakdown by model
|
| 1009 |
+
breakdown_by_model = {}
|
| 1010 |
+
for event in events:
|
| 1011 |
+
model = event.model
|
| 1012 |
+
if model not in breakdown_by_model:
|
| 1013 |
+
breakdown_by_model[model] = {
|
| 1014 |
+
"requests": 0,
|
| 1015 |
+
"tokens": 0,
|
| 1016 |
+
"cost_usd": 0.0
|
| 1017 |
+
}
|
| 1018 |
+
breakdown_by_model[model]["requests"] += 1
|
| 1019 |
+
breakdown_by_model[model]["tokens"] += event.total_tokens
|
| 1020 |
+
breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd
|
| 1021 |
+
|
| 1022 |
+
return CostReportResponse(
|
| 1023 |
+
tenant_id=tenant_id,
|
| 1024 |
+
period=range,
|
| 1025 |
+
total_cost_usd=total_cost,
|
| 1026 |
+
total_requests=total_requests,
|
| 1027 |
+
total_tokens=total_tokens,
|
| 1028 |
+
breakdown_by_provider=breakdown_by_provider,
|
| 1029 |
+
breakdown_by_model=breakdown_by_model
|
| 1030 |
+
)
|
| 1031 |
+
except Exception as e:
|
| 1032 |
+
logger.error(f"Error getting cost report: {e}", exc_info=True)
|
| 1033 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
if __name__ == "__main__":
|
| 1037 |
+
import uvicorn
|
| 1038 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 1039 |
+
|
app/middleware/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Middleware for authentication, rate limiting, etc.
|
| 3 |
-
"""
|
| 4 |
-
from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth
|
| 5 |
-
|
| 6 |
-
__all__ = [
|
| 7 |
-
"verify_tenant_access",
|
| 8 |
-
"get_tenant_from_token",
|
| 9 |
-
"require_auth",
|
| 10 |
-
]
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Middleware for authentication, rate limiting, etc.
|
| 3 |
+
"""
|
| 4 |
+
from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"verify_tenant_access",
|
| 8 |
+
"get_tenant_from_token",
|
| 9 |
+
"require_auth",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
app/middleware/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (370 Bytes)
|
|
|
app/middleware/__pycache__/auth.cpython-313.pyc
DELETED
|
Binary file (6.5 kB)
|
|
|
app/middleware/__pycache__/rate_limit.cpython-313.pyc
DELETED
|
Binary file (1.45 kB)
|
|
|
app/middleware/auth.py
CHANGED
|
@@ -1,212 +1,212 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Authentication and authorization middleware.
|
| 3 |
-
Extracts tenant_id from JWT token in production mode.
|
| 4 |
-
"""
|
| 5 |
-
from fastapi import Request, HTTPException, Depends
|
| 6 |
-
from typing import Optional, Dict, Any
|
| 7 |
-
import logging
|
| 8 |
-
from jose import JWTError, jwt
|
| 9 |
-
|
| 10 |
-
from app.config import settings
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
async def verify_tenant_access(
|
| 16 |
-
request: Request,
|
| 17 |
-
tenant_id: str,
|
| 18 |
-
user_id: str
|
| 19 |
-
) -> bool:
|
| 20 |
-
"""
|
| 21 |
-
Verify that the user has access to the specified tenant.
|
| 22 |
-
|
| 23 |
-
TODO: Implement actual authentication logic:
|
| 24 |
-
1. Extract JWT token from Authorization header
|
| 25 |
-
2. Verify token signature
|
| 26 |
-
3. Extract user_id and tenant_id from token
|
| 27 |
-
4. Verify user belongs to tenant
|
| 28 |
-
5. Check permissions
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
request: FastAPI request object
|
| 32 |
-
tenant_id: Tenant ID from request
|
| 33 |
-
user_id: User ID from request
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
True if access is granted, False otherwise
|
| 37 |
-
"""
|
| 38 |
-
# TODO: Implement actual authentication
|
| 39 |
-
# For now, this is a placeholder that always returns True
|
| 40 |
-
# In production, you MUST implement proper auth
|
| 41 |
-
|
| 42 |
-
# Example implementation:
|
| 43 |
-
# token = request.headers.get("Authorization", "").replace("Bearer ", "")
|
| 44 |
-
# if not token:
|
| 45 |
-
# return False
|
| 46 |
-
#
|
| 47 |
-
# decoded = verify_jwt_token(token)
|
| 48 |
-
# if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id:
|
| 49 |
-
# return False
|
| 50 |
-
#
|
| 51 |
-
# return True
|
| 52 |
-
|
| 53 |
-
logger.warning("⚠️ Authentication middleware not implemented - using placeholder")
|
| 54 |
-
return True
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def get_tenant_from_token(request: Request) -> Optional[str]:
|
| 58 |
-
"""
|
| 59 |
-
Extract tenant_id from authentication token.
|
| 60 |
-
|
| 61 |
-
In production mode, extracts tenant_id from JWT token.
|
| 62 |
-
In dev mode, returns None (allows request tenant_id).
|
| 63 |
-
|
| 64 |
-
Args:
|
| 65 |
-
request: FastAPI request object
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
Tenant ID if found in token, None otherwise
|
| 69 |
-
"""
|
| 70 |
-
if settings.ENV == "dev":
|
| 71 |
-
# Dev mode: allow request tenant_id
|
| 72 |
-
return None
|
| 73 |
-
|
| 74 |
-
# Production mode: extract from JWT
|
| 75 |
-
auth_header = request.headers.get("Authorization", "")
|
| 76 |
-
if not auth_header.startswith("Bearer "):
|
| 77 |
-
logger.warning("Missing or invalid Authorization header")
|
| 78 |
-
return None
|
| 79 |
-
|
| 80 |
-
token = auth_header.replace("Bearer ", "").strip()
|
| 81 |
-
if not token:
|
| 82 |
-
return None
|
| 83 |
-
|
| 84 |
-
try:
|
| 85 |
-
# TODO: Replace with your actual JWT secret key
|
| 86 |
-
# For now, this is a placeholder that expects a specific token format
|
| 87 |
-
# In production, you should:
|
| 88 |
-
# 1. Get JWT_SECRET from environment
|
| 89 |
-
# 2. Verify token signature
|
| 90 |
-
# 3. Extract tenant_id from token payload
|
| 91 |
-
|
| 92 |
-
# Example implementation (replace with your actual JWT verification):
|
| 93 |
-
# JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key")
|
| 94 |
-
# decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
| 95 |
-
# return decoded.get("tenant_id")
|
| 96 |
-
|
| 97 |
-
# Placeholder: Try to decode without verification (for testing)
|
| 98 |
-
# In production, you MUST verify the signature
|
| 99 |
-
try:
|
| 100 |
-
decoded = jwt.decode(token, options={"verify_signature": False})
|
| 101 |
-
tenant_id = decoded.get("tenant_id")
|
| 102 |
-
if tenant_id:
|
| 103 |
-
logger.info(f"Extracted tenant_id from token: {tenant_id}")
|
| 104 |
-
return tenant_id
|
| 105 |
-
except jwt.DecodeError:
|
| 106 |
-
logger.warning("Failed to decode JWT token")
|
| 107 |
-
return None
|
| 108 |
-
|
| 109 |
-
except Exception as e:
|
| 110 |
-
logger.error(f"Error extracting tenant from token: {e}")
|
| 111 |
-
return None
|
| 112 |
-
|
| 113 |
-
return None
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
async def get_auth_context(request: Request) -> Dict[str, Any]:
|
| 117 |
-
"""
|
| 118 |
-
Get authentication context from request.
|
| 119 |
-
|
| 120 |
-
DEV mode:
|
| 121 |
-
- Allows X-Tenant-Id and X-User-Id headers
|
| 122 |
-
- Falls back to defaults if missing
|
| 123 |
-
|
| 124 |
-
PROD mode:
|
| 125 |
-
- Requires Authorization: Bearer <JWT>
|
| 126 |
-
- Verifies JWT using JWT_SECRET
|
| 127 |
-
- Extracts tenant_id and user_id from token claims
|
| 128 |
-
- NEVER accepts tenant_id from request body/query params
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
request: FastAPI request object
|
| 132 |
-
|
| 133 |
-
Returns:
|
| 134 |
-
Dictionary with user_id and tenant_id
|
| 135 |
-
|
| 136 |
-
Raises:
|
| 137 |
-
HTTPException: If authentication fails (production mode only)
|
| 138 |
-
"""
|
| 139 |
-
if settings.ENV == "dev":
|
| 140 |
-
# Dev mode: allow headers, fallback to defaults
|
| 141 |
-
tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant")
|
| 142 |
-
user_id = request.headers.get("X-User-Id", "dev_user")
|
| 143 |
-
return {
|
| 144 |
-
"user_id": user_id,
|
| 145 |
-
"tenant_id": tenant_id
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
# Production mode: require JWT token
|
| 149 |
-
auth_header = request.headers.get("Authorization")
|
| 150 |
-
if not auth_header or not auth_header.startswith("Bearer "):
|
| 151 |
-
raise HTTPException(
|
| 152 |
-
status_code=401,
|
| 153 |
-
detail="Authentication required. Provide valid Bearer token in Authorization header."
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
token = auth_header.replace("Bearer ", "").strip()
|
| 157 |
-
if not token:
|
| 158 |
-
raise HTTPException(
|
| 159 |
-
status_code=401,
|
| 160 |
-
detail="Invalid token format."
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Verify JWT token
|
| 164 |
-
if not settings.JWT_SECRET:
|
| 165 |
-
logger.error("JWT_SECRET not configured in production mode")
|
| 166 |
-
raise HTTPException(
|
| 167 |
-
status_code=500,
|
| 168 |
-
detail="Server configuration error: JWT_SECRET not set"
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
try:
|
| 172 |
-
decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
|
| 173 |
-
|
| 174 |
-
user_id = decoded.get("user_id") or decoded.get("sub")
|
| 175 |
-
tenant_id = decoded.get("tenant_id")
|
| 176 |
-
|
| 177 |
-
if not user_id or not tenant_id:
|
| 178 |
-
raise HTTPException(
|
| 179 |
-
status_code=401,
|
| 180 |
-
detail="Token missing required claims (user_id, tenant_id)."
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}")
|
| 184 |
-
return {
|
| 185 |
-
"user_id": user_id,
|
| 186 |
-
"tenant_id": tenant_id,
|
| 187 |
-
"email": decoded.get("email"),
|
| 188 |
-
"role": decoded.get("role")
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
except JWTError as e:
|
| 192 |
-
logger.warning(f"JWT verification failed: {e}")
|
| 193 |
-
raise HTTPException(
|
| 194 |
-
status_code=401,
|
| 195 |
-
detail="Invalid or expired token."
|
| 196 |
-
)
|
| 197 |
-
except Exception as e:
|
| 198 |
-
logger.error(f"Auth error: {e}", exc_info=True)
|
| 199 |
-
raise HTTPException(
|
| 200 |
-
status_code=401,
|
| 201 |
-
detail="Authentication failed."
|
| 202 |
-
)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
# FastAPI dependency for easy use in endpoints
|
| 206 |
-
async def require_auth(request: Request) -> Dict[str, Any]:
|
| 207 |
-
"""
|
| 208 |
-
FastAPI dependency for requiring authentication.
|
| 209 |
-
Alias for get_auth_context for backward compatibility.
|
| 210 |
-
"""
|
| 211 |
-
return await get_auth_context(request)
|
| 212 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication and authorization middleware.
|
| 3 |
+
Extracts tenant_id from JWT token in production mode.
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import Request, HTTPException, Depends
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
import logging
|
| 8 |
+
from jose import JWTError, jwt
|
| 9 |
+
|
| 10 |
+
from app.config import settings
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def verify_tenant_access(
|
| 16 |
+
request: Request,
|
| 17 |
+
tenant_id: str,
|
| 18 |
+
user_id: str
|
| 19 |
+
) -> bool:
|
| 20 |
+
"""
|
| 21 |
+
Verify that the user has access to the specified tenant.
|
| 22 |
+
|
| 23 |
+
TODO: Implement actual authentication logic:
|
| 24 |
+
1. Extract JWT token from Authorization header
|
| 25 |
+
2. Verify token signature
|
| 26 |
+
3. Extract user_id and tenant_id from token
|
| 27 |
+
4. Verify user belongs to tenant
|
| 28 |
+
5. Check permissions
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
request: FastAPI request object
|
| 32 |
+
tenant_id: Tenant ID from request
|
| 33 |
+
user_id: User ID from request
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
True if access is granted, False otherwise
|
| 37 |
+
"""
|
| 38 |
+
# TODO: Implement actual authentication
|
| 39 |
+
# For now, this is a placeholder that always returns True
|
| 40 |
+
# In production, you MUST implement proper auth
|
| 41 |
+
|
| 42 |
+
# Example implementation:
|
| 43 |
+
# token = request.headers.get("Authorization", "").replace("Bearer ", "")
|
| 44 |
+
# if not token:
|
| 45 |
+
# return False
|
| 46 |
+
#
|
| 47 |
+
# decoded = verify_jwt_token(token)
|
| 48 |
+
# if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id:
|
| 49 |
+
# return False
|
| 50 |
+
#
|
| 51 |
+
# return True
|
| 52 |
+
|
| 53 |
+
logger.warning("⚠️ Authentication middleware not implemented - using placeholder")
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_tenant_from_token(request: Request) -> Optional[str]:
|
| 58 |
+
"""
|
| 59 |
+
Extract tenant_id from authentication token.
|
| 60 |
+
|
| 61 |
+
In production mode, extracts tenant_id from JWT token.
|
| 62 |
+
In dev mode, returns None (allows request tenant_id).
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
request: FastAPI request object
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tenant ID if found in token, None otherwise
|
| 69 |
+
"""
|
| 70 |
+
if settings.ENV == "dev":
|
| 71 |
+
# Dev mode: allow request tenant_id
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
# Production mode: extract from JWT
|
| 75 |
+
auth_header = request.headers.get("Authorization", "")
|
| 76 |
+
if not auth_header.startswith("Bearer "):
|
| 77 |
+
logger.warning("Missing or invalid Authorization header")
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
token = auth_header.replace("Bearer ", "").strip()
|
| 81 |
+
if not token:
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# TODO: Replace with your actual JWT secret key
|
| 86 |
+
# For now, this is a placeholder that expects a specific token format
|
| 87 |
+
# In production, you should:
|
| 88 |
+
# 1. Get JWT_SECRET from environment
|
| 89 |
+
# 2. Verify token signature
|
| 90 |
+
# 3. Extract tenant_id from token payload
|
| 91 |
+
|
| 92 |
+
# Example implementation (replace with your actual JWT verification):
|
| 93 |
+
# JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key")
|
| 94 |
+
# decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
| 95 |
+
# return decoded.get("tenant_id")
|
| 96 |
+
|
| 97 |
+
# Placeholder: Try to decode without verification (for testing)
|
| 98 |
+
# In production, you MUST verify the signature
|
| 99 |
+
try:
|
| 100 |
+
decoded = jwt.decode(token, options={"verify_signature": False})
|
| 101 |
+
tenant_id = decoded.get("tenant_id")
|
| 102 |
+
if tenant_id:
|
| 103 |
+
logger.info(f"Extracted tenant_id from token: {tenant_id}")
|
| 104 |
+
return tenant_id
|
| 105 |
+
except jwt.DecodeError:
|
| 106 |
+
logger.warning("Failed to decode JWT token")
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Error extracting tenant from token: {e}")
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def get_auth_context(request: Request) -> Dict[str, Any]:
|
| 117 |
+
"""
|
| 118 |
+
Get authentication context from request.
|
| 119 |
+
|
| 120 |
+
DEV mode:
|
| 121 |
+
- Allows X-Tenant-Id and X-User-Id headers
|
| 122 |
+
- Falls back to defaults if missing
|
| 123 |
+
|
| 124 |
+
PROD mode:
|
| 125 |
+
- Requires Authorization: Bearer <JWT>
|
| 126 |
+
- Verifies JWT using JWT_SECRET
|
| 127 |
+
- Extracts tenant_id and user_id from token claims
|
| 128 |
+
- NEVER accepts tenant_id from request body/query params
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
request: FastAPI request object
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Dictionary with user_id and tenant_id
|
| 135 |
+
|
| 136 |
+
Raises:
|
| 137 |
+
HTTPException: If authentication fails (production mode only)
|
| 138 |
+
"""
|
| 139 |
+
if settings.ENV == "dev":
|
| 140 |
+
# Dev mode: allow headers, fallback to defaults
|
| 141 |
+
tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant")
|
| 142 |
+
user_id = request.headers.get("X-User-Id", "dev_user")
|
| 143 |
+
return {
|
| 144 |
+
"user_id": user_id,
|
| 145 |
+
"tenant_id": tenant_id
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Production mode: require JWT token
|
| 149 |
+
auth_header = request.headers.get("Authorization")
|
| 150 |
+
if not auth_header or not auth_header.startswith("Bearer "):
|
| 151 |
+
raise HTTPException(
|
| 152 |
+
status_code=401,
|
| 153 |
+
detail="Authentication required. Provide valid Bearer token in Authorization header."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
token = auth_header.replace("Bearer ", "").strip()
|
| 157 |
+
if not token:
|
| 158 |
+
raise HTTPException(
|
| 159 |
+
status_code=401,
|
| 160 |
+
detail="Invalid token format."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Verify JWT token
|
| 164 |
+
if not settings.JWT_SECRET:
|
| 165 |
+
logger.error("JWT_SECRET not configured in production mode")
|
| 166 |
+
raise HTTPException(
|
| 167 |
+
status_code=500,
|
| 168 |
+
detail="Server configuration error: JWT_SECRET not set"
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
|
| 173 |
+
|
| 174 |
+
user_id = decoded.get("user_id") or decoded.get("sub")
|
| 175 |
+
tenant_id = decoded.get("tenant_id")
|
| 176 |
+
|
| 177 |
+
if not user_id or not tenant_id:
|
| 178 |
+
raise HTTPException(
|
| 179 |
+
status_code=401,
|
| 180 |
+
detail="Token missing required claims (user_id, tenant_id)."
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}")
|
| 184 |
+
return {
|
| 185 |
+
"user_id": user_id,
|
| 186 |
+
"tenant_id": tenant_id,
|
| 187 |
+
"email": decoded.get("email"),
|
| 188 |
+
"role": decoded.get("role")
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
except JWTError as e:
|
| 192 |
+
logger.warning(f"JWT verification failed: {e}")
|
| 193 |
+
raise HTTPException(
|
| 194 |
+
status_code=401,
|
| 195 |
+
detail="Invalid or expired token."
|
| 196 |
+
)
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error(f"Auth error: {e}", exc_info=True)
|
| 199 |
+
raise HTTPException(
|
| 200 |
+
status_code=401,
|
| 201 |
+
detail="Authentication failed."
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# FastAPI dependency for easy use in endpoints
|
| 206 |
+
async def require_auth(request: Request) -> Dict[str, Any]:
|
| 207 |
+
"""
|
| 208 |
+
FastAPI dependency for requiring authentication.
|
| 209 |
+
Alias for get_auth_context for backward compatibility.
|
| 210 |
+
"""
|
| 211 |
+
return await get_auth_context(request)
|
| 212 |
+
|
app/middleware/rate_limit.py
CHANGED
|
@@ -1,40 +1,40 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Rate limiting middleware using slowapi.
|
| 3 |
-
"""
|
| 4 |
-
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 5 |
-
from slowapi.util import get_remote_address
|
| 6 |
-
from slowapi.errors import RateLimitExceeded
|
| 7 |
-
from fastapi import Request
|
| 8 |
-
import logging
|
| 9 |
-
|
| 10 |
-
from app.config import settings
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
# Initialize limiter with default limits (can be overridden per endpoint)
|
| 15 |
-
limiter = Limiter(
|
| 16 |
-
key_func=get_remote_address,
|
| 17 |
-
default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else []
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_tenant_rate_limit_key(request: Request) -> str:
|
| 22 |
-
"""
|
| 23 |
-
Get rate limit key based on tenant_id from headers (dev) or IP (prod).
|
| 24 |
-
|
| 25 |
-
Note: This is a sync function called by slowapi, so we can't await async functions.
|
| 26 |
-
In dev mode, we extract tenant_id from X-Tenant-Id header.
|
| 27 |
-
In prod mode, we fall back to IP address (rate limiting happens before auth).
|
| 28 |
-
"""
|
| 29 |
-
# Try to get tenant_id from headers (works in dev mode)
|
| 30 |
-
tenant_id = request.headers.get("X-Tenant-Id")
|
| 31 |
-
if tenant_id:
|
| 32 |
-
return f"tenant:{tenant_id}"
|
| 33 |
-
|
| 34 |
-
# Fallback to IP address (for prod mode or if no header)
|
| 35 |
-
return get_remote_address(request)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# Export limiter and key function
|
| 39 |
-
__all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"]
|
| 40 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rate limiting middleware using slowapi.
|
| 3 |
+
"""
|
| 4 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 5 |
+
from slowapi.util import get_remote_address
|
| 6 |
+
from slowapi.errors import RateLimitExceeded
|
| 7 |
+
from fastapi import Request
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from app.config import settings
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Initialize limiter with default limits (can be overridden per endpoint)
|
| 15 |
+
limiter = Limiter(
|
| 16 |
+
key_func=get_remote_address,
|
| 17 |
+
default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else []
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_tenant_rate_limit_key(request: Request) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Get rate limit key based on tenant_id from headers (dev) or IP (prod).
|
| 24 |
+
|
| 25 |
+
Note: This is a sync function called by slowapi, so we can't await async functions.
|
| 26 |
+
In dev mode, we extract tenant_id from X-Tenant-Id header.
|
| 27 |
+
In prod mode, we fall back to IP address (rate limiting happens before auth).
|
| 28 |
+
"""
|
| 29 |
+
# Try to get tenant_id from headers (works in dev mode)
|
| 30 |
+
tenant_id = request.headers.get("X-Tenant-Id")
|
| 31 |
+
if tenant_id:
|
| 32 |
+
return f"tenant:{tenant_id}"
|
| 33 |
+
|
| 34 |
+
# Fallback to IP address (for prod mode or if no header)
|
| 35 |
+
return get_remote_address(request)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Export limiter and key function
|
| 39 |
+
__all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"]
|
| 40 |
+
|
app/models/__init__.py
CHANGED
|
@@ -1,33 +1,33 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pydantic models for the RAG backend.
|
| 3 |
-
"""
|
| 4 |
-
from app.models.schemas import (
|
| 5 |
-
DocumentStatus,
|
| 6 |
-
ChunkMetadata,
|
| 7 |
-
DocumentChunk,
|
| 8 |
-
UploadRequest,
|
| 9 |
-
UploadResponse,
|
| 10 |
-
Citation,
|
| 11 |
-
ChatRequest,
|
| 12 |
-
ChatResponse,
|
| 13 |
-
RetrievalResult,
|
| 14 |
-
KnowledgeBaseStats,
|
| 15 |
-
HealthResponse,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
__all__ = [
|
| 19 |
-
"DocumentStatus",
|
| 20 |
-
"ChunkMetadata",
|
| 21 |
-
"DocumentChunk",
|
| 22 |
-
"UploadRequest",
|
| 23 |
-
"UploadResponse",
|
| 24 |
-
"Citation",
|
| 25 |
-
"ChatRequest",
|
| 26 |
-
"ChatResponse",
|
| 27 |
-
"RetrievalResult",
|
| 28 |
-
"KnowledgeBaseStats",
|
| 29 |
-
"HealthResponse",
|
| 30 |
-
]
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for the RAG backend.
|
| 3 |
+
"""
|
| 4 |
+
from app.models.schemas import (
|
| 5 |
+
DocumentStatus,
|
| 6 |
+
ChunkMetadata,
|
| 7 |
+
DocumentChunk,
|
| 8 |
+
UploadRequest,
|
| 9 |
+
UploadResponse,
|
| 10 |
+
Citation,
|
| 11 |
+
ChatRequest,
|
| 12 |
+
ChatResponse,
|
| 13 |
+
RetrievalResult,
|
| 14 |
+
KnowledgeBaseStats,
|
| 15 |
+
HealthResponse,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"DocumentStatus",
|
| 20 |
+
"ChunkMetadata",
|
| 21 |
+
"DocumentChunk",
|
| 22 |
+
"UploadRequest",
|
| 23 |
+
"UploadResponse",
|
| 24 |
+
"Citation",
|
| 25 |
+
"ChatRequest",
|
| 26 |
+
"ChatResponse",
|
| 27 |
+
"RetrievalResult",
|
| 28 |
+
"KnowledgeBaseStats",
|
| 29 |
+
"HealthResponse",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
app/models/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (542 Bytes)
|
|
|
app/models/__pycache__/billing_schemas.cpython-313.pyc
DELETED
|
Binary file (2.09 kB)
|
|
|
app/models/__pycache__/schemas.cpython-313.pyc
DELETED
|
Binary file (5.26 kB)
|
|
|
app/models/billing_schemas.py
CHANGED
|
@@ -1,46 +1,46 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pydantic schemas for billing endpoints.
|
| 3 |
-
"""
|
| 4 |
-
from pydantic import BaseModel
|
| 5 |
-
from typing import Optional, List
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class UsageResponse(BaseModel):
|
| 10 |
-
"""Usage statistics response."""
|
| 11 |
-
tenant_id: str
|
| 12 |
-
period: str # "day" or "month"
|
| 13 |
-
total_requests: int
|
| 14 |
-
total_tokens: int
|
| 15 |
-
total_cost_usd: float
|
| 16 |
-
gemini_requests: int = 0
|
| 17 |
-
openai_requests: int = 0
|
| 18 |
-
start_date: datetime
|
| 19 |
-
end_date: datetime
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class PlanLimitsResponse(BaseModel):
|
| 23 |
-
"""Current plan limits response."""
|
| 24 |
-
tenant_id: str
|
| 25 |
-
plan_name: str
|
| 26 |
-
monthly_chat_limit: int # -1 for unlimited
|
| 27 |
-
current_month_usage: int
|
| 28 |
-
remaining_chats: Optional[int] # None if unlimited
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class CostReportResponse(BaseModel):
|
| 32 |
-
"""Cost report response."""
|
| 33 |
-
tenant_id: str
|
| 34 |
-
period: str
|
| 35 |
-
total_cost_usd: float
|
| 36 |
-
total_requests: int
|
| 37 |
-
total_tokens: int
|
| 38 |
-
breakdown_by_provider: dict
|
| 39 |
-
breakdown_by_model: dict
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class SetPlanRequest(BaseModel):
|
| 43 |
-
"""Request to set tenant plan."""
|
| 44 |
-
tenant_id: str
|
| 45 |
-
plan_name: str # "starter", "growth", or "pro"
|
| 46 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for billing endpoints.
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UsageResponse(BaseModel):
|
| 10 |
+
"""Usage statistics response."""
|
| 11 |
+
tenant_id: str
|
| 12 |
+
period: str # "day" or "month"
|
| 13 |
+
total_requests: int
|
| 14 |
+
total_tokens: int
|
| 15 |
+
total_cost_usd: float
|
| 16 |
+
gemini_requests: int = 0
|
| 17 |
+
openai_requests: int = 0
|
| 18 |
+
start_date: datetime
|
| 19 |
+
end_date: datetime
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PlanLimitsResponse(BaseModel):
|
| 23 |
+
"""Current plan limits response."""
|
| 24 |
+
tenant_id: str
|
| 25 |
+
plan_name: str
|
| 26 |
+
monthly_chat_limit: int # -1 for unlimited
|
| 27 |
+
current_month_usage: int
|
| 28 |
+
remaining_chats: Optional[int] # None if unlimited
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CostReportResponse(BaseModel):
|
| 32 |
+
"""Cost report response."""
|
| 33 |
+
tenant_id: str
|
| 34 |
+
period: str
|
| 35 |
+
total_cost_usd: float
|
| 36 |
+
total_requests: int
|
| 37 |
+
total_tokens: int
|
| 38 |
+
breakdown_by_provider: dict
|
| 39 |
+
breakdown_by_model: dict
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SetPlanRequest(BaseModel):
|
| 43 |
+
"""Request to set tenant plan."""
|
| 44 |
+
tenant_id: str
|
| 45 |
+
plan_name: str # "starter", "growth", or "pro"
|
| 46 |
+
|
app/models/schemas.py
CHANGED
|
@@ -1,112 +1,112 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pydantic models for API request/response schemas.
|
| 3 |
-
"""
|
| 4 |
-
from pydantic import BaseModel, Field
|
| 5 |
-
from typing import List, Optional, Dict, Any
|
| 6 |
-
from datetime import datetime
|
| 7 |
-
from enum import Enum
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class DocumentStatus(str, Enum):
|
| 11 |
-
PENDING = "pending"
|
| 12 |
-
PROCESSING = "processing"
|
| 13 |
-
COMPLETED = "completed"
|
| 14 |
-
FAILED = "failed"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class ChunkMetadata(BaseModel):
|
| 18 |
-
"""Metadata for a document chunk."""
|
| 19 |
-
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 20 |
-
kb_id: str
|
| 21 |
-
user_id: str
|
| 22 |
-
file_name: str
|
| 23 |
-
file_type: str
|
| 24 |
-
chunk_id: str
|
| 25 |
-
chunk_index: int
|
| 26 |
-
page_number: Optional[int] = None
|
| 27 |
-
total_chunks: int
|
| 28 |
-
document_id: Optional[str] = None # Track original document
|
| 29 |
-
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class DocumentChunk(BaseModel):
|
| 33 |
-
"""A chunk of text with metadata."""
|
| 34 |
-
id: str
|
| 35 |
-
content: str
|
| 36 |
-
metadata: ChunkMetadata
|
| 37 |
-
embedding: Optional[List[float]] = None
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class UploadRequest(BaseModel):
|
| 41 |
-
"""Request model for file upload."""
|
| 42 |
-
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 43 |
-
user_id: str
|
| 44 |
-
kb_id: str
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class UploadResponse(BaseModel):
|
| 48 |
-
"""Response model for file upload."""
|
| 49 |
-
success: bool
|
| 50 |
-
message: str
|
| 51 |
-
document_id: Optional[str] = None
|
| 52 |
-
file_name: str
|
| 53 |
-
chunks_created: int = 0
|
| 54 |
-
status: DocumentStatus = DocumentStatus.PENDING
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class Citation(BaseModel):
|
| 58 |
-
"""Citation reference for an answer."""
|
| 59 |
-
file_name: str
|
| 60 |
-
chunk_id: str
|
| 61 |
-
page_number: Optional[int] = None
|
| 62 |
-
relevance_score: float
|
| 63 |
-
excerpt: str # Short excerpt from the chunk
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class ChatRequest(BaseModel):
|
| 67 |
-
"""Request model for chat endpoint."""
|
| 68 |
-
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 69 |
-
user_id: str
|
| 70 |
-
kb_id: str
|
| 71 |
-
conversation_id: Optional[str] = None
|
| 72 |
-
question: str
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class ChatResponse(BaseModel):
|
| 76 |
-
"""Response model for chat endpoint."""
|
| 77 |
-
success: bool
|
| 78 |
-
answer: str
|
| 79 |
-
citations: List[Citation] = []
|
| 80 |
-
confidence: float # 0-1 score
|
| 81 |
-
from_knowledge_base: bool = True
|
| 82 |
-
escalation_suggested: bool = False
|
| 83 |
-
conversation_id: str
|
| 84 |
-
metadata: Dict[str, Any] = {}
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class RetrievalResult(BaseModel):
|
| 88 |
-
"""Result from vector store retrieval."""
|
| 89 |
-
chunk_id: str
|
| 90 |
-
content: str
|
| 91 |
-
metadata: Dict[str, Any]
|
| 92 |
-
similarity_score: float
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
class KnowledgeBaseStats(BaseModel):
|
| 96 |
-
"""Statistics for a knowledge base."""
|
| 97 |
-
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 98 |
-
kb_id: str
|
| 99 |
-
user_id: str
|
| 100 |
-
total_documents: int
|
| 101 |
-
total_chunks: int
|
| 102 |
-
file_names: List[str]
|
| 103 |
-
last_updated: Optional[datetime] = None
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class HealthResponse(BaseModel):
|
| 107 |
-
"""Health check response."""
|
| 108 |
-
status: str
|
| 109 |
-
version: str = "1.0.0"
|
| 110 |
-
vector_db_connected: bool
|
| 111 |
-
llm_configured: bool
|
| 112 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for API request/response schemas.
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
from typing import List, Optional, Dict, Any
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from enum import Enum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DocumentStatus(str, Enum):
|
| 11 |
+
PENDING = "pending"
|
| 12 |
+
PROCESSING = "processing"
|
| 13 |
+
COMPLETED = "completed"
|
| 14 |
+
FAILED = "failed"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ChunkMetadata(BaseModel):
|
| 18 |
+
"""Metadata for a document chunk."""
|
| 19 |
+
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 20 |
+
kb_id: str
|
| 21 |
+
user_id: str
|
| 22 |
+
file_name: str
|
| 23 |
+
file_type: str
|
| 24 |
+
chunk_id: str
|
| 25 |
+
chunk_index: int
|
| 26 |
+
page_number: Optional[int] = None
|
| 27 |
+
total_chunks: int
|
| 28 |
+
document_id: Optional[str] = None # Track original document
|
| 29 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DocumentChunk(BaseModel):
|
| 33 |
+
"""A chunk of text with metadata."""
|
| 34 |
+
id: str
|
| 35 |
+
content: str
|
| 36 |
+
metadata: ChunkMetadata
|
| 37 |
+
embedding: Optional[List[float]] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class UploadRequest(BaseModel):
|
| 41 |
+
"""Request model for file upload."""
|
| 42 |
+
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 43 |
+
user_id: str
|
| 44 |
+
kb_id: str
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class UploadResponse(BaseModel):
|
| 48 |
+
"""Response model for file upload."""
|
| 49 |
+
success: bool
|
| 50 |
+
message: str
|
| 51 |
+
document_id: Optional[str] = None
|
| 52 |
+
file_name: str
|
| 53 |
+
chunks_created: int = 0
|
| 54 |
+
status: DocumentStatus = DocumentStatus.PENDING
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Citation(BaseModel):
|
| 58 |
+
"""Citation reference for an answer."""
|
| 59 |
+
file_name: str
|
| 60 |
+
chunk_id: str
|
| 61 |
+
page_number: Optional[int] = None
|
| 62 |
+
relevance_score: float
|
| 63 |
+
excerpt: str # Short excerpt from the chunk
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ChatRequest(BaseModel):
|
| 67 |
+
"""Request model for chat endpoint."""
|
| 68 |
+
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 69 |
+
user_id: str
|
| 70 |
+
kb_id: str
|
| 71 |
+
conversation_id: Optional[str] = None
|
| 72 |
+
question: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ChatResponse(BaseModel):
|
| 76 |
+
"""Response model for chat endpoint."""
|
| 77 |
+
success: bool
|
| 78 |
+
answer: str
|
| 79 |
+
citations: List[Citation] = []
|
| 80 |
+
confidence: float # 0-1 score
|
| 81 |
+
from_knowledge_base: bool = True
|
| 82 |
+
escalation_suggested: bool = False
|
| 83 |
+
conversation_id: str
|
| 84 |
+
metadata: Dict[str, Any] = {}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RetrievalResult(BaseModel):
|
| 88 |
+
"""Result from vector store retrieval."""
|
| 89 |
+
chunk_id: str
|
| 90 |
+
content: str
|
| 91 |
+
metadata: Dict[str, Any]
|
| 92 |
+
similarity_score: float
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class KnowledgeBaseStats(BaseModel):
|
| 96 |
+
"""Statistics for a knowledge base."""
|
| 97 |
+
tenant_id: str # CRITICAL: Multi-tenant isolation
|
| 98 |
+
kb_id: str
|
| 99 |
+
user_id: str
|
| 100 |
+
total_documents: int
|
| 101 |
+
total_chunks: int
|
| 102 |
+
file_names: List[str]
|
| 103 |
+
last_updated: Optional[datetime] = None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class HealthResponse(BaseModel):
|
| 107 |
+
"""Health check response."""
|
| 108 |
+
status: str
|
| 109 |
+
version: str = "1.0.0"
|
| 110 |
+
vector_db_connected: bool
|
| 111 |
+
llm_configured: bool
|
| 112 |
+
|
app/rag/__init__.py
CHANGED
|
@@ -1,27 +1,27 @@
|
|
| 1 |
-
"""
|
| 2 |
-
RAG (Retrieval-Augmented Generation) pipeline modules.
|
| 3 |
-
"""
|
| 4 |
-
from app.rag.ingest import parser, DocumentParser
|
| 5 |
-
from app.rag.chunking import chunker, DocumentChunker
|
| 6 |
-
from app.rag.embeddings import get_embedding_service, EmbeddingService
|
| 7 |
-
from app.rag.vectorstore import get_vector_store, VectorStore
|
| 8 |
-
from app.rag.retrieval import get_retrieval_service, RetrievalService
|
| 9 |
-
from app.rag.answer import get_answer_service, AnswerService
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"parser",
|
| 13 |
-
"DocumentParser",
|
| 14 |
-
"chunker",
|
| 15 |
-
"DocumentChunker",
|
| 16 |
-
"get_embedding_service",
|
| 17 |
-
"EmbeddingService",
|
| 18 |
-
"get_vector_store",
|
| 19 |
-
"VectorStore",
|
| 20 |
-
"get_retrieval_service",
|
| 21 |
-
"RetrievalService",
|
| 22 |
-
"get_answer_service",
|
| 23 |
-
"AnswerService",
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG (Retrieval-Augmented Generation) pipeline modules.
|
| 3 |
+
"""
|
| 4 |
+
from app.rag.ingest import parser, DocumentParser
|
| 5 |
+
from app.rag.chunking import chunker, DocumentChunker
|
| 6 |
+
from app.rag.embeddings import get_embedding_service, EmbeddingService
|
| 7 |
+
from app.rag.vectorstore import get_vector_store, VectorStore
|
| 8 |
+
from app.rag.retrieval import get_retrieval_service, RetrievalService
|
| 9 |
+
from app.rag.answer import get_answer_service, AnswerService
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"parser",
|
| 13 |
+
"DocumentParser",
|
| 14 |
+
"chunker",
|
| 15 |
+
"DocumentChunker",
|
| 16 |
+
"get_embedding_service",
|
| 17 |
+
"EmbeddingService",
|
| 18 |
+
"get_vector_store",
|
| 19 |
+
"VectorStore",
|
| 20 |
+
"get_retrieval_service",
|
| 21 |
+
"RetrievalService",
|
| 22 |
+
"get_answer_service",
|
| 23 |
+
"AnswerService",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
app/rag/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (799 Bytes)
|
|
|
app/rag/__pycache__/answer.cpython-313.pyc
DELETED
|
Binary file (18.6 kB)
|
|
|
app/rag/__pycache__/chunking.cpython-313.pyc
DELETED
|
Binary file (7.26 kB)
|
|
|
app/rag/__pycache__/embeddings.cpython-313.pyc
DELETED
|
Binary file (6.06 kB)
|
|
|
app/rag/__pycache__/ingest.cpython-313.pyc
DELETED
|
Binary file (10.5 kB)
|
|
|
app/rag/__pycache__/intent.cpython-313.pyc
DELETED
|
Binary file (5.51 kB)
|
|
|
app/rag/__pycache__/prompts.cpython-313.pyc
DELETED
|
Binary file (6.94 kB)
|
|
|
app/rag/__pycache__/retrieval.cpython-313.pyc
DELETED
|
Binary file (9.74 kB)
|
|
|
app/rag/__pycache__/vectorstore.cpython-313.pyc
DELETED
|
Binary file (9.35 kB)
|
|
|
app/rag/__pycache__/verifier.cpython-313.pyc
DELETED
|
Binary file (9.77 kB)
|
|
|
app/rag/answer.py
CHANGED
|
@@ -1,444 +1,444 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Answer generation using LLM with RAG context.
|
| 3 |
-
Supports Gemini and OpenAI as providers.
|
| 4 |
-
"""
|
| 5 |
-
import google.generativeai as genai
|
| 6 |
-
from openai import OpenAI
|
| 7 |
-
from typing import Optional, Dict, Any, List
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
import re
|
| 11 |
-
|
| 12 |
-
from app.config import settings
|
| 13 |
-
from app.rag.prompts import (
|
| 14 |
-
format_rag_prompt,
|
| 15 |
-
format_draft_prompt,
|
| 16 |
-
get_no_context_response,
|
| 17 |
-
get_low_confidence_response
|
| 18 |
-
)
|
| 19 |
-
from app.rag.verifier import get_verifier_service
|
| 20 |
-
from app.rag.intent import detect_intents
|
| 21 |
-
from app.models.schemas import Citation
|
| 22 |
-
from abc import ABC, abstractmethod
|
| 23 |
-
|
| 24 |
-
logging.basicConfig(level=logging.INFO)
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class LLMProvider(ABC):
|
| 29 |
-
"""Base class for LLM providers."""
|
| 30 |
-
|
| 31 |
-
@abstractmethod
|
| 32 |
-
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 33 |
-
"""Generate response from prompts."""
|
| 34 |
-
raise NotImplementedError
|
| 35 |
-
|
| 36 |
-
@abstractmethod
|
| 37 |
-
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 38 |
-
"""
|
| 39 |
-
Generate response and return usage information.
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
(response_text, usage_info)
|
| 43 |
-
usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used
|
| 44 |
-
"""
|
| 45 |
-
raise NotImplementedError
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class GeminiProvider(LLMProvider):
|
| 49 |
-
"""Google Gemini LLM provider."""
|
| 50 |
-
|
| 51 |
-
def __init__(self, api_key: Optional[str] = None, model: str = None):
|
| 52 |
-
self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY")
|
| 53 |
-
# Default to gemini-1.5-flash if not specified
|
| 54 |
-
self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash"
|
| 55 |
-
|
| 56 |
-
if not self.api_key:
|
| 57 |
-
raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.")
|
| 58 |
-
|
| 59 |
-
genai.configure(api_key=self.api_key)
|
| 60 |
-
|
| 61 |
-
# Don't initialize client here - do it lazily in generate() to handle errors better
|
| 62 |
-
self._client = None
|
| 63 |
-
logger.info(f"Gemini provider initialized (will use model: {self.model})")
|
| 64 |
-
|
| 65 |
-
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 66 |
-
"""Generate response using Gemini."""
|
| 67 |
-
text, _ = self.generate_with_usage(system_prompt, user_prompt)
|
| 68 |
-
return text
|
| 69 |
-
|
| 70 |
-
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 71 |
-
"""Generate response using Gemini and return usage info."""
|
| 72 |
-
# Combine system and user prompts for Gemini
|
| 73 |
-
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
| 74 |
-
|
| 75 |
-
# Estimate prompt tokens (rough: 1 token ≈ 4 chars)
|
| 76 |
-
prompt_tokens = len(full_prompt) // 4
|
| 77 |
-
|
| 78 |
-
# Try to list available models first, then use the first available one
|
| 79 |
-
# If that fails, try common model names
|
| 80 |
-
models_to_try = []
|
| 81 |
-
|
| 82 |
-
# First, try to get available models
|
| 83 |
-
try:
|
| 84 |
-
available_models = genai.list_models()
|
| 85 |
-
model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods]
|
| 86 |
-
if model_names:
|
| 87 |
-
# Extract just the model name (remove 'models/' prefix if present)
|
| 88 |
-
clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names]
|
| 89 |
-
models_to_try.extend(clean_names[:3]) # Use first 3 available models
|
| 90 |
-
logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}")
|
| 91 |
-
except Exception as e:
|
| 92 |
-
logger.warning(f"Could not list available models: {e}, using fallback list")
|
| 93 |
-
|
| 94 |
-
# Fallback to common model names if listing failed
|
| 95 |
-
if not models_to_try:
|
| 96 |
-
models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"]
|
| 97 |
-
|
| 98 |
-
# Add configured model if different
|
| 99 |
-
if self.model and self.model not in models_to_try:
|
| 100 |
-
models_to_try.insert(0, self.model)
|
| 101 |
-
|
| 102 |
-
# Remove duplicates while preserving order
|
| 103 |
-
seen = set()
|
| 104 |
-
models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))]
|
| 105 |
-
|
| 106 |
-
last_error = None
|
| 107 |
-
for model_name in models_to_try:
|
| 108 |
-
try:
|
| 109 |
-
logger.info(f"Attempting to generate with model: {model_name}")
|
| 110 |
-
# Create a new client for this model
|
| 111 |
-
client = genai.GenerativeModel(model_name)
|
| 112 |
-
response = client.generate_content(
|
| 113 |
-
full_prompt,
|
| 114 |
-
generation_config=genai.types.GenerationConfig(
|
| 115 |
-
temperature=settings.TEMPERATURE,
|
| 116 |
-
max_output_tokens=1024,
|
| 117 |
-
)
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
# Extract response text
|
| 121 |
-
response_text = response.text
|
| 122 |
-
|
| 123 |
-
# Try to get usage info from response
|
| 124 |
-
usage_info = {
|
| 125 |
-
"prompt_tokens": prompt_tokens,
|
| 126 |
-
"completion_tokens": len(response_text) // 4, # Estimate
|
| 127 |
-
"total_tokens": prompt_tokens + (len(response_text) // 4),
|
| 128 |
-
"model_used": model_name.split('/')[-1] if '/' in model_name else model_name
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
# Try to get actual usage from response if available
|
| 132 |
-
if hasattr(response, 'usage_metadata'):
|
| 133 |
-
usage_metadata = response.usage_metadata
|
| 134 |
-
if hasattr(usage_metadata, 'prompt_token_count'):
|
| 135 |
-
usage_info["prompt_tokens"] = usage_metadata.prompt_token_count
|
| 136 |
-
if hasattr(usage_metadata, 'candidates_token_count'):
|
| 137 |
-
usage_info["completion_tokens"] = usage_metadata.candidates_token_count
|
| 138 |
-
if hasattr(usage_metadata, 'total_token_count'):
|
| 139 |
-
usage_info["total_tokens"] = usage_metadata.total_token_count
|
| 140 |
-
|
| 141 |
-
if model_name != self.model:
|
| 142 |
-
logger.info(f"✅ Successfully used model: {model_name}")
|
| 143 |
-
|
| 144 |
-
return response_text, usage_info
|
| 145 |
-
except Exception as e:
|
| 146 |
-
error_str = str(e).lower()
|
| 147 |
-
last_error = e
|
| 148 |
-
if "not found" in error_str or "not supported" in error_str or "404" in error_str:
|
| 149 |
-
logger.warning(f"Model {model_name} failed: {e}")
|
| 150 |
-
continue # Try next model
|
| 151 |
-
else:
|
| 152 |
-
# Different error (not model not found), re-raise
|
| 153 |
-
logger.error(f"Gemini generation error with {model_name}: {e}")
|
| 154 |
-
raise
|
| 155 |
-
|
| 156 |
-
# All models failed - return a helpful error message
|
| 157 |
-
error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models."
|
| 158 |
-
logger.error(error_msg)
|
| 159 |
-
raise Exception(error_msg)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
class OpenAIProvider(LLMProvider):
|
| 163 |
-
"""OpenAI LLM provider."""
|
| 164 |
-
|
| 165 |
-
def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL):
|
| 166 |
-
self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY")
|
| 167 |
-
self.model = model
|
| 168 |
-
|
| 169 |
-
if not self.api_key:
|
| 170 |
-
raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.")
|
| 171 |
-
|
| 172 |
-
self.client = OpenAI(api_key=self.api_key)
|
| 173 |
-
logger.info(f"OpenAI provider initialized with model: {model}")
|
| 174 |
-
|
| 175 |
-
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 176 |
-
"""Generate response using OpenAI."""
|
| 177 |
-
text, _ = self.generate_with_usage(system_prompt, user_prompt)
|
| 178 |
-
return text
|
| 179 |
-
|
| 180 |
-
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 181 |
-
"""Generate response using OpenAI and return usage info."""
|
| 182 |
-
try:
|
| 183 |
-
response = self.client.chat.completions.create(
|
| 184 |
-
model=self.model,
|
| 185 |
-
messages=[
|
| 186 |
-
{"role": "system", "content": system_prompt},
|
| 187 |
-
{"role": "user", "content": user_prompt}
|
| 188 |
-
],
|
| 189 |
-
temperature=settings.TEMPERATURE,
|
| 190 |
-
max_tokens=1024
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
response_text = response.choices[0].message.content
|
| 194 |
-
|
| 195 |
-
# Extract usage info from OpenAI response
|
| 196 |
-
usage_info = {
|
| 197 |
-
"prompt_tokens": response.usage.prompt_tokens,
|
| 198 |
-
"completion_tokens": response.usage.completion_tokens,
|
| 199 |
-
"total_tokens": response.usage.total_tokens,
|
| 200 |
-
"model_used": self.model
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
return response_text, usage_info
|
| 204 |
-
except Exception as e:
|
| 205 |
-
logger.error(f"OpenAI generation error: {e}")
|
| 206 |
-
raise
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
class AnswerService:
|
| 210 |
-
"""
|
| 211 |
-
Generates answers using RAG context and LLM.
|
| 212 |
-
Handles confidence scoring and citation extraction.
|
| 213 |
-
"""
|
| 214 |
-
|
| 215 |
-
# Confidence thresholds
|
| 216 |
-
HIGH_CONFIDENCE_THRESHOLD = 0.5
|
| 217 |
-
LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold
|
| 218 |
-
STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results)
|
| 219 |
-
|
| 220 |
-
def __init__(self, provider: str = settings.LLM_PROVIDER):
|
| 221 |
-
"""
|
| 222 |
-
Initialize the answer service.
|
| 223 |
-
|
| 224 |
-
Args:
|
| 225 |
-
provider: LLM provider to use ("gemini" or "openai")
|
| 226 |
-
"""
|
| 227 |
-
self.provider_name = provider
|
| 228 |
-
self._provider: Optional[LLMProvider] = None
|
| 229 |
-
|
| 230 |
-
@property
|
| 231 |
-
def provider(self) -> LLMProvider:
|
| 232 |
-
"""Lazy load the LLM provider."""
|
| 233 |
-
if self._provider is None:
|
| 234 |
-
if self.provider_name == "gemini":
|
| 235 |
-
self._provider = GeminiProvider()
|
| 236 |
-
elif self.provider_name == "openai":
|
| 237 |
-
self._provider = OpenAIProvider()
|
| 238 |
-
else:
|
| 239 |
-
raise ValueError(f"Unknown LLM provider: {self.provider_name}")
|
| 240 |
-
return self._provider
|
| 241 |
-
|
| 242 |
-
def generate_answer(
|
| 243 |
-
self,
|
| 244 |
-
question: str,
|
| 245 |
-
context: str,
|
| 246 |
-
citations_info: List[Dict[str, Any]],
|
| 247 |
-
confidence: float,
|
| 248 |
-
has_relevant_results: bool,
|
| 249 |
-
use_verifier: bool = None # None = use config default
|
| 250 |
-
) -> Dict[str, Any]:
|
| 251 |
-
"""
|
| 252 |
-
Generate an answer based on retrieved context with mandatory verifier.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
question: User's question
|
| 256 |
-
context: Retrieved context from knowledge base
|
| 257 |
-
citations_info: List of citation information
|
| 258 |
-
confidence: Average confidence score from retrieval
|
| 259 |
-
has_relevant_results: Whether any results passed the threshold
|
| 260 |
-
use_verifier: Whether to use verifier mode (None = use config default)
|
| 261 |
-
|
| 262 |
-
Returns:
|
| 263 |
-
Dictionary with answer, citations, confidence, and metadata
|
| 264 |
-
"""
|
| 265 |
-
# Determine if verifier should be used (mandatory by default)
|
| 266 |
-
if use_verifier is None:
|
| 267 |
-
use_verifier = settings.REQUIRE_VERIFIER
|
| 268 |
-
|
| 269 |
-
# GATE 1: No relevant results found - REFUSE
|
| 270 |
-
if not has_relevant_results or not context:
|
| 271 |
-
logger.info("No relevant context found, returning no-context response")
|
| 272 |
-
return {
|
| 273 |
-
"answer": get_no_context_response(),
|
| 274 |
-
"citations": [],
|
| 275 |
-
"confidence": 0.0,
|
| 276 |
-
"from_knowledge_base": False,
|
| 277 |
-
"escalation_suggested": True,
|
| 278 |
-
"refused": True
|
| 279 |
-
}
|
| 280 |
-
|
| 281 |
-
# GATE 2: Strict confidence threshold - REFUSE if below strict threshold
|
| 282 |
-
if confidence < self.STRICT_CONFIDENCE_THRESHOLD:
|
| 283 |
-
logger.warning(
|
| 284 |
-
f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), "
|
| 285 |
-
f"REFUSING to answer to prevent hallucination"
|
| 286 |
-
)
|
| 287 |
-
return {
|
| 288 |
-
"answer": get_no_context_response(),
|
| 289 |
-
"citations": [],
|
| 290 |
-
"confidence": confidence,
|
| 291 |
-
"from_knowledge_base": False,
|
| 292 |
-
"escalation_suggested": True,
|
| 293 |
-
"refused": True,
|
| 294 |
-
"refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}"
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
# GATE 3: Intent-based gating for specific intents (integration, API, etc.)
|
| 298 |
-
intents = detect_intents(question)
|
| 299 |
-
if "integration" in intents or "api" in question.lower():
|
| 300 |
-
# For integration/API questions, require strong relevance
|
| 301 |
-
if confidence < 0.50: # Even stricter for integration questions
|
| 302 |
-
logger.warning(
|
| 303 |
-
f"Integration/API question with low confidence ({confidence:.3f}), "
|
| 304 |
-
f"REFUSING to prevent hallucination"
|
| 305 |
-
)
|
| 306 |
-
return {
|
| 307 |
-
"answer": get_no_context_response(),
|
| 308 |
-
"citations": [],
|
| 309 |
-
"confidence": confidence,
|
| 310 |
-
"from_knowledge_base": False,
|
| 311 |
-
"escalation_suggested": True,
|
| 312 |
-
"refused": True,
|
| 313 |
-
"refusal_reason": "Integration/API questions require higher confidence"
|
| 314 |
-
}
|
| 315 |
-
|
| 316 |
-
# Case 3: Passed all gates - generate answer with MANDATORY verifier
|
| 317 |
-
logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}")
|
| 318 |
-
|
| 319 |
-
try:
|
| 320 |
-
# VERIFIER MODE IS MANDATORY: Draft → Verify → Final
|
| 321 |
-
# Step 1: Generate draft answer with usage tracking
|
| 322 |
-
draft_system, draft_user = format_draft_prompt(context, question)
|
| 323 |
-
draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user)
|
| 324 |
-
logger.info("Generated draft answer, running verifier...")
|
| 325 |
-
|
| 326 |
-
# Step 2: Verify draft answer (MANDATORY)
|
| 327 |
-
verifier = get_verifier_service()
|
| 328 |
-
verification = verifier.verify_answer(
|
| 329 |
-
draft_answer=draft_answer,
|
| 330 |
-
context=context,
|
| 331 |
-
citations_info=citations_info
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# Step 3: Handle verification result
|
| 335 |
-
if verification["pass"]:
|
| 336 |
-
logger.info("✅ Verifier PASSED - Using draft answer")
|
| 337 |
-
citations = self._extract_citations(draft_answer, citations_info)
|
| 338 |
-
return {
|
| 339 |
-
"answer": draft_answer,
|
| 340 |
-
"citations": citations,
|
| 341 |
-
"confidence": confidence,
|
| 342 |
-
"from_knowledge_base": True,
|
| 343 |
-
"escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD,
|
| 344 |
-
"verifier_passed": True,
|
| 345 |
-
"refused": False,
|
| 346 |
-
"usage": usage_info # Include usage info for tracking
|
| 347 |
-
}
|
| 348 |
-
else:
|
| 349 |
-
# Verifier failed - REFUSE to answer
|
| 350 |
-
issues = verification.get('issues', [])
|
| 351 |
-
unsupported = verification.get('unsupported_claims', [])
|
| 352 |
-
logger.warning(
|
| 353 |
-
f"❌ Verifier FAILED - Issues: {issues}, "
|
| 354 |
-
f"Unsupported claims: {unsupported}"
|
| 355 |
-
)
|
| 356 |
-
refusal_message = (
|
| 357 |
-
get_no_context_response() +
|
| 358 |
-
"\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. "
|
| 359 |
-
"This helps prevent providing incorrect information."
|
| 360 |
-
)
|
| 361 |
-
return {
|
| 362 |
-
"answer": refusal_message,
|
| 363 |
-
"citations": [],
|
| 364 |
-
"confidence": 0.0,
|
| 365 |
-
"from_knowledge_base": False,
|
| 366 |
-
"escalation_suggested": True,
|
| 367 |
-
"verifier_passed": False,
|
| 368 |
-
"verifier_issues": issues,
|
| 369 |
-
"unsupported_claims": unsupported,
|
| 370 |
-
"refused": True,
|
| 371 |
-
"refusal_reason": "Verifier failed: claims not supported by context",
|
| 372 |
-
"usage": usage_info # Still track usage even if refused
|
| 373 |
-
}
|
| 374 |
-
|
| 375 |
-
except ValueError as e:
|
| 376 |
-
# Configuration errors (e.g., missing API key)
|
| 377 |
-
error_msg = str(e)
|
| 378 |
-
logger.error(f"Configuration error in answer generation: {error_msg}")
|
| 379 |
-
if "API key" in error_msg.lower():
|
| 380 |
-
raise ValueError(f"LLM API key not configured: {error_msg}")
|
| 381 |
-
raise
|
| 382 |
-
except Exception as e:
|
| 383 |
-
logger.error(f"Error generating answer: {e}", exc_info=True)
|
| 384 |
-
# Re-raise to be handled by the endpoint
|
| 385 |
-
raise
|
| 386 |
-
|
| 387 |
-
def _extract_citations(
|
| 388 |
-
self,
|
| 389 |
-
answer: str,
|
| 390 |
-
citations_info: List[Dict[str, Any]]
|
| 391 |
-
) -> List[Citation]:
|
| 392 |
-
"""
|
| 393 |
-
Extract and format citations from the answer.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
answer: Generated answer with [Source X] references
|
| 397 |
-
citations_info: Available citation information
|
| 398 |
-
|
| 399 |
-
Returns:
|
| 400 |
-
List of Citation objects
|
| 401 |
-
"""
|
| 402 |
-
citations = []
|
| 403 |
-
|
| 404 |
-
# Find all [Source X] references in the answer
|
| 405 |
-
source_pattern = r'\[Source\s*(\d+)\]'
|
| 406 |
-
matches = re.findall(source_pattern, answer)
|
| 407 |
-
referenced_indices = set(int(m) for m in matches)
|
| 408 |
-
|
| 409 |
-
# Build citation objects for referenced sources
|
| 410 |
-
for info in citations_info:
|
| 411 |
-
if info.get("index") in referenced_indices:
|
| 412 |
-
citations.append(Citation(
|
| 413 |
-
file_name=info.get("file_name", "Unknown"),
|
| 414 |
-
chunk_id=info.get("chunk_id", ""),
|
| 415 |
-
page_number=info.get("page_number"),
|
| 416 |
-
relevance_score=info.get("similarity_score", 0.0),
|
| 417 |
-
excerpt=info.get("excerpt", "")
|
| 418 |
-
))
|
| 419 |
-
|
| 420 |
-
# If no specific citations found but we have context, include top sources
|
| 421 |
-
if not citations and citations_info:
|
| 422 |
-
for info in citations_info[:3]: # Top 3 sources
|
| 423 |
-
citations.append(Citation(
|
| 424 |
-
file_name=info.get("file_name", "Unknown"),
|
| 425 |
-
chunk_id=info.get("chunk_id", ""),
|
| 426 |
-
page_number=info.get("page_number"),
|
| 427 |
-
relevance_score=info.get("similarity_score", 0.0),
|
| 428 |
-
excerpt=info.get("excerpt", "")
|
| 429 |
-
))
|
| 430 |
-
|
| 431 |
-
return citations
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
# Global answer service instance
|
| 435 |
-
_answer_service: Optional[AnswerService] = None
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def get_answer_service() -> AnswerService:
|
| 439 |
-
"""Get the global answer service instance."""
|
| 440 |
-
global _answer_service
|
| 441 |
-
if _answer_service is None:
|
| 442 |
-
_answer_service = AnswerService()
|
| 443 |
-
return _answer_service
|
| 444 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Answer generation using LLM with RAG context.
|
| 3 |
+
Supports Gemini and OpenAI as providers.
|
| 4 |
+
"""
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from typing import Optional, Dict, Any, List
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
from app.config import settings
|
| 13 |
+
from app.rag.prompts import (
|
| 14 |
+
format_rag_prompt,
|
| 15 |
+
format_draft_prompt,
|
| 16 |
+
get_no_context_response,
|
| 17 |
+
get_low_confidence_response
|
| 18 |
+
)
|
| 19 |
+
from app.rag.verifier import get_verifier_service
|
| 20 |
+
from app.rag.intent import detect_intents
|
| 21 |
+
from app.models.schemas import Citation
|
| 22 |
+
from abc import ABC, abstractmethod
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LLMProvider(ABC):
|
| 29 |
+
"""Base class for LLM providers."""
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 33 |
+
"""Generate response from prompts."""
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 38 |
+
"""
|
| 39 |
+
Generate response and return usage information.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
(response_text, usage_info)
|
| 43 |
+
usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used
|
| 44 |
+
"""
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GeminiProvider(LLMProvider):
|
| 49 |
+
"""Google Gemini LLM provider."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, api_key: Optional[str] = None, model: str = None):
|
| 52 |
+
self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY")
|
| 53 |
+
# Default to gemini-1.5-flash if not specified
|
| 54 |
+
self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash"
|
| 55 |
+
|
| 56 |
+
if not self.api_key:
|
| 57 |
+
raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.")
|
| 58 |
+
|
| 59 |
+
genai.configure(api_key=self.api_key)
|
| 60 |
+
|
| 61 |
+
# Don't initialize client here - do it lazily in generate() to handle errors better
|
| 62 |
+
self._client = None
|
| 63 |
+
logger.info(f"Gemini provider initialized (will use model: {self.model})")
|
| 64 |
+
|
| 65 |
+
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 66 |
+
"""Generate response using Gemini."""
|
| 67 |
+
text, _ = self.generate_with_usage(system_prompt, user_prompt)
|
| 68 |
+
return text
|
| 69 |
+
|
| 70 |
+
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 71 |
+
"""Generate response using Gemini and return usage info."""
|
| 72 |
+
# Combine system and user prompts for Gemini
|
| 73 |
+
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
| 74 |
+
|
| 75 |
+
# Estimate prompt tokens (rough: 1 token ≈ 4 chars)
|
| 76 |
+
prompt_tokens = len(full_prompt) // 4
|
| 77 |
+
|
| 78 |
+
# Try to list available models first, then use the first available one
|
| 79 |
+
# If that fails, try common model names
|
| 80 |
+
models_to_try = []
|
| 81 |
+
|
| 82 |
+
# First, try to get available models
|
| 83 |
+
try:
|
| 84 |
+
available_models = genai.list_models()
|
| 85 |
+
model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods]
|
| 86 |
+
if model_names:
|
| 87 |
+
# Extract just the model name (remove 'models/' prefix if present)
|
| 88 |
+
clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names]
|
| 89 |
+
models_to_try.extend(clean_names[:3]) # Use first 3 available models
|
| 90 |
+
logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.warning(f"Could not list available models: {e}, using fallback list")
|
| 93 |
+
|
| 94 |
+
# Fallback to common model names if listing failed
|
| 95 |
+
if not models_to_try:
|
| 96 |
+
models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"]
|
| 97 |
+
|
| 98 |
+
# Add configured model if different
|
| 99 |
+
if self.model and self.model not in models_to_try:
|
| 100 |
+
models_to_try.insert(0, self.model)
|
| 101 |
+
|
| 102 |
+
# Remove duplicates while preserving order
|
| 103 |
+
seen = set()
|
| 104 |
+
models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))]
|
| 105 |
+
|
| 106 |
+
last_error = None
|
| 107 |
+
for model_name in models_to_try:
|
| 108 |
+
try:
|
| 109 |
+
logger.info(f"Attempting to generate with model: {model_name}")
|
| 110 |
+
# Create a new client for this model
|
| 111 |
+
client = genai.GenerativeModel(model_name)
|
| 112 |
+
response = client.generate_content(
|
| 113 |
+
full_prompt,
|
| 114 |
+
generation_config=genai.types.GenerationConfig(
|
| 115 |
+
temperature=settings.TEMPERATURE,
|
| 116 |
+
max_output_tokens=1024,
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Extract response text
|
| 121 |
+
response_text = response.text
|
| 122 |
+
|
| 123 |
+
# Try to get usage info from response
|
| 124 |
+
usage_info = {
|
| 125 |
+
"prompt_tokens": prompt_tokens,
|
| 126 |
+
"completion_tokens": len(response_text) // 4, # Estimate
|
| 127 |
+
"total_tokens": prompt_tokens + (len(response_text) // 4),
|
| 128 |
+
"model_used": model_name.split('/')[-1] if '/' in model_name else model_name
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# Try to get actual usage from response if available
|
| 132 |
+
if hasattr(response, 'usage_metadata'):
|
| 133 |
+
usage_metadata = response.usage_metadata
|
| 134 |
+
if hasattr(usage_metadata, 'prompt_token_count'):
|
| 135 |
+
usage_info["prompt_tokens"] = usage_metadata.prompt_token_count
|
| 136 |
+
if hasattr(usage_metadata, 'candidates_token_count'):
|
| 137 |
+
usage_info["completion_tokens"] = usage_metadata.candidates_token_count
|
| 138 |
+
if hasattr(usage_metadata, 'total_token_count'):
|
| 139 |
+
usage_info["total_tokens"] = usage_metadata.total_token_count
|
| 140 |
+
|
| 141 |
+
if model_name != self.model:
|
| 142 |
+
logger.info(f"✅ Successfully used model: {model_name}")
|
| 143 |
+
|
| 144 |
+
return response_text, usage_info
|
| 145 |
+
except Exception as e:
|
| 146 |
+
error_str = str(e).lower()
|
| 147 |
+
last_error = e
|
| 148 |
+
if "not found" in error_str or "not supported" in error_str or "404" in error_str:
|
| 149 |
+
logger.warning(f"Model {model_name} failed: {e}")
|
| 150 |
+
continue # Try next model
|
| 151 |
+
else:
|
| 152 |
+
# Different error (not model not found), re-raise
|
| 153 |
+
logger.error(f"Gemini generation error with {model_name}: {e}")
|
| 154 |
+
raise
|
| 155 |
+
|
| 156 |
+
# All models failed - return a helpful error message
|
| 157 |
+
error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models."
|
| 158 |
+
logger.error(error_msg)
|
| 159 |
+
raise Exception(error_msg)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class OpenAIProvider(LLMProvider):
|
| 163 |
+
"""OpenAI LLM provider."""
|
| 164 |
+
|
| 165 |
+
def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL):
|
| 166 |
+
self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY")
|
| 167 |
+
self.model = model
|
| 168 |
+
|
| 169 |
+
if not self.api_key:
|
| 170 |
+
raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.")
|
| 171 |
+
|
| 172 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 173 |
+
logger.info(f"OpenAI provider initialized with model: {model}")
|
| 174 |
+
|
| 175 |
+
def generate(self, system_prompt: str, user_prompt: str) -> str:
|
| 176 |
+
"""Generate response using OpenAI."""
|
| 177 |
+
text, _ = self.generate_with_usage(system_prompt, user_prompt)
|
| 178 |
+
return text
|
| 179 |
+
|
| 180 |
+
def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
|
| 181 |
+
"""Generate response using OpenAI and return usage info."""
|
| 182 |
+
try:
|
| 183 |
+
response = self.client.chat.completions.create(
|
| 184 |
+
model=self.model,
|
| 185 |
+
messages=[
|
| 186 |
+
{"role": "system", "content": system_prompt},
|
| 187 |
+
{"role": "user", "content": user_prompt}
|
| 188 |
+
],
|
| 189 |
+
temperature=settings.TEMPERATURE,
|
| 190 |
+
max_tokens=1024
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
response_text = response.choices[0].message.content
|
| 194 |
+
|
| 195 |
+
# Extract usage info from OpenAI response
|
| 196 |
+
usage_info = {
|
| 197 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
| 198 |
+
"completion_tokens": response.usage.completion_tokens,
|
| 199 |
+
"total_tokens": response.usage.total_tokens,
|
| 200 |
+
"model_used": self.model
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
return response_text, usage_info
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"OpenAI generation error: {e}")
|
| 206 |
+
raise
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class AnswerService:
|
| 210 |
+
"""
|
| 211 |
+
Generates answers using RAG context and LLM.
|
| 212 |
+
Handles confidence scoring and citation extraction.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
# Confidence thresholds
|
| 216 |
+
HIGH_CONFIDENCE_THRESHOLD = 0.5
|
| 217 |
+
LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold
|
| 218 |
+
STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results)
|
| 219 |
+
|
| 220 |
+
def __init__(self, provider: str = settings.LLM_PROVIDER):
|
| 221 |
+
"""
|
| 222 |
+
Initialize the answer service.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
provider: LLM provider to use ("gemini" or "openai")
|
| 226 |
+
"""
|
| 227 |
+
self.provider_name = provider
|
| 228 |
+
self._provider: Optional[LLMProvider] = None
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def provider(self) -> LLMProvider:
|
| 232 |
+
"""Lazy load the LLM provider."""
|
| 233 |
+
if self._provider is None:
|
| 234 |
+
if self.provider_name == "gemini":
|
| 235 |
+
self._provider = GeminiProvider()
|
| 236 |
+
elif self.provider_name == "openai":
|
| 237 |
+
self._provider = OpenAIProvider()
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unknown LLM provider: {self.provider_name}")
|
| 240 |
+
return self._provider
|
| 241 |
+
|
| 242 |
+
def generate_answer(
|
| 243 |
+
self,
|
| 244 |
+
question: str,
|
| 245 |
+
context: str,
|
| 246 |
+
citations_info: List[Dict[str, Any]],
|
| 247 |
+
confidence: float,
|
| 248 |
+
has_relevant_results: bool,
|
| 249 |
+
use_verifier: bool = None # None = use config default
|
| 250 |
+
) -> Dict[str, Any]:
|
| 251 |
+
"""
|
| 252 |
+
Generate an answer based on retrieved context with mandatory verifier.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
question: User's question
|
| 256 |
+
context: Retrieved context from knowledge base
|
| 257 |
+
citations_info: List of citation information
|
| 258 |
+
confidence: Average confidence score from retrieval
|
| 259 |
+
has_relevant_results: Whether any results passed the threshold
|
| 260 |
+
use_verifier: Whether to use verifier mode (None = use config default)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Dictionary with answer, citations, confidence, and metadata
|
| 264 |
+
"""
|
| 265 |
+
# Determine if verifier should be used (mandatory by default)
|
| 266 |
+
if use_verifier is None:
|
| 267 |
+
use_verifier = settings.REQUIRE_VERIFIER
|
| 268 |
+
|
| 269 |
+
# GATE 1: No relevant results found - REFUSE
|
| 270 |
+
if not has_relevant_results or not context:
|
| 271 |
+
logger.info("No relevant context found, returning no-context response")
|
| 272 |
+
return {
|
| 273 |
+
"answer": get_no_context_response(),
|
| 274 |
+
"citations": [],
|
| 275 |
+
"confidence": 0.0,
|
| 276 |
+
"from_knowledge_base": False,
|
| 277 |
+
"escalation_suggested": True,
|
| 278 |
+
"refused": True
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
# GATE 2: Strict confidence threshold - REFUSE if below strict threshold
|
| 282 |
+
if confidence < self.STRICT_CONFIDENCE_THRESHOLD:
|
| 283 |
+
logger.warning(
|
| 284 |
+
f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), "
|
| 285 |
+
f"REFUSING to answer to prevent hallucination"
|
| 286 |
+
)
|
| 287 |
+
return {
|
| 288 |
+
"answer": get_no_context_response(),
|
| 289 |
+
"citations": [],
|
| 290 |
+
"confidence": confidence,
|
| 291 |
+
"from_knowledge_base": False,
|
| 292 |
+
"escalation_suggested": True,
|
| 293 |
+
"refused": True,
|
| 294 |
+
"refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}"
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
# GATE 3: Intent-based gating for specific intents (integration, API, etc.)
|
| 298 |
+
intents = detect_intents(question)
|
| 299 |
+
if "integration" in intents or "api" in question.lower():
|
| 300 |
+
# For integration/API questions, require strong relevance
|
| 301 |
+
if confidence < 0.50: # Even stricter for integration questions
|
| 302 |
+
logger.warning(
|
| 303 |
+
f"Integration/API question with low confidence ({confidence:.3f}), "
|
| 304 |
+
f"REFUSING to prevent hallucination"
|
| 305 |
+
)
|
| 306 |
+
return {
|
| 307 |
+
"answer": get_no_context_response(),
|
| 308 |
+
"citations": [],
|
| 309 |
+
"confidence": confidence,
|
| 310 |
+
"from_knowledge_base": False,
|
| 311 |
+
"escalation_suggested": True,
|
| 312 |
+
"refused": True,
|
| 313 |
+
"refusal_reason": "Integration/API questions require higher confidence"
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
# Case 3: Passed all gates - generate answer with MANDATORY verifier
|
| 317 |
+
logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}")
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
# VERIFIER MODE IS MANDATORY: Draft → Verify → Final
|
| 321 |
+
# Step 1: Generate draft answer with usage tracking
|
| 322 |
+
draft_system, draft_user = format_draft_prompt(context, question)
|
| 323 |
+
draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user)
|
| 324 |
+
logger.info("Generated draft answer, running verifier...")
|
| 325 |
+
|
| 326 |
+
# Step 2: Verify draft answer (MANDATORY)
|
| 327 |
+
verifier = get_verifier_service()
|
| 328 |
+
verification = verifier.verify_answer(
|
| 329 |
+
draft_answer=draft_answer,
|
| 330 |
+
context=context,
|
| 331 |
+
citations_info=citations_info
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Step 3: Handle verification result
|
| 335 |
+
if verification["pass"]:
|
| 336 |
+
logger.info("✅ Verifier PASSED - Using draft answer")
|
| 337 |
+
citations = self._extract_citations(draft_answer, citations_info)
|
| 338 |
+
return {
|
| 339 |
+
"answer": draft_answer,
|
| 340 |
+
"citations": citations,
|
| 341 |
+
"confidence": confidence,
|
| 342 |
+
"from_knowledge_base": True,
|
| 343 |
+
"escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD,
|
| 344 |
+
"verifier_passed": True,
|
| 345 |
+
"refused": False,
|
| 346 |
+
"usage": usage_info # Include usage info for tracking
|
| 347 |
+
}
|
| 348 |
+
else:
|
| 349 |
+
# Verifier failed - REFUSE to answer
|
| 350 |
+
issues = verification.get('issues', [])
|
| 351 |
+
unsupported = verification.get('unsupported_claims', [])
|
| 352 |
+
logger.warning(
|
| 353 |
+
f"❌ Verifier FAILED - Issues: {issues}, "
|
| 354 |
+
f"Unsupported claims: {unsupported}"
|
| 355 |
+
)
|
| 356 |
+
refusal_message = (
|
| 357 |
+
get_no_context_response() +
|
| 358 |
+
"\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. "
|
| 359 |
+
"This helps prevent providing incorrect information."
|
| 360 |
+
)
|
| 361 |
+
return {
|
| 362 |
+
"answer": refusal_message,
|
| 363 |
+
"citations": [],
|
| 364 |
+
"confidence": 0.0,
|
| 365 |
+
"from_knowledge_base": False,
|
| 366 |
+
"escalation_suggested": True,
|
| 367 |
+
"verifier_passed": False,
|
| 368 |
+
"verifier_issues": issues,
|
| 369 |
+
"unsupported_claims": unsupported,
|
| 370 |
+
"refused": True,
|
| 371 |
+
"refusal_reason": "Verifier failed: claims not supported by context",
|
| 372 |
+
"usage": usage_info # Still track usage even if refused
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
except ValueError as e:
|
| 376 |
+
# Configuration errors (e.g., missing API key)
|
| 377 |
+
error_msg = str(e)
|
| 378 |
+
logger.error(f"Configuration error in answer generation: {error_msg}")
|
| 379 |
+
if "API key" in error_msg.lower():
|
| 380 |
+
raise ValueError(f"LLM API key not configured: {error_msg}")
|
| 381 |
+
raise
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logger.error(f"Error generating answer: {e}", exc_info=True)
|
| 384 |
+
# Re-raise to be handled by the endpoint
|
| 385 |
+
raise
|
| 386 |
+
|
| 387 |
+
def _extract_citations(
|
| 388 |
+
self,
|
| 389 |
+
answer: str,
|
| 390 |
+
citations_info: List[Dict[str, Any]]
|
| 391 |
+
) -> List[Citation]:
|
| 392 |
+
"""
|
| 393 |
+
Extract and format citations from the answer.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
answer: Generated answer with [Source X] references
|
| 397 |
+
citations_info: Available citation information
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
List of Citation objects
|
| 401 |
+
"""
|
| 402 |
+
citations = []
|
| 403 |
+
|
| 404 |
+
# Find all [Source X] references in the answer
|
| 405 |
+
source_pattern = r'\[Source\s*(\d+)\]'
|
| 406 |
+
matches = re.findall(source_pattern, answer)
|
| 407 |
+
referenced_indices = set(int(m) for m in matches)
|
| 408 |
+
|
| 409 |
+
# Build citation objects for referenced sources
|
| 410 |
+
for info in citations_info:
|
| 411 |
+
if info.get("index") in referenced_indices:
|
| 412 |
+
citations.append(Citation(
|
| 413 |
+
file_name=info.get("file_name", "Unknown"),
|
| 414 |
+
chunk_id=info.get("chunk_id", ""),
|
| 415 |
+
page_number=info.get("page_number"),
|
| 416 |
+
relevance_score=info.get("similarity_score", 0.0),
|
| 417 |
+
excerpt=info.get("excerpt", "")
|
| 418 |
+
))
|
| 419 |
+
|
| 420 |
+
# If no specific citations found but we have context, include top sources
|
| 421 |
+
if not citations and citations_info:
|
| 422 |
+
for info in citations_info[:3]: # Top 3 sources
|
| 423 |
+
citations.append(Citation(
|
| 424 |
+
file_name=info.get("file_name", "Unknown"),
|
| 425 |
+
chunk_id=info.get("chunk_id", ""),
|
| 426 |
+
page_number=info.get("page_number"),
|
| 427 |
+
relevance_score=info.get("similarity_score", 0.0),
|
| 428 |
+
excerpt=info.get("excerpt", "")
|
| 429 |
+
))
|
| 430 |
+
|
| 431 |
+
return citations
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# Global answer service instance
|
| 435 |
+
_answer_service: Optional[AnswerService] = None
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def get_answer_service() -> AnswerService:
|
| 439 |
+
"""Get the global answer service instance."""
|
| 440 |
+
global _answer_service
|
| 441 |
+
if _answer_service is None:
|
| 442 |
+
_answer_service = AnswerService()
|
| 443 |
+
return _answer_service
|
| 444 |
+
|
app/rag/chunking.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Document chunking with overlap and metadata preservation.
|
| 3 |
-
"""
|
| 4 |
-
import tiktoken
|
| 5 |
-
from typing import List, Dict, Any, Optional
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
import re
|
| 8 |
-
import uuid
|
| 9 |
-
from datetime import datetime
|
| 10 |
-
|
| 11 |
-
from app.config import settings
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
@dataclass
|
| 15 |
-
class TextChunk:
|
| 16 |
-
"""Represents a chunk of text with metadata."""
|
| 17 |
-
content: str
|
| 18 |
-
chunk_index: int
|
| 19 |
-
start_char: int
|
| 20 |
-
end_char: int
|
| 21 |
-
page_number: Optional[int] = None
|
| 22 |
-
token_count: int = 0
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class DocumentChunker:
|
| 26 |
-
"""
|
| 27 |
-
Chunks documents into smaller pieces with overlap.
|
| 28 |
-
Uses tiktoken for accurate token counting.
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
chunk_size: int = settings.CHUNK_SIZE,
|
| 34 |
-
chunk_overlap: int = settings.CHUNK_OVERLAP,
|
| 35 |
-
min_chunk_size: int = settings.MIN_CHUNK_SIZE
|
| 36 |
-
):
|
| 37 |
-
self.chunk_size = chunk_size
|
| 38 |
-
self.chunk_overlap = chunk_overlap
|
| 39 |
-
self.min_chunk_size = min_chunk_size
|
| 40 |
-
# Use cl100k_base encoding (same as GPT-4, good general purpose)
|
| 41 |
-
self.encoding = tiktoken.get_encoding("cl100k_base")
|
| 42 |
-
|
| 43 |
-
def count_tokens(self, text: str) -> int:
|
| 44 |
-
"""Count tokens in text."""
|
| 45 |
-
return len(self.encoding.encode(text))
|
| 46 |
-
|
| 47 |
-
def _split_into_sentences(self, text: str) -> List[str]:
|
| 48 |
-
"""Split text into sentences while preserving structure."""
|
| 49 |
-
# Split on sentence boundaries but keep delimiters
|
| 50 |
-
sentence_endings = r'(?<=[.!?])\s+'
|
| 51 |
-
sentences = re.split(sentence_endings, text)
|
| 52 |
-
return [s.strip() for s in sentences if s.strip()]
|
| 53 |
-
|
| 54 |
-
def _split_into_paragraphs(self, text: str) -> List[str]:
|
| 55 |
-
"""Split text into paragraphs."""
|
| 56 |
-
paragraphs = re.split(r'\n\s*\n', text)
|
| 57 |
-
return [p.strip() for p in paragraphs if p.strip()]
|
| 58 |
-
|
| 59 |
-
def chunk_text(
|
| 60 |
-
self,
|
| 61 |
-
text: str,
|
| 62 |
-
page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number
|
| 63 |
-
) -> List[TextChunk]:
|
| 64 |
-
"""
|
| 65 |
-
Chunk text into smaller pieces with overlap.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
text: The text to chunk
|
| 69 |
-
page_numbers: Optional mapping of character positions to page numbers
|
| 70 |
-
|
| 71 |
-
Returns:
|
| 72 |
-
List of TextChunk objects
|
| 73 |
-
"""
|
| 74 |
-
if not text.strip():
|
| 75 |
-
return []
|
| 76 |
-
|
| 77 |
-
chunks = []
|
| 78 |
-
current_chunk = ""
|
| 79 |
-
current_start = 0
|
| 80 |
-
chunk_index = 0
|
| 81 |
-
|
| 82 |
-
# First, split into paragraphs for natural boundaries
|
| 83 |
-
paragraphs = self._split_into_paragraphs(text)
|
| 84 |
-
|
| 85 |
-
char_position = 0
|
| 86 |
-
for para in paragraphs:
|
| 87 |
-
para_tokens = self.count_tokens(para)
|
| 88 |
-
current_tokens = self.count_tokens(current_chunk)
|
| 89 |
-
|
| 90 |
-
# If adding this paragraph exceeds chunk size
|
| 91 |
-
if current_tokens + para_tokens > self.chunk_size and current_chunk:
|
| 92 |
-
# Save current chunk if it meets minimum size
|
| 93 |
-
if current_tokens >= self.min_chunk_size:
|
| 94 |
-
page_num = None
|
| 95 |
-
if page_numbers:
|
| 96 |
-
# Find the page number for this chunk's start position
|
| 97 |
-
for pos, page in sorted(page_numbers.items()):
|
| 98 |
-
if pos <= current_start:
|
| 99 |
-
page_num = page
|
| 100 |
-
|
| 101 |
-
chunks.append(TextChunk(
|
| 102 |
-
content=current_chunk.strip(),
|
| 103 |
-
chunk_index=chunk_index,
|
| 104 |
-
start_char=current_start,
|
| 105 |
-
end_char=char_position,
|
| 106 |
-
page_number=page_num,
|
| 107 |
-
token_count=current_tokens
|
| 108 |
-
))
|
| 109 |
-
chunk_index += 1
|
| 110 |
-
|
| 111 |
-
# Start new chunk with overlap
|
| 112 |
-
overlap_text = self._get_overlap_text(current_chunk)
|
| 113 |
-
current_chunk = overlap_text + "\n\n" + para if overlap_text else para
|
| 114 |
-
current_start = char_position - len(overlap_text) if overlap_text else char_position
|
| 115 |
-
else:
|
| 116 |
-
# Add paragraph to current chunk
|
| 117 |
-
if current_chunk:
|
| 118 |
-
current_chunk += "\n\n" + para
|
| 119 |
-
else:
|
| 120 |
-
current_chunk = para
|
| 121 |
-
current_start = char_position
|
| 122 |
-
|
| 123 |
-
char_position += len(para) + 2 # +2 for paragraph separator
|
| 124 |
-
|
| 125 |
-
# Don't forget the last chunk
|
| 126 |
-
if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size:
|
| 127 |
-
page_num = None
|
| 128 |
-
if page_numbers:
|
| 129 |
-
for pos, page in sorted(page_numbers.items()):
|
| 130 |
-
if pos <= current_start:
|
| 131 |
-
page_num = page
|
| 132 |
-
|
| 133 |
-
chunks.append(TextChunk(
|
| 134 |
-
content=current_chunk.strip(),
|
| 135 |
-
chunk_index=chunk_index,
|
| 136 |
-
start_char=current_start,
|
| 137 |
-
end_char=len(text),
|
| 138 |
-
page_number=page_num,
|
| 139 |
-
token_count=self.count_tokens(current_chunk)
|
| 140 |
-
))
|
| 141 |
-
|
| 142 |
-
return chunks
|
| 143 |
-
|
| 144 |
-
def _get_overlap_text(self, text: str) -> str:
|
| 145 |
-
"""Get the overlap text from the end of a chunk."""
|
| 146 |
-
sentences = self._split_into_sentences(text)
|
| 147 |
-
if not sentences:
|
| 148 |
-
return ""
|
| 149 |
-
|
| 150 |
-
overlap = ""
|
| 151 |
-
tokens = 0
|
| 152 |
-
|
| 153 |
-
# Work backwards through sentences
|
| 154 |
-
for sentence in reversed(sentences):
|
| 155 |
-
sentence_tokens = self.count_tokens(sentence)
|
| 156 |
-
if tokens + sentence_tokens <= self.chunk_overlap:
|
| 157 |
-
overlap = sentence + " " + overlap if overlap else sentence
|
| 158 |
-
tokens += sentence_tokens
|
| 159 |
-
else:
|
| 160 |
-
break
|
| 161 |
-
|
| 162 |
-
return overlap.strip()
|
| 163 |
-
|
| 164 |
-
def create_chunk_metadata(
|
| 165 |
-
self,
|
| 166 |
-
chunk: TextChunk,
|
| 167 |
-
tenant_id: str, # CRITICAL: Multi-tenant isolation
|
| 168 |
-
kb_id: str,
|
| 169 |
-
user_id: str,
|
| 170 |
-
file_name: str,
|
| 171 |
-
file_type: str,
|
| 172 |
-
total_chunks: int,
|
| 173 |
-
document_id: Optional[str] = None
|
| 174 |
-
) -> Dict[str, Any]:
|
| 175 |
-
"""Create metadata dictionary for a chunk."""
|
| 176 |
-
chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}"
|
| 177 |
-
|
| 178 |
-
return {
|
| 179 |
-
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 180 |
-
"kb_id": kb_id,
|
| 181 |
-
"user_id": user_id,
|
| 182 |
-
"file_name": file_name,
|
| 183 |
-
"file_type": file_type,
|
| 184 |
-
"chunk_id": chunk_id,
|
| 185 |
-
"chunk_index": chunk.chunk_index,
|
| 186 |
-
"page_number": chunk.page_number,
|
| 187 |
-
"total_chunks": total_chunks,
|
| 188 |
-
"token_count": chunk.token_count,
|
| 189 |
-
"document_id": document_id, # Track original document
|
| 190 |
-
"created_at": datetime.utcnow().isoformat()
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
# Global chunker instance
|
| 195 |
-
chunker = DocumentChunker()
|
| 196 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document chunking with overlap and metadata preservation.
|
| 3 |
+
"""
|
| 4 |
+
import tiktoken
|
| 5 |
+
from typing import List, Dict, Any, Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
import re
|
| 8 |
+
import uuid
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from app.config import settings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TextChunk:
|
| 16 |
+
"""Represents a chunk of text with metadata."""
|
| 17 |
+
content: str
|
| 18 |
+
chunk_index: int
|
| 19 |
+
start_char: int
|
| 20 |
+
end_char: int
|
| 21 |
+
page_number: Optional[int] = None
|
| 22 |
+
token_count: int = 0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DocumentChunker:
|
| 26 |
+
"""
|
| 27 |
+
Chunks documents into smaller pieces with overlap.
|
| 28 |
+
Uses tiktoken for accurate token counting.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
chunk_size: int = settings.CHUNK_SIZE,
|
| 34 |
+
chunk_overlap: int = settings.CHUNK_OVERLAP,
|
| 35 |
+
min_chunk_size: int = settings.MIN_CHUNK_SIZE
|
| 36 |
+
):
|
| 37 |
+
self.chunk_size = chunk_size
|
| 38 |
+
self.chunk_overlap = chunk_overlap
|
| 39 |
+
self.min_chunk_size = min_chunk_size
|
| 40 |
+
# Use cl100k_base encoding (same as GPT-4, good general purpose)
|
| 41 |
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
| 42 |
+
|
| 43 |
+
def count_tokens(self, text: str) -> int:
|
| 44 |
+
"""Count tokens in text."""
|
| 45 |
+
return len(self.encoding.encode(text))
|
| 46 |
+
|
| 47 |
+
def _split_into_sentences(self, text: str) -> List[str]:
|
| 48 |
+
"""Split text into sentences while preserving structure."""
|
| 49 |
+
# Split on sentence boundaries but keep delimiters
|
| 50 |
+
sentence_endings = r'(?<=[.!?])\s+'
|
| 51 |
+
sentences = re.split(sentence_endings, text)
|
| 52 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 53 |
+
|
| 54 |
+
def _split_into_paragraphs(self, text: str) -> List[str]:
|
| 55 |
+
"""Split text into paragraphs."""
|
| 56 |
+
paragraphs = re.split(r'\n\s*\n', text)
|
| 57 |
+
return [p.strip() for p in paragraphs if p.strip()]
|
| 58 |
+
|
| 59 |
+
def chunk_text(
|
| 60 |
+
self,
|
| 61 |
+
text: str,
|
| 62 |
+
page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number
|
| 63 |
+
) -> List[TextChunk]:
|
| 64 |
+
"""
|
| 65 |
+
Chunk text into smaller pieces with overlap.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
text: The text to chunk
|
| 69 |
+
page_numbers: Optional mapping of character positions to page numbers
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of TextChunk objects
|
| 73 |
+
"""
|
| 74 |
+
if not text.strip():
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
chunks = []
|
| 78 |
+
current_chunk = ""
|
| 79 |
+
current_start = 0
|
| 80 |
+
chunk_index = 0
|
| 81 |
+
|
| 82 |
+
# First, split into paragraphs for natural boundaries
|
| 83 |
+
paragraphs = self._split_into_paragraphs(text)
|
| 84 |
+
|
| 85 |
+
char_position = 0
|
| 86 |
+
for para in paragraphs:
|
| 87 |
+
para_tokens = self.count_tokens(para)
|
| 88 |
+
current_tokens = self.count_tokens(current_chunk)
|
| 89 |
+
|
| 90 |
+
# If adding this paragraph exceeds chunk size
|
| 91 |
+
if current_tokens + para_tokens > self.chunk_size and current_chunk:
|
| 92 |
+
# Save current chunk if it meets minimum size
|
| 93 |
+
if current_tokens >= self.min_chunk_size:
|
| 94 |
+
page_num = None
|
| 95 |
+
if page_numbers:
|
| 96 |
+
# Find the page number for this chunk's start position
|
| 97 |
+
for pos, page in sorted(page_numbers.items()):
|
| 98 |
+
if pos <= current_start:
|
| 99 |
+
page_num = page
|
| 100 |
+
|
| 101 |
+
chunks.append(TextChunk(
|
| 102 |
+
content=current_chunk.strip(),
|
| 103 |
+
chunk_index=chunk_index,
|
| 104 |
+
start_char=current_start,
|
| 105 |
+
end_char=char_position,
|
| 106 |
+
page_number=page_num,
|
| 107 |
+
token_count=current_tokens
|
| 108 |
+
))
|
| 109 |
+
chunk_index += 1
|
| 110 |
+
|
| 111 |
+
# Start new chunk with overlap
|
| 112 |
+
overlap_text = self._get_overlap_text(current_chunk)
|
| 113 |
+
current_chunk = overlap_text + "\n\n" + para if overlap_text else para
|
| 114 |
+
current_start = char_position - len(overlap_text) if overlap_text else char_position
|
| 115 |
+
else:
|
| 116 |
+
# Add paragraph to current chunk
|
| 117 |
+
if current_chunk:
|
| 118 |
+
current_chunk += "\n\n" + para
|
| 119 |
+
else:
|
| 120 |
+
current_chunk = para
|
| 121 |
+
current_start = char_position
|
| 122 |
+
|
| 123 |
+
char_position += len(para) + 2 # +2 for paragraph separator
|
| 124 |
+
|
| 125 |
+
# Don't forget the last chunk
|
| 126 |
+
if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size:
|
| 127 |
+
page_num = None
|
| 128 |
+
if page_numbers:
|
| 129 |
+
for pos, page in sorted(page_numbers.items()):
|
| 130 |
+
if pos <= current_start:
|
| 131 |
+
page_num = page
|
| 132 |
+
|
| 133 |
+
chunks.append(TextChunk(
|
| 134 |
+
content=current_chunk.strip(),
|
| 135 |
+
chunk_index=chunk_index,
|
| 136 |
+
start_char=current_start,
|
| 137 |
+
end_char=len(text),
|
| 138 |
+
page_number=page_num,
|
| 139 |
+
token_count=self.count_tokens(current_chunk)
|
| 140 |
+
))
|
| 141 |
+
|
| 142 |
+
return chunks
|
| 143 |
+
|
| 144 |
+
def _get_overlap_text(self, text: str) -> str:
|
| 145 |
+
"""Get the overlap text from the end of a chunk."""
|
| 146 |
+
sentences = self._split_into_sentences(text)
|
| 147 |
+
if not sentences:
|
| 148 |
+
return ""
|
| 149 |
+
|
| 150 |
+
overlap = ""
|
| 151 |
+
tokens = 0
|
| 152 |
+
|
| 153 |
+
# Work backwards through sentences
|
| 154 |
+
for sentence in reversed(sentences):
|
| 155 |
+
sentence_tokens = self.count_tokens(sentence)
|
| 156 |
+
if tokens + sentence_tokens <= self.chunk_overlap:
|
| 157 |
+
overlap = sentence + " " + overlap if overlap else sentence
|
| 158 |
+
tokens += sentence_tokens
|
| 159 |
+
else:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
return overlap.strip()
|
| 163 |
+
|
| 164 |
+
def create_chunk_metadata(
|
| 165 |
+
self,
|
| 166 |
+
chunk: TextChunk,
|
| 167 |
+
tenant_id: str, # CRITICAL: Multi-tenant isolation
|
| 168 |
+
kb_id: str,
|
| 169 |
+
user_id: str,
|
| 170 |
+
file_name: str,
|
| 171 |
+
file_type: str,
|
| 172 |
+
total_chunks: int,
|
| 173 |
+
document_id: Optional[str] = None
|
| 174 |
+
) -> Dict[str, Any]:
|
| 175 |
+
"""Create metadata dictionary for a chunk."""
|
| 176 |
+
chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}"
|
| 177 |
+
|
| 178 |
+
return {
|
| 179 |
+
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
|
| 180 |
+
"kb_id": kb_id,
|
| 181 |
+
"user_id": user_id,
|
| 182 |
+
"file_name": file_name,
|
| 183 |
+
"file_type": file_type,
|
| 184 |
+
"chunk_id": chunk_id,
|
| 185 |
+
"chunk_index": chunk.chunk_index,
|
| 186 |
+
"page_number": chunk.page_number,
|
| 187 |
+
"total_chunks": total_chunks,
|
| 188 |
+
"token_count": chunk.token_count,
|
| 189 |
+
"document_id": document_id, # Track original document
|
| 190 |
+
"created_at": datetime.utcnow().isoformat()
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# Global chunker instance
|
| 195 |
+
chunker = DocumentChunker()
|
| 196 |
+
|
app/rag/embeddings.py
CHANGED
|
@@ -1,145 +1,145 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Embedding generation using Sentence Transformers.
|
| 3 |
-
Supports local models for privacy and offline use.
|
| 4 |
-
"""
|
| 5 |
-
from sentence_transformers import SentenceTransformer
|
| 6 |
-
from typing import List, Optional
|
| 7 |
-
import numpy as np
|
| 8 |
-
import logging
|
| 9 |
-
from functools import lru_cache
|
| 10 |
-
|
| 11 |
-
from app.config import settings
|
| 12 |
-
|
| 13 |
-
logging.basicConfig(level=logging.INFO)
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class EmbeddingService:
|
| 18 |
-
"""
|
| 19 |
-
Generates embeddings for text using Sentence Transformers.
|
| 20 |
-
Uses a lightweight model optimized for semantic search.
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
def __init__(self, model_name: str = settings.EMBEDDING_MODEL):
|
| 24 |
-
"""
|
| 25 |
-
Initialize the embedding service.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
model_name: Name of the Sentence Transformer model to use
|
| 29 |
-
"""
|
| 30 |
-
self.model_name = model_name
|
| 31 |
-
self._model: Optional[SentenceTransformer] = None
|
| 32 |
-
logger.info(f"Embedding service initialized with model: {model_name}")
|
| 33 |
-
|
| 34 |
-
@property
|
| 35 |
-
def model(self) -> SentenceTransformer:
|
| 36 |
-
"""Lazy load the model."""
|
| 37 |
-
if self._model is None:
|
| 38 |
-
logger.info(f"Loading embedding model: {self.model_name}")
|
| 39 |
-
self._model = SentenceTransformer(self.model_name)
|
| 40 |
-
logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}")
|
| 41 |
-
return self._model
|
| 42 |
-
|
| 43 |
-
def embed_text(self, text: str) -> List[float]:
|
| 44 |
-
"""
|
| 45 |
-
Generate embedding for a single text.
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
text: Text to embed
|
| 49 |
-
|
| 50 |
-
Returns:
|
| 51 |
-
List of floats representing the embedding vector
|
| 52 |
-
"""
|
| 53 |
-
if not text.strip():
|
| 54 |
-
raise ValueError("Cannot embed empty text")
|
| 55 |
-
|
| 56 |
-
embedding = self.model.encode(text, convert_to_numpy=True)
|
| 57 |
-
return embedding.tolist()
|
| 58 |
-
|
| 59 |
-
def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
|
| 60 |
-
"""
|
| 61 |
-
Generate embeddings for multiple texts.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
texts: List of texts to embed
|
| 65 |
-
batch_size: Batch size for processing
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
List of embedding vectors
|
| 69 |
-
"""
|
| 70 |
-
if not texts:
|
| 71 |
-
return []
|
| 72 |
-
|
| 73 |
-
# Filter out empty texts
|
| 74 |
-
valid_texts = [t for t in texts if t.strip()]
|
| 75 |
-
if len(valid_texts) != len(texts):
|
| 76 |
-
logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts")
|
| 77 |
-
|
| 78 |
-
logger.info(f"Generating embeddings for {len(valid_texts)} texts")
|
| 79 |
-
|
| 80 |
-
embeddings = self.model.encode(
|
| 81 |
-
valid_texts,
|
| 82 |
-
batch_size=batch_size,
|
| 83 |
-
show_progress_bar=len(valid_texts) > 100,
|
| 84 |
-
convert_to_numpy=True
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
return embeddings.tolist()
|
| 88 |
-
|
| 89 |
-
def embed_query(self, query: str) -> List[float]:
|
| 90 |
-
"""
|
| 91 |
-
Generate embedding for a search query.
|
| 92 |
-
Some models have different embeddings for queries vs documents.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
query: Search query to embed
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
Embedding vector for the query
|
| 99 |
-
"""
|
| 100 |
-
# For most models, query embedding is the same as document embedding
|
| 101 |
-
# But we keep this separate for models that differentiate
|
| 102 |
-
return self.embed_text(query)
|
| 103 |
-
|
| 104 |
-
def get_dimension(self) -> int:
|
| 105 |
-
"""Get the embedding dimension."""
|
| 106 |
-
return self.model.get_sentence_embedding_dimension()
|
| 107 |
-
|
| 108 |
-
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
| 109 |
-
"""
|
| 110 |
-
Compute cosine similarity between two embeddings.
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
embedding1: First embedding vector
|
| 114 |
-
embedding2: Second embedding vector
|
| 115 |
-
|
| 116 |
-
Returns:
|
| 117 |
-
Cosine similarity score (0-1)
|
| 118 |
-
"""
|
| 119 |
-
vec1 = np.array(embedding1)
|
| 120 |
-
vec2 = np.array(embedding2)
|
| 121 |
-
|
| 122 |
-
# Cosine similarity
|
| 123 |
-
dot_product = np.dot(vec1, vec2)
|
| 124 |
-
norm1 = np.linalg.norm(vec1)
|
| 125 |
-
norm2 = np.linalg.norm(vec2)
|
| 126 |
-
|
| 127 |
-
if norm1 == 0 or norm2 == 0:
|
| 128 |
-
return 0.0
|
| 129 |
-
|
| 130 |
-
return float(dot_product / (norm1 * norm2))
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# Global embedding service instance (lazy loaded)
|
| 134 |
-
_embedding_service: Optional[EmbeddingService] = None
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def get_embedding_service() -> EmbeddingService:
|
| 138 |
-
"""Get the global embedding service instance."""
|
| 139 |
-
global _embedding_service
|
| 140 |
-
if _embedding_service is None:
|
| 141 |
-
_embedding_service = EmbeddingService()
|
| 142 |
-
return _embedding_service
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding generation using Sentence Transformers.
|
| 3 |
+
Supports local models for privacy and offline use.
|
| 4 |
+
"""
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
import logging
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
|
| 11 |
+
from app.config import settings
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EmbeddingService:
|
| 18 |
+
"""
|
| 19 |
+
Generates embeddings for text using Sentence Transformers.
|
| 20 |
+
Uses a lightweight model optimized for semantic search.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_name: str = settings.EMBEDDING_MODEL):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the embedding service.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_name: Name of the Sentence Transformer model to use
|
| 29 |
+
"""
|
| 30 |
+
self.model_name = model_name
|
| 31 |
+
self._model: Optional[SentenceTransformer] = None
|
| 32 |
+
logger.info(f"Embedding service initialized with model: {model_name}")
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def model(self) -> SentenceTransformer:
|
| 36 |
+
"""Lazy load the model."""
|
| 37 |
+
if self._model is None:
|
| 38 |
+
logger.info(f"Loading embedding model: {self.model_name}")
|
| 39 |
+
self._model = SentenceTransformer(self.model_name)
|
| 40 |
+
logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}")
|
| 41 |
+
return self._model
|
| 42 |
+
|
| 43 |
+
def embed_text(self, text: str) -> List[float]:
|
| 44 |
+
"""
|
| 45 |
+
Generate embedding for a single text.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
text: Text to embed
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
List of floats representing the embedding vector
|
| 52 |
+
"""
|
| 53 |
+
if not text.strip():
|
| 54 |
+
raise ValueError("Cannot embed empty text")
|
| 55 |
+
|
| 56 |
+
embedding = self.model.encode(text, convert_to_numpy=True)
|
| 57 |
+
return embedding.tolist()
|
| 58 |
+
|
| 59 |
+
def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
|
| 60 |
+
"""
|
| 61 |
+
Generate embeddings for multiple texts.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
texts: List of texts to embed
|
| 65 |
+
batch_size: Batch size for processing
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of embedding vectors
|
| 69 |
+
"""
|
| 70 |
+
if not texts:
|
| 71 |
+
return []
|
| 72 |
+
|
| 73 |
+
# Filter out empty texts
|
| 74 |
+
valid_texts = [t for t in texts if t.strip()]
|
| 75 |
+
if len(valid_texts) != len(texts):
|
| 76 |
+
logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts")
|
| 77 |
+
|
| 78 |
+
logger.info(f"Generating embeddings for {len(valid_texts)} texts")
|
| 79 |
+
|
| 80 |
+
embeddings = self.model.encode(
|
| 81 |
+
valid_texts,
|
| 82 |
+
batch_size=batch_size,
|
| 83 |
+
show_progress_bar=len(valid_texts) > 100,
|
| 84 |
+
convert_to_numpy=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return embeddings.tolist()
|
| 88 |
+
|
| 89 |
+
def embed_query(self, query: str) -> List[float]:
|
| 90 |
+
"""
|
| 91 |
+
Generate embedding for a search query.
|
| 92 |
+
Some models have different embeddings for queries vs documents.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
query: Search query to embed
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Embedding vector for the query
|
| 99 |
+
"""
|
| 100 |
+
# For most models, query embedding is the same as document embedding
|
| 101 |
+
# But we keep this separate for models that differentiate
|
| 102 |
+
return self.embed_text(query)
|
| 103 |
+
|
| 104 |
+
def get_dimension(self) -> int:
|
| 105 |
+
"""Get the embedding dimension."""
|
| 106 |
+
return self.model.get_sentence_embedding_dimension()
|
| 107 |
+
|
| 108 |
+
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Compute cosine similarity between two embeddings.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
embedding1: First embedding vector
|
| 114 |
+
embedding2: Second embedding vector
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Cosine similarity score (0-1)
|
| 118 |
+
"""
|
| 119 |
+
vec1 = np.array(embedding1)
|
| 120 |
+
vec2 = np.array(embedding2)
|
| 121 |
+
|
| 122 |
+
# Cosine similarity
|
| 123 |
+
dot_product = np.dot(vec1, vec2)
|
| 124 |
+
norm1 = np.linalg.norm(vec1)
|
| 125 |
+
norm2 = np.linalg.norm(vec2)
|
| 126 |
+
|
| 127 |
+
if norm1 == 0 or norm2 == 0:
|
| 128 |
+
return 0.0
|
| 129 |
+
|
| 130 |
+
return float(dot_product / (norm1 * norm2))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# Global embedding service instance (lazy loaded)
|
| 134 |
+
_embedding_service: Optional[EmbeddingService] = None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_embedding_service() -> EmbeddingService:
|
| 138 |
+
"""Get the global embedding service instance."""
|
| 139 |
+
global _embedding_service
|
| 140 |
+
if _embedding_service is None:
|
| 141 |
+
_embedding_service = EmbeddingService()
|
| 142 |
+
return _embedding_service
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|