diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 52252fbf3d731f6c0ff561ae05589bb95d222a09..0000000000000000000000000000000000000000 --- a/.gitignore +++ /dev/null @@ -1,37 +0,0 @@ -๏ปฟ# Python files -*.py -*.pyc -__pycache__/ - -# Config files -*.txt -*.yaml -*.yml -*.toml -*.json -*.md -*.sh -*.bat -*.ps1 - -# Directories to include -app/ -requirements.txt -Dockerfile -app.py -README.md -.gitignore -.env.example.txt - -# Exclude binaries -*.png -*.jpg -*.jpeg -*.db -*.pdf -*.bin -*.sqlite3 -public/ -data/billing/ -data/vectordb/ -venv/ diff --git a/Dockerfile b/Dockerfile index 81583ce08f916ed699d485e323d5915650e20a55..5c82b05724a43d52a7262719defed8d0ff7928c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,35 +1,34 @@ -FROM python:3.11-slim - -# Set working directory -WORKDIR /app - -# Install system dependencies -RUN apt-get update && apt-get install -y \ - gcc \ - g++ \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first (for better caching) -COPY requirements.txt . - -# Install Python dependencies -RUN pip install --no-cache-dir -r requirements.txt - -# Copy application code -COPY . . - -# Create necessary directories -RUN mkdir -p data/uploads data/processed data/vectordb data/billing - -# Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var) -EXPOSE 7860 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1 - -# Start the application using app.py (as specified in README.md) -# Hugging Face Spaces will run app.py when app_file is set -CMD ["python", "app.py"] - +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first (for better caching) +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create necessary directories +RUN mkdir -p data/uploads data/processed data/vectordb data/billing + +# Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var) +EXPOSE 7860 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1 + +# Start the application (Hugging Face Spaces provides PORT env var) +CMD uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860} + diff --git a/README.md b/README.md index 098f39f9861376a5fc28a208d4e8d9c43f8ecd3b..7ae36e506f47ac35a51ef9239ffb525836cfd2dc 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,47 @@ ---- -title: ClientSphere RAG Backend -emoji: ๐Ÿค– -colorFrom: blue -colorTo: purple -sdk: docker -sdk_version: "4.0.0" -python_version: "3.11" -app_file: app.py -pinned: false ---- - -# ClientSphere RAG Backend - -FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform. - -## Features - -- ๐Ÿ“š Knowledge base management -- ๐Ÿ” Semantic search with embeddings -- ๐Ÿ’ฌ AI-powered chat with citations -- ๐Ÿ“Š Confidence scoring -- ๐Ÿ”’ Multi-tenant isolation -- ๐Ÿ“ˆ Usage tracking and billing - -## API Endpoints - -- `GET /health/live` - Health check -- `GET /kb/stats` - Knowledge base statistics -- `POST /kb/upload` - Upload documents -- `POST /chat` - Chat with RAG -- `GET /kb/search` - Search knowledge base - -## Environment Variables - -Required: -- `GEMINI_API_KEY` - Google Gemini API key -- `ENV` - Set to `prod` for production -- `LLM_PROVIDER` - `gemini` or `openai` - -Optional: -- `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated) -- `JWT_SECRET` - JWT secret for authentication - -## Documentation - -See `README_HF_SPACES.md` for deployment details. +--- +title: ClientSphere RAG Backend +emoji: ๐Ÿค– +colorFrom: blue +colorTo: purple +sdk: docker +sdk_version: "4.0.0" +python_version: "3.11" +app_file: app.py +pinned: false +--- + +# ClientSphere RAG Backend + +FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform. + +## Features + +- ๐Ÿ“š Knowledge base management +- ๐Ÿ” Semantic search with embeddings +- ๐Ÿ’ฌ AI-powered chat with citations +- ๐Ÿ“Š Confidence scoring +- ๐Ÿ”’ Multi-tenant isolation +- ๐Ÿ“ˆ Usage tracking and billing + +## API Endpoints + +- `GET /health/live` - Health check +- `GET /kb/stats` - Knowledge base statistics +- `POST /kb/upload` - Upload documents +- `POST /chat` - Chat with RAG +- `GET /kb/search` - Search knowledge base + +## Environment Variables + +Required: +- `GEMINI_API_KEY` - Google Gemini API key +- `ENV` - Set to `prod` for production +- `LLM_PROVIDER` - `gemini` or `openai` + +Optional: +- `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated) +- `JWT_SECRET` - JWT secret for authentication + +## Documentation + +See `README_HF_SPACES.md` for deployment details. diff --git a/README_HF_SPACES.md b/README_HF_SPACES.md index 428acd39865110495f94d2ee6bf1ada09f3a5ff4..60bcc8a6a05420b72da7251c3aee611ba2807ec9 100644 --- a/README_HF_SPACES.md +++ b/README_HF_SPACES.md @@ -1,231 +1,230 @@ -# ๐Ÿš€ Deploy RAG Backend to Hugging Face Spaces - -Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies! - -## โœ… Why Hugging Face Spaces? - -- โœ… **Free tier** with generous limits -- โœ… **Full Python 3.11+** support -- โœ… **ML libraries** fully supported (sentence-transformers, chromadb, etc.) -- โœ… **Persistent storage** for vector database -- โœ… **No bundle size limits** -- โœ… **GPU support** available (paid) -- โœ… **Automatic HTTPS** and custom domains -- โœ… **GitHub integration** (auto-deploy on push) - -## ๐Ÿ“‹ Prerequisites - -1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co) -2. **GitHub Repository**: Your code should be in a GitHub repository -3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey) - -## ๐Ÿš€ Step-by-Step Deployment - -### Step 1: Prepare Your Repository - -Your `rag-backend/` directory should contain: -- โœ… `app.py` - Entry point (already created) -- โœ… `requirements.txt` - Dependencies -- โœ… `app/main.py` - FastAPI application -- โœ… All other application files - -### Step 2: Create Hugging Face Space - -1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) -2. Click **"Create new Space"** -3. Configure: - - **Owner**: Your username - - **Space name**: `clientsphere-rag-backend` (or your choice) - - **SDK**: **Docker** (recommended) or **Gradio** (if you want UI) - - **Hardware**: - - **CPU basic** (free) - Good for testing - - **CPU upgrade** (paid) - Better performance - - **GPU** (paid) - For heavy ML workloads - -### Step 3: Connect GitHub Repository - -1. In Space creation, select **"Repository"** as source -2. Choose your GitHub repository -3. Set **Repository path** to: `rag-backend/` (subdirectory) -4. Click **"Create Space"** - -### Step 4: Configure Environment Variables - -1. Go to your Space's **Settings** tab -2. Scroll to **"Repository secrets"** or **"Variables"** -3. Add these secrets: - -**Required:** -``` -GEMINI_API_KEY=your_gemini_api_key_here -ENV=prod -LLM_PROVIDER=gemini -``` - -**Optional (but recommended):** -``` -ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev -JWT_SECRET=your_secure_jwt_secret -DEBUG=false -``` - -### Step 5: Configure Docker (if using Docker SDK) - -If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`. - -**Your existing `Dockerfile` should work!** It's already configured correctly. - -### Step 6: Alternative - Use app.py (Simpler) - -If you want to use the simpler `app.py` approach: - -1. In Space settings, set: - - **SDK**: **Gradio** or **Streamlit** (but we'll override) - - **App file**: `app.py` - -2. Hugging Face will automatically: - - Install dependencies from `requirements.txt` - - Run `python app.py` - - Expose on port 7860 - -### Step 7: Deploy! - -1. **Push to GitHub** (if not already): - ```bash - git add rag-backend/app.py - git commit -m "Add Hugging Face Spaces entry point" - git push origin main - ``` - -2. **Hugging Face will auto-deploy** from your GitHub repo! - -3. **Wait for build** (5-10 minutes first time, faster after) - -4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space` - -## ๐Ÿ”ง Configuration Options - -### Option A: Docker (Recommended) - -**Advantages:** -- Full control over environment -- Can customize Python version -- Better for production - -**Setup:** -- Use existing `Dockerfile` -- Hugging Face will build and run it -- Exposes on port 7860 automatically - -### Option B: app.py (Simpler) - -**Advantages:** -- Simpler setup -- Faster builds -- Good for development - -**Setup:** -- Create `app.py` in `rag-backend/` (already done) -- Hugging Face runs it automatically - -## ๐Ÿ“ Environment Variables Reference - -| Variable | Required | Description | -|----------|----------|-------------| -| `GEMINI_API_KEY` | โœ… Yes | Your Google Gemini API key | -| `ENV` | โœ… Yes | Set to `prod` for production | -| `LLM_PROVIDER` | โœ… Yes | `gemini` or `openai` | -| `ALLOWED_ORIGINS` | โš ๏ธ Recommended | CORS allowed origins (comma-separated) | -| `JWT_SECRET` | โš ๏ธ Recommended | JWT secret for authentication | -| `DEBUG` | โŒ Optional | Set to `false` in production | -| `OPENAI_API_KEY` | โŒ Optional | If using OpenAI instead of Gemini | - -## ๐ŸŒ CORS Configuration - -After deployment, update `ALLOWED_ORIGINS` to include: -- Your Cloudflare Pages frontend URL -- Your Cloudflare Workers backend URL -- Any other origins that need access - -Example: -``` -ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev -``` - -## ๐Ÿ”„ Updating Deployment - -**Automatic (Recommended):** -- Push to GitHub โ†’ Hugging Face auto-deploys - -**Manual:** -- Go to Space โ†’ Settings โ†’ "Rebuild Space" - -## ๐Ÿ“Š Resource Limits - -### Free Tier: -- โœ… **CPU**: Basic (sufficient for RAG) -- โœ… **Storage**: 50GB (plenty for vector DB) -- โœ… **Memory**: 16GB RAM -- โœ… **Build time**: 20 minutes -- โœ… **Sleep after inactivity**: 48 hours (wakes on request) - -### Paid Tiers: -- **CPU upgrade**: Better performance -- **GPU**: For heavy ML workloads -- **No sleep**: Always-on service - -## ๐Ÿงช Testing Deployment - -After deployment, test your endpoints: - -```bash -# Health check -curl https://your-username-clientsphere-rag-backend.hf.space/health/live - -# KB Stats (with auth) -curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test -``` - -## ๐Ÿ”— Update Frontend - -After deployment, update Cloudflare Pages environment variable: - -``` -VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space -``` - -Then redeploy frontend: -```bash -npm run build -npx wrangler pages deploy dist --project-name=clientsphere -``` - -## โœ… Advantages Over Render - -| Feature | Hugging Face Spaces | Render | -|---------|-------------------|--------| -| Free Tier | โœ… Generous | โš ๏ธ Limited | -| ML Libraries | โœ… Full support | โœ… Full support | -| Auto-deploy | โœ… GitHub integration | โœ… GitHub integration | -| Storage | โœ… 50GB free | โš ๏ธ Limited | -| Sleep Mode | โœ… Wakes on request | โŒ No sleep mode | -| GPU Support | โœ… Available | โŒ Not available | -| Community | โœ… Large ML community | โš ๏ธ Smaller | - -## ๐ŸŽฏ Summary - -1. โœ… Create Hugging Face Space -2. โœ… Connect GitHub repository (rag-backend/) -3. โœ… Set environment variables -4. โœ… Deploy (automatic on push) -5. โœ… Update frontend `VITE_RAG_API_URL` -6. โœ… Test and enjoy! - -**Your RAG backend will be live at:** -`https://your-username-clientsphere-rag-backend.hf.space` - ---- - -**Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces) - - +# ๐Ÿš€ Deploy RAG Backend to Hugging Face Spaces + +Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies! + +## โœ… Why Hugging Face Spaces? + +- โœ… **Free tier** with generous limits +- โœ… **Full Python 3.11+** support +- โœ… **ML libraries** fully supported (sentence-transformers, chromadb, etc.) +- โœ… **Persistent storage** for vector database +- โœ… **No bundle size limits** +- โœ… **GPU support** available (paid) +- โœ… **Automatic HTTPS** and custom domains +- โœ… **GitHub integration** (auto-deploy on push) + +## ๐Ÿ“‹ Prerequisites + +1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co) +2. **GitHub Repository**: Your code should be in a GitHub repository +3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey) + +## ๐Ÿš€ Step-by-Step Deployment + +### Step 1: Prepare Your Repository + +Your `rag-backend/` directory should contain: +- โœ… `app.py` - Entry point (already created) +- โœ… `requirements.txt` - Dependencies +- โœ… `app/main.py` - FastAPI application +- โœ… All other application files + +### Step 2: Create Hugging Face Space + +1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) +2. Click **"Create new Space"** +3. Configure: + - **Owner**: Your username + - **Space name**: `clientsphere-rag-backend` (or your choice) + - **SDK**: **Docker** (recommended) or **Gradio** (if you want UI) + - **Hardware**: + - **CPU basic** (free) - Good for testing + - **CPU upgrade** (paid) - Better performance + - **GPU** (paid) - For heavy ML workloads + +### Step 3: Connect GitHub Repository + +1. In Space creation, select **"Repository"** as source +2. Choose your GitHub repository +3. Set **Repository path** to: `rag-backend/` (subdirectory) +4. Click **"Create Space"** + +### Step 4: Configure Environment Variables + +1. Go to your Space's **Settings** tab +2. Scroll to **"Repository secrets"** or **"Variables"** +3. Add these secrets: + +**Required:** +``` +GEMINI_API_KEY=your_gemini_api_key_here +ENV=prod +LLM_PROVIDER=gemini +``` + +**Optional (but recommended):** +``` +ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev +JWT_SECRET=your_secure_jwt_secret +DEBUG=false +``` + +### Step 5: Configure Docker (if using Docker SDK) + +If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`. + +**Your existing `Dockerfile` should work!** It's already configured correctly. + +### Step 6: Alternative - Use app.py (Simpler) + +If you want to use the simpler `app.py` approach: + +1. In Space settings, set: + - **SDK**: **Gradio** or **Streamlit** (but we'll override) + - **App file**: `app.py` + +2. Hugging Face will automatically: + - Install dependencies from `requirements.txt` + - Run `python app.py` + - Expose on port 7860 + +### Step 7: Deploy! + +1. **Push to GitHub** (if not already): + ```bash + git add rag-backend/app.py + git commit -m "Add Hugging Face Spaces entry point" + git push origin main + ``` + +2. **Hugging Face will auto-deploy** from your GitHub repo! + +3. **Wait for build** (5-10 minutes first time, faster after) + +4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space` + +## ๐Ÿ”ง Configuration Options + +### Option A: Docker (Recommended) + +**Advantages:** +- Full control over environment +- Can customize Python version +- Better for production + +**Setup:** +- Use existing `Dockerfile` +- Hugging Face will build and run it +- Exposes on port 7860 automatically + +### Option B: app.py (Simpler) + +**Advantages:** +- Simpler setup +- Faster builds +- Good for development + +**Setup:** +- Create `app.py` in `rag-backend/` (already done) +- Hugging Face runs it automatically + +## ๐Ÿ“ Environment Variables Reference + +| Variable | Required | Description | +|----------|----------|-------------| +| `GEMINI_API_KEY` | โœ… Yes | Your Google Gemini API key | +| `ENV` | โœ… Yes | Set to `prod` for production | +| `LLM_PROVIDER` | โœ… Yes | `gemini` or `openai` | +| `ALLOWED_ORIGINS` | โš ๏ธ Recommended | CORS allowed origins (comma-separated) | +| `JWT_SECRET` | โš ๏ธ Recommended | JWT secret for authentication | +| `DEBUG` | โŒ Optional | Set to `false` in production | +| `OPENAI_API_KEY` | โŒ Optional | If using OpenAI instead of Gemini | + +## ๐ŸŒ CORS Configuration + +After deployment, update `ALLOWED_ORIGINS` to include: +- Your Cloudflare Pages frontend URL +- Your Cloudflare Workers backend URL +- Any other origins that need access + +Example: +``` +ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev +``` + +## ๐Ÿ”„ Updating Deployment + +**Automatic (Recommended):** +- Push to GitHub โ†’ Hugging Face auto-deploys + +**Manual:** +- Go to Space โ†’ Settings โ†’ "Rebuild Space" + +## ๐Ÿ“Š Resource Limits + +### Free Tier: +- โœ… **CPU**: Basic (sufficient for RAG) +- โœ… **Storage**: 50GB (plenty for vector DB) +- โœ… **Memory**: 16GB RAM +- โœ… **Build time**: 20 minutes +- โœ… **Sleep after inactivity**: 48 hours (wakes on request) + +### Paid Tiers: +- **CPU upgrade**: Better performance +- **GPU**: For heavy ML workloads +- **No sleep**: Always-on service + +## ๐Ÿงช Testing Deployment + +After deployment, test your endpoints: + +```bash +# Health check +curl https://your-username-clientsphere-rag-backend.hf.space/health/live + +# KB Stats (with auth) +curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test +``` + +## ๐Ÿ”— Update Frontend + +After deployment, update Cloudflare Pages environment variable: + +``` +VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space +``` + +Then redeploy frontend: +```bash +npm run build +npx wrangler pages deploy dist --project-name=clientsphere +``` + +## โœ… Advantages Over Render + +| Feature | Hugging Face Spaces | Render | +|---------|-------------------|--------| +| Free Tier | โœ… Generous | โš ๏ธ Limited | +| ML Libraries | โœ… Full support | โœ… Full support | +| Auto-deploy | โœ… GitHub integration | โœ… GitHub integration | +| Storage | โœ… 50GB free | โš ๏ธ Limited | +| Sleep Mode | โœ… Wakes on request | โŒ No sleep mode | +| GPU Support | โœ… Available | โŒ Not available | +| Community | โœ… Large ML community | โš ๏ธ Smaller | + +## ๐ŸŽฏ Summary + +1. โœ… Create Hugging Face Space +2. โœ… Connect GitHub repository (rag-backend/) +3. โœ… Set environment variables +4. โœ… Deploy (automatic on push) +5. โœ… Update frontend `VITE_RAG_API_URL` +6. โœ… Test and enjoy! + +**Your RAG backend will be live at:** +`https://your-username-clientsphere-rag-backend.hf.space` + +--- + +**Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces) + diff --git a/app.py b/app.py index ab7cbfba8081e99baf78061bdfc4b9b0d469bc07..799f1ee953cc80c0ed4300a98c179060bf235952 100644 --- a/app.py +++ b/app.py @@ -1,13 +1,13 @@ -""" -Hugging Face Spaces entry point for RAG Backend. -This file is used when deploying to Hugging Face Spaces. -""" -import os -import uvicorn -from app.main import app - -if __name__ == "__main__": - # Hugging Face Spaces provides PORT environment variable (defaults to 7860) - port = int(os.getenv("PORT", 7860)) - uvicorn.run(app, host="0.0.0.0", port=port) - +""" +Hugging Face Spaces entry point for RAG Backend. +This file is used when deploying to Hugging Face Spaces. +""" +import os +import uvicorn +from app.main import app + +if __name__ == "__main__": + # Hugging Face Spaces provides PORT environment variable (defaults to 7860) + port = int(os.getenv("PORT", 7860)) + uvicorn.run(app, host="0.0.0.0", port=port) + diff --git a/app/__init__.py b/app/__init__.py index 2927b3f94e820da8fb6011f7728162ff6588b1e0..1656c77071f1569f32d92292adda04c91ecda1ed 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,7 +1,7 @@ -""" -ClientSphere RAG Backend Application. -""" -__version__ = "1.0.0" - - - +""" +ClientSphere RAG Backend Application. +""" +__version__ = "1.0.0" + + + diff --git a/app/__pycache__/__init__.cpython-313.pyc b/app/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index cdbcbc7f7b1606aa3483782b38a838732e262abd..0000000000000000000000000000000000000000 Binary files a/app/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/app/__pycache__/config.cpython-313.pyc b/app/__pycache__/config.cpython-313.pyc deleted file mode 100644 index 1ff6d5db559d78685cbb1b28e0111047ae20418e..0000000000000000000000000000000000000000 Binary files a/app/__pycache__/config.cpython-313.pyc and /dev/null differ diff --git a/app/__pycache__/main.cpython-313.pyc b/app/__pycache__/main.cpython-313.pyc deleted file mode 100644 index 73ca17bb267d5c5aaa06bafd899f1f4fe20eda53..0000000000000000000000000000000000000000 Binary files a/app/__pycache__/main.cpython-313.pyc and /dev/null differ diff --git a/app/billing/__pycache__/pricing.cpython-313.pyc b/app/billing/__pycache__/pricing.cpython-313.pyc deleted file mode 100644 index 43ab282823946c4f77c61b7fa9bfaece57eb5052..0000000000000000000000000000000000000000 Binary files a/app/billing/__pycache__/pricing.cpython-313.pyc and /dev/null differ diff --git a/app/billing/__pycache__/quota.cpython-313.pyc b/app/billing/__pycache__/quota.cpython-313.pyc deleted file mode 100644 index 4cf06d53ff56842e6bd33bac131f1abc7d744757..0000000000000000000000000000000000000000 Binary files a/app/billing/__pycache__/quota.cpython-313.pyc and /dev/null differ diff --git a/app/billing/__pycache__/usage_tracker.cpython-313.pyc b/app/billing/__pycache__/usage_tracker.cpython-313.pyc deleted file mode 100644 index 9613384d8067dd1a3c356c90ab5f648563fbe573..0000000000000000000000000000000000000000 Binary files a/app/billing/__pycache__/usage_tracker.cpython-313.pyc and /dev/null differ diff --git a/app/billing/pricing.py b/app/billing/pricing.py index 186832f56e89bc5d8d3c501d065f816414e6b4fb..14ed2028077a5dd4abea91451b8728fb2e9e1b82 100644 --- a/app/billing/pricing.py +++ b/app/billing/pricing.py @@ -1,57 +1,57 @@ -""" -Pricing table for LLM providers. -Used to calculate estimated costs from token usage. -""" -from typing import Dict, Optional - -# Pricing per 1M tokens (as of 2024, update as needed) -PRICING_TABLE: Dict[str, Dict[str, float]] = { - "gemini": { - "gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens - "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, - "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, - "gemini-1.0-pro": {"input": 0.50, "output": 1.50}, - "default": {"input": 0.50, "output": 1.50} - }, - "openai": { - "gpt-4": {"input": 30.00, "output": 60.00}, - "gpt-4-turbo": {"input": 10.00, "output": 30.00}, - "gpt-3.5-turbo": {"input": 0.50, "output": 1.50}, - "default": {"input": 0.50, "output": 1.50} - } -} - - -def calculate_cost( - provider: str, - model: str, - prompt_tokens: int, - completion_tokens: int -) -> float: - """ - Calculate estimated cost in USD based on token usage. - - Args: - provider: "gemini" or "openai" - model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo") - prompt_tokens: Number of input tokens - completion_tokens: Number of output tokens - - Returns: - Estimated cost in USD - """ - provider_pricing = PRICING_TABLE.get(provider.lower(), {}) - model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) - - # Calculate cost: (tokens / 1M) * price_per_1M - input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"] - output_cost = (completion_tokens / 1_000_000) * model_pricing["output"] - - return input_cost + output_cost - - -def get_model_pricing(provider: str, model: str) -> Dict[str, float]: - """Get pricing for a specific model.""" - provider_pricing = PRICING_TABLE.get(provider.lower(), {}) - return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) - +""" +Pricing table for LLM providers. +Used to calculate estimated costs from token usage. +""" +from typing import Dict, Optional + +# Pricing per 1M tokens (as of 2024, update as needed) +PRICING_TABLE: Dict[str, Dict[str, float]] = { + "gemini": { + "gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens + "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, + "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, + "gemini-1.0-pro": {"input": 0.50, "output": 1.50}, + "default": {"input": 0.50, "output": 1.50} + }, + "openai": { + "gpt-4": {"input": 30.00, "output": 60.00}, + "gpt-4-turbo": {"input": 10.00, "output": 30.00}, + "gpt-3.5-turbo": {"input": 0.50, "output": 1.50}, + "default": {"input": 0.50, "output": 1.50} + } +} + + +def calculate_cost( + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int +) -> float: + """ + Calculate estimated cost in USD based on token usage. + + Args: + provider: "gemini" or "openai" + model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo") + prompt_tokens: Number of input tokens + completion_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + provider_pricing = PRICING_TABLE.get(provider.lower(), {}) + model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) + + # Calculate cost: (tokens / 1M) * price_per_1M + input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"] + output_cost = (completion_tokens / 1_000_000) * model_pricing["output"] + + return input_cost + output_cost + + +def get_model_pricing(provider: str, model: str) -> Dict[str, float]: + """Get pricing for a specific model.""" + provider_pricing = PRICING_TABLE.get(provider.lower(), {}) + return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) + diff --git a/app/billing/quota.py b/app/billing/quota.py index c527a7b5fe3f79c0362eb72f50a6b8d152447541..638b63d4e647395cd389f9967288235ac1c92a6c 100644 --- a/app/billing/quota.py +++ b/app/billing/quota.py @@ -1,131 +1,131 @@ -""" -Quota management and enforcement. -""" -from sqlalchemy.orm import Session -from sqlalchemy import func, and_ -from datetime import datetime, timedelta -from typing import Optional, Tuple -import logging - -from app.db.models import TenantPlan, UsageMonthly, Tenant -logger = logging.getLogger(__name__) - -# Plan limits (chats per month) -PLAN_LIMITS = { - "starter": 500, - "growth": 5000, - "pro": -1 # -1 means unlimited -} - - -def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]: - """Get tenant's current plan.""" - return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first() - - -def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]: - """Get monthly usage for tenant.""" - now = datetime.utcnow() - target_year = year or now.year - target_month = month or now.month - - return db.query(UsageMonthly).filter( - and_( - UsageMonthly.tenant_id == tenant_id, - UsageMonthly.year == target_year, - UsageMonthly.month == target_month - ) - ).first() - - -def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]: - """ - Check if tenant has quota remaining for the current month. - - Returns: - (has_quota, error_message) - has_quota: True if quota available, False if exceeded - error_message: None if quota available, error message if exceeded - """ - # Get tenant plan - plan = get_tenant_plan(db, tenant_id) - - if not plan: - # Default to starter plan if no plan assigned - logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter") - monthly_limit = PLAN_LIMITS.get("starter", 500) - else: - monthly_limit = plan.monthly_chat_limit - - # Unlimited plan (-1) always passes - if monthly_limit == -1: - return True, None - - # Get current month usage - now = datetime.utcnow() - monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) - - current_usage = monthly_usage.total_requests if monthly_usage else 0 - - # Check if quota exceeded - if current_usage >= monthly_limit: - return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan." - - return True, None - - -def ensure_tenant_exists(db: Session, tenant_id: str) -> None: - """Ensure tenant record exists in database.""" - tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first() - if not tenant: - # Create tenant with default starter plan - tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}") - db.add(tenant) - - # Create default starter plan - plan = TenantPlan( - tenant_id=tenant_id, - plan_name="starter", - monthly_chat_limit=PLAN_LIMITS["starter"] - ) - db.add(plan) - db.commit() - logger.info(f"Created tenant {tenant_id} with starter plan") - - -def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan: - """ - Set tenant's subscription plan. - - Args: - db: Database session - tenant_id: Tenant ID - plan_name: "starter", "growth", or "pro" - - Returns: - Updated TenantPlan - """ - if plan_name not in PLAN_LIMITS: - raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}") - - # Ensure tenant exists - ensure_tenant_exists(db, tenant_id) - - # Get or create plan - plan = get_tenant_plan(db, tenant_id) - if plan: - plan.plan_name = plan_name - plan.monthly_chat_limit = PLAN_LIMITS[plan_name] - plan.updated_at = datetime.utcnow() - else: - plan = TenantPlan( - tenant_id=tenant_id, - plan_name=plan_name, - monthly_chat_limit=PLAN_LIMITS[plan_name] - ) - db.add(plan) - - db.commit() - db.refresh(plan) - return plan - +""" +Quota management and enforcement. +""" +from sqlalchemy.orm import Session +from sqlalchemy import func, and_ +from datetime import datetime, timedelta +from typing import Optional, Tuple +import logging + +from app.db.models import TenantPlan, UsageMonthly, Tenant +logger = logging.getLogger(__name__) + +# Plan limits (chats per month) +PLAN_LIMITS = { + "starter": 500, + "growth": 5000, + "pro": -1 # -1 means unlimited +} + + +def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]: + """Get tenant's current plan.""" + return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first() + + +def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]: + """Get monthly usage for tenant.""" + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + return db.query(UsageMonthly).filter( + and_( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == target_year, + UsageMonthly.month == target_month + ) + ).first() + + +def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]: + """ + Check if tenant has quota remaining for the current month. + + Returns: + (has_quota, error_message) + has_quota: True if quota available, False if exceeded + error_message: None if quota available, error message if exceeded + """ + # Get tenant plan + plan = get_tenant_plan(db, tenant_id) + + if not plan: + # Default to starter plan if no plan assigned + logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter") + monthly_limit = PLAN_LIMITS.get("starter", 500) + else: + monthly_limit = plan.monthly_chat_limit + + # Unlimited plan (-1) always passes + if monthly_limit == -1: + return True, None + + # Get current month usage + now = datetime.utcnow() + monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) + + current_usage = monthly_usage.total_requests if monthly_usage else 0 + + # Check if quota exceeded + if current_usage >= monthly_limit: + return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan." + + return True, None + + +def ensure_tenant_exists(db: Session, tenant_id: str) -> None: + """Ensure tenant record exists in database.""" + tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + # Create tenant with default starter plan + tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}") + db.add(tenant) + + # Create default starter plan + plan = TenantPlan( + tenant_id=tenant_id, + plan_name="starter", + monthly_chat_limit=PLAN_LIMITS["starter"] + ) + db.add(plan) + db.commit() + logger.info(f"Created tenant {tenant_id} with starter plan") + + +def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan: + """ + Set tenant's subscription plan. + + Args: + db: Database session + tenant_id: Tenant ID + plan_name: "starter", "growth", or "pro" + + Returns: + Updated TenantPlan + """ + if plan_name not in PLAN_LIMITS: + raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}") + + # Ensure tenant exists + ensure_tenant_exists(db, tenant_id) + + # Get or create plan + plan = get_tenant_plan(db, tenant_id) + if plan: + plan.plan_name = plan_name + plan.monthly_chat_limit = PLAN_LIMITS[plan_name] + plan.updated_at = datetime.utcnow() + else: + plan = TenantPlan( + tenant_id=tenant_id, + plan_name=plan_name, + monthly_chat_limit=PLAN_LIMITS[plan_name] + ) + db.add(plan) + + db.commit() + db.refresh(plan) + return plan + diff --git a/app/billing/usage_tracker.py b/app/billing/usage_tracker.py index 390dda761ef82aa5ae678798a906c3232799bcb8..d88b956d7314b0b29e530f20d665088ca4f69a13 100644 --- a/app/billing/usage_tracker.py +++ b/app/billing/usage_tracker.py @@ -1,173 +1,173 @@ -""" -Usage tracking service. -Tracks token usage and costs for each LLM request. -""" -from sqlalchemy.orm import Session -from sqlalchemy import func, and_ -from datetime import datetime, timedelta -from typing import Optional -import uuid -import logging - -from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant -from app.billing.pricing import calculate_cost -from app.billing.quota import ensure_tenant_exists - -logger = logging.getLogger(__name__) - - -def track_usage( - db: Session, - tenant_id: str, - user_id: str, - kb_id: str, - provider: str, - model: str, - prompt_tokens: int, - completion_tokens: int, - request_timestamp: Optional[datetime] = None -) -> UsageEvent: - """ - Track a single usage event. - - Args: - db: Database session - tenant_id: Tenant ID - user_id: User ID - kb_id: Knowledge base ID - provider: "gemini" or "openai" - model: Model name - prompt_tokens: Input tokens - completion_tokens: Output tokens - request_timestamp: Request timestamp (defaults to now) - - Returns: - Created UsageEvent - """ - # Ensure tenant exists - ensure_tenant_exists(db, tenant_id) - - # Calculate cost - total_tokens = prompt_tokens + completion_tokens - estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens) - - # Create usage event - request_id = f"req_{uuid.uuid4().hex[:16]}" - timestamp = request_timestamp or datetime.utcnow() - - usage_event = UsageEvent( - request_id=request_id, - tenant_id=tenant_id, - user_id=user_id, - kb_id=kb_id, - provider=provider, - model=model, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - estimated_cost_usd=estimated_cost, - request_timestamp=timestamp - ) - - db.add(usage_event) - - # Update daily aggregation - _update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) - - # Update monthly aggregation - _update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) - - db.commit() - db.refresh(usage_event) - - logger.info( - f"Tracked usage: tenant={tenant_id}, provider={provider}, " - f"tokens={total_tokens}, cost=${estimated_cost:.6f}" - ) - - return usage_event - - -def _update_daily_usage( - db: Session, - tenant_id: str, - timestamp: datetime, - provider: str, - tokens: int, - cost: float -): - """Update daily usage aggregation.""" - date = timestamp.date() - date_start = datetime.combine(date, datetime.min.time()) - - daily = db.query(UsageDaily).filter( - and_( - UsageDaily.tenant_id == tenant_id, - UsageDaily.date == date_start - ) - ).first() - - if daily: - daily.total_requests += 1 - daily.total_tokens += tokens - daily.total_cost_usd += cost - if provider == "gemini": - daily.gemini_requests += 1 - elif provider == "openai": - daily.openai_requests += 1 - daily.updated_at = datetime.utcnow() - else: - daily = UsageDaily( - tenant_id=tenant_id, - date=date_start, - total_requests=1, - total_tokens=tokens, - total_cost_usd=cost, - gemini_requests=1 if provider == "gemini" else 0, - openai_requests=1 if provider == "openai" else 0 - ) - db.add(daily) - - -def _update_monthly_usage( - db: Session, - tenant_id: str, - timestamp: datetime, - provider: str, - tokens: int, - cost: float -): - """Update monthly usage aggregation.""" - year = timestamp.year - month = timestamp.month - - monthly = db.query(UsageMonthly).filter( - and_( - UsageMonthly.tenant_id == tenant_id, - UsageMonthly.year == year, - UsageMonthly.month == month - ) - ).first() - - if monthly: - monthly.total_requests += 1 - monthly.total_tokens += tokens - monthly.total_cost_usd += cost - if provider == "gemini": - monthly.gemini_requests += 1 - elif provider == "openai": - monthly.openai_requests += 1 - monthly.updated_at = datetime.utcnow() - else: - monthly = UsageMonthly( - tenant_id=tenant_id, - year=year, - month=month, - total_requests=1, - total_tokens=tokens, - total_cost_usd=cost, - gemini_requests=1 if provider == "gemini" else 0, - openai_requests=1 if provider == "openai" else 0 - ) - db.add(monthly) - +""" +Usage tracking service. +Tracks token usage and costs for each LLM request. +""" +from sqlalchemy.orm import Session +from sqlalchemy import func, and_ +from datetime import datetime, timedelta +from typing import Optional +import uuid +import logging + +from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant +from app.billing.pricing import calculate_cost +from app.billing.quota import ensure_tenant_exists + +logger = logging.getLogger(__name__) + + +def track_usage( + db: Session, + tenant_id: str, + user_id: str, + kb_id: str, + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + request_timestamp: Optional[datetime] = None +) -> UsageEvent: + """ + Track a single usage event. + + Args: + db: Database session + tenant_id: Tenant ID + user_id: User ID + kb_id: Knowledge base ID + provider: "gemini" or "openai" + model: Model name + prompt_tokens: Input tokens + completion_tokens: Output tokens + request_timestamp: Request timestamp (defaults to now) + + Returns: + Created UsageEvent + """ + # Ensure tenant exists + ensure_tenant_exists(db, tenant_id) + + # Calculate cost + total_tokens = prompt_tokens + completion_tokens + estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens) + + # Create usage event + request_id = f"req_{uuid.uuid4().hex[:16]}" + timestamp = request_timestamp or datetime.utcnow() + + usage_event = UsageEvent( + request_id=request_id, + tenant_id=tenant_id, + user_id=user_id, + kb_id=kb_id, + provider=provider, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + estimated_cost_usd=estimated_cost, + request_timestamp=timestamp + ) + + db.add(usage_event) + + # Update daily aggregation + _update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) + + # Update monthly aggregation + _update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) + + db.commit() + db.refresh(usage_event) + + logger.info( + f"Tracked usage: tenant={tenant_id}, provider={provider}, " + f"tokens={total_tokens}, cost=${estimated_cost:.6f}" + ) + + return usage_event + + +def _update_daily_usage( + db: Session, + tenant_id: str, + timestamp: datetime, + provider: str, + tokens: int, + cost: float +): + """Update daily usage aggregation.""" + date = timestamp.date() + date_start = datetime.combine(date, datetime.min.time()) + + daily = db.query(UsageDaily).filter( + and_( + UsageDaily.tenant_id == tenant_id, + UsageDaily.date == date_start + ) + ).first() + + if daily: + daily.total_requests += 1 + daily.total_tokens += tokens + daily.total_cost_usd += cost + if provider == "gemini": + daily.gemini_requests += 1 + elif provider == "openai": + daily.openai_requests += 1 + daily.updated_at = datetime.utcnow() + else: + daily = UsageDaily( + tenant_id=tenant_id, + date=date_start, + total_requests=1, + total_tokens=tokens, + total_cost_usd=cost, + gemini_requests=1 if provider == "gemini" else 0, + openai_requests=1 if provider == "openai" else 0 + ) + db.add(daily) + + +def _update_monthly_usage( + db: Session, + tenant_id: str, + timestamp: datetime, + provider: str, + tokens: int, + cost: float +): + """Update monthly usage aggregation.""" + year = timestamp.year + month = timestamp.month + + monthly = db.query(UsageMonthly).filter( + and_( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == year, + UsageMonthly.month == month + ) + ).first() + + if monthly: + monthly.total_requests += 1 + monthly.total_tokens += tokens + monthly.total_cost_usd += cost + if provider == "gemini": + monthly.gemini_requests += 1 + elif provider == "openai": + monthly.openai_requests += 1 + monthly.updated_at = datetime.utcnow() + else: + monthly = UsageMonthly( + tenant_id=tenant_id, + year=year, + month=month, + total_requests=1, + total_tokens=tokens, + total_cost_usd=cost, + gemini_requests=1 if provider == "gemini" else 0, + openai_requests=1 if provider == "openai" else 0 + ) + db.add(monthly) + diff --git a/app/config.py b/app/config.py index 093e2aa07dbf518609daaac9b79d06f5d9df2c03..665337d0b7203d288bf95ad2a3f1bb8cf864c3e9 100644 --- a/app/config.py +++ b/app/config.py @@ -1,77 +1,77 @@ -""" -Configuration settings for the RAG backend. -""" -from pydantic_settings import BaseSettings -from pathlib import Path -from typing import Optional -import os - - -class Settings(BaseSettings): - """Application settings with environment variable support.""" - - # App settings - APP_NAME: str = "ClientSphere RAG Backend" - DEBUG: bool = True - ENV: str = "dev" # "dev" or "prod" - controls tenant_id security - - # Paths - BASE_DIR: Path = Path(__file__).parent.parent - DATA_DIR: Path = BASE_DIR / "data" - UPLOADS_DIR: Path = DATA_DIR / "uploads" - PROCESSED_DIR: Path = DATA_DIR / "processed" - VECTORDB_DIR: Path = DATA_DIR / "vectordb" - - # Chunking settings (optimized for retrieval quality) - CHUNK_SIZE: int = 600 # tokens (increased for better context) - CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity) - MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks) - - # Embedding settings - EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality - EMBEDDING_DIMENSION: int = 384 - - # Vector store settings - COLLECTION_NAME: str = "clientsphere_kb" - - # Retrieval settings (optimized for maximum confidence) - TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches) - SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks - SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination) - - # LLM settings - LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai" - GEMINI_API_KEY: Optional[str] = None - OPENAI_API_KEY: Optional[str] = None - GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model - OPENAI_MODEL: str = "gpt-3.5-turbo" - - # Response settings - MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus) - TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination) - REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention - - # Security settings - MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB - ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all) - RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute - JWT_SECRET: Optional[str] = None # JWT secret for authentication - - # Rate limiting - RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting - - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - # Create directories if they don't exist - self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True) - self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True) - self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True) - - -# Global settings instance -settings = Settings() - +""" +Configuration settings for the RAG backend. +""" +from pydantic_settings import BaseSettings +from pathlib import Path +from typing import Optional +import os + + +class Settings(BaseSettings): + """Application settings with environment variable support.""" + + # App settings + APP_NAME: str = "ClientSphere RAG Backend" + DEBUG: bool = True + ENV: str = "dev" # "dev" or "prod" - controls tenant_id security + + # Paths + BASE_DIR: Path = Path(__file__).parent.parent + DATA_DIR: Path = BASE_DIR / "data" + UPLOADS_DIR: Path = DATA_DIR / "uploads" + PROCESSED_DIR: Path = DATA_DIR / "processed" + VECTORDB_DIR: Path = DATA_DIR / "vectordb" + + # Chunking settings (optimized for retrieval quality) + CHUNK_SIZE: int = 600 # tokens (increased for better context) + CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity) + MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks) + + # Embedding settings + EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality + EMBEDDING_DIMENSION: int = 384 + + # Vector store settings + COLLECTION_NAME: str = "clientsphere_kb" + + # Retrieval settings (optimized for maximum confidence) + TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches) + SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks + SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination) + + # LLM settings + LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai" + GEMINI_API_KEY: Optional[str] = None + OPENAI_API_KEY: Optional[str] = None + GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model + OPENAI_MODEL: str = "gpt-3.5-turbo" + + # Response settings + MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus) + TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination) + REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention + + # Security settings + MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB + ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all) + RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute + JWT_SECRET: Optional[str] = None # JWT secret for authentication + + # Rate limiting + RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Create directories if they don't exist + self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True) + self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True) + self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True) + + +# Global settings instance +settings = Settings() + diff --git a/app/db/__init__.py b/app/db/__init__.py index cae2a10692977751489e2380a25106789a1b3d25..8c2b90af076993b371b31a5cee9d715dfb794300 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,2 +1,2 @@ -# Database module - +# Database module + diff --git a/app/db/__pycache__/__init__.cpython-313.pyc b/app/db/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 5e4546ddb536d3c73c9d6ccea735b061980d2cb3..0000000000000000000000000000000000000000 Binary files a/app/db/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/app/db/__pycache__/database.cpython-313.pyc b/app/db/__pycache__/database.cpython-313.pyc deleted file mode 100644 index 7bc47a58f3e4cf8175f382a1b85d5d5ae853d18d..0000000000000000000000000000000000000000 Binary files a/app/db/__pycache__/database.cpython-313.pyc and /dev/null differ diff --git a/app/db/__pycache__/models.cpython-313.pyc b/app/db/__pycache__/models.cpython-313.pyc deleted file mode 100644 index 33cf42d384da0659676e20df5a2cd086b03c3361..0000000000000000000000000000000000000000 Binary files a/app/db/__pycache__/models.cpython-313.pyc and /dev/null differ diff --git a/app/db/database.py b/app/db/database.py index 3ecb1d6017dc1c1f51fca7e96a3a7130b1d3947b..1d24c4a3231eea95ddd61801983ed6b2cc5b2787 100644 --- a/app/db/database.py +++ b/app/db/database.py @@ -1,53 +1,53 @@ -""" -Database setup and session management. -Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema. -""" -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session -from pathlib import Path -import logging - -from app.config import settings - -logger = logging.getLogger(__name__) - -# Database path -DB_DIR = settings.DATA_DIR / "billing" -DB_DIR.mkdir(parents=True, exist_ok=True) -DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}" - -# Create engine -engine = create_engine( - DATABASE_URL, - connect_args={"check_same_thread": False}, # SQLite specific - echo=False # Set to True for SQL query logging -) - -# Session factory -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -# Base class for models -Base = declarative_base() - - -def get_db() -> Session: - """Get database session (dependency for FastAPI).""" - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def init_db(): - """Initialize database tables.""" - Base.metadata.create_all(bind=engine) - logger.info("Database tables created/verified") - - -def drop_db(): - """Drop all tables (use with caution!).""" - Base.metadata.drop_all(bind=engine) - logger.warning("All database tables dropped") - +""" +Database setup and session management. +Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema. +""" +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, Session +from pathlib import Path +import logging + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Database path +DB_DIR = settings.DATA_DIR / "billing" +DB_DIR.mkdir(parents=True, exist_ok=True) +DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}" + +# Create engine +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False}, # SQLite specific + echo=False # Set to True for SQL query logging +) + +# Session factory +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# Base class for models +Base = declarative_base() + + +def get_db() -> Session: + """Get database session (dependency for FastAPI).""" + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def init_db(): + """Initialize database tables.""" + Base.metadata.create_all(bind=engine) + logger.info("Database tables created/verified") + + +def drop_db(): + """Drop all tables (use with caution!).""" + Base.metadata.drop_all(bind=engine) + logger.warning("All database tables dropped") + diff --git a/app/db/models.py b/app/db/models.py index f07c57061c84a15bff09f00a89467dcd41d9cb09..ef16b515fe6d16a42dc3039be5ffbddb737363da 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -1,129 +1,129 @@ -""" -Database models for billing and usage tracking. -""" -from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text -from sqlalchemy.orm import relationship -from datetime import datetime -from typing import Optional - -from app.db.database import Base - - -class Tenant(Base): - """Tenant/organization model.""" - __tablename__ = "tenants" - - id = Column(String, primary_key=True, index=True) - name = Column(String, nullable=False) - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - - # Relationships - plan = relationship("TenantPlan", back_populates="tenant", uselist=False) - usage_events = relationship("UsageEvent", back_populates="tenant") - daily_usage = relationship("UsageDaily", back_populates="tenant") - monthly_usage = relationship("UsageMonthly", back_populates="tenant") - - -class TenantPlan(Base): - """Tenant subscription plan.""" - __tablename__ = "tenant_plans" - - id = Column(Integer, primary_key=True, autoincrement=True) - tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True) - plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro" - monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - - # Relationships - tenant = relationship("Tenant", back_populates="plan") - - -class UsageEvent(Base): - """Individual usage event (each /chat request).""" - __tablename__ = "usage_events" - - id = Column(Integer, primary_key=True, autoincrement=True) - request_id = Column(String, unique=True, nullable=False, index=True) - tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) - user_id = Column(String, nullable=False, index=True) - kb_id = Column(String, nullable=False) - - # LLM details - provider = Column(String, nullable=False) # "gemini" or "openai" - model = Column(String, nullable=False) - - # Token usage - prompt_tokens = Column(Integer, nullable=False, default=0) - completion_tokens = Column(Integer, nullable=False, default=0) - total_tokens = Column(Integer, nullable=False, default=0) - - # Cost tracking - estimated_cost_usd = Column(Float, nullable=False, default=0.0) - - # Timestamp - request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) - - # Relationships - tenant = relationship("Tenant", back_populates="usage_events") - - -class UsageDaily(Base): - """Daily aggregated usage per tenant.""" - __tablename__ = "usage_daily" - - id = Column(Integer, primary_key=True, autoincrement=True) - tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - - # Aggregated metrics - total_requests = Column(Integer, nullable=False, default=0) - total_tokens = Column(Integer, nullable=False, default=0) - total_cost_usd = Column(Float, nullable=False, default=0.0) - - # Provider breakdown - gemini_requests = Column(Integer, nullable=False, default=0) - openai_requests = Column(Integer, nullable=False, default=0) - - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - - # Unique constraint: one record per tenant per day - __table_args__ = ( - {"sqlite_autoincrement": True}, - ) - - # Relationships - tenant = relationship("Tenant", back_populates="daily_usage") - - -class UsageMonthly(Base): - """Monthly aggregated usage per tenant.""" - __tablename__ = "usage_monthly" - - id = Column(Integer, primary_key=True, autoincrement=True) - tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) - year = Column(Integer, nullable=False, index=True) - month = Column(Integer, nullable=False, index=True) # 1-12 - - # Aggregated metrics - total_requests = Column(Integer, nullable=False, default=0) - total_tokens = Column(Integer, nullable=False, default=0) - total_cost_usd = Column(Float, nullable=False, default=0.0) - - # Provider breakdown - gemini_requests = Column(Integer, nullable=False, default=0) - openai_requests = Column(Integer, nullable=False, default=0) - - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - - # Unique constraint: one record per tenant per month - __table_args__ = ( - {"sqlite_autoincrement": True}, - ) - - # Relationships - tenant = relationship("Tenant", back_populates="monthly_usage") - +""" +Database models for billing and usage tracking. +""" +from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text +from sqlalchemy.orm import relationship +from datetime import datetime +from typing import Optional + +from app.db.database import Base + + +class Tenant(Base): + """Tenant/organization model.""" + __tablename__ = "tenants" + + id = Column(String, primary_key=True, index=True) + name = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Relationships + plan = relationship("TenantPlan", back_populates="tenant", uselist=False) + usage_events = relationship("UsageEvent", back_populates="tenant") + daily_usage = relationship("UsageDaily", back_populates="tenant") + monthly_usage = relationship("UsageMonthly", back_populates="tenant") + + +class TenantPlan(Base): + """Tenant subscription plan.""" + __tablename__ = "tenant_plans" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True) + plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro" + monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Relationships + tenant = relationship("Tenant", back_populates="plan") + + +class UsageEvent(Base): + """Individual usage event (each /chat request).""" + __tablename__ = "usage_events" + + id = Column(Integer, primary_key=True, autoincrement=True) + request_id = Column(String, unique=True, nullable=False, index=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + user_id = Column(String, nullable=False, index=True) + kb_id = Column(String, nullable=False) + + # LLM details + provider = Column(String, nullable=False) # "gemini" or "openai" + model = Column(String, nullable=False) + + # Token usage + prompt_tokens = Column(Integer, nullable=False, default=0) + completion_tokens = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + + # Cost tracking + estimated_cost_usd = Column(Float, nullable=False, default=0.0) + + # Timestamp + request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) + + # Relationships + tenant = relationship("Tenant", back_populates="usage_events") + + +class UsageDaily(Base): + """Daily aggregated usage per tenant.""" + __tablename__ = "usage_daily" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + date = Column(DateTime, nullable=False, index=True) + + # Aggregated metrics + total_requests = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + total_cost_usd = Column(Float, nullable=False, default=0.0) + + # Provider breakdown + gemini_requests = Column(Integer, nullable=False, default=0) + openai_requests = Column(Integer, nullable=False, default=0) + + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Unique constraint: one record per tenant per day + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + # Relationships + tenant = relationship("Tenant", back_populates="daily_usage") + + +class UsageMonthly(Base): + """Monthly aggregated usage per tenant.""" + __tablename__ = "usage_monthly" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + year = Column(Integer, nullable=False, index=True) + month = Column(Integer, nullable=False, index=True) # 1-12 + + # Aggregated metrics + total_requests = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + total_cost_usd = Column(Float, nullable=False, default=0.0) + + # Provider breakdown + gemini_requests = Column(Integer, nullable=False, default=0) + openai_requests = Column(Integer, nullable=False, default=0) + + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Unique constraint: one record per tenant per month + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + # Relationships + tenant = relationship("Tenant", back_populates="monthly_usage") + diff --git a/app/main.py b/app/main.py index bbe9293084087a8cf8545b2702f91549db453aa6..e9b65f05bcfcdd0e6b6b8565bb1ea9772792f27b 100644 --- a/app/main.py +++ b/app/main.py @@ -1,1039 +1,1039 @@ -""" -FastAPI application for ClientSphere RAG Backend. -Provides endpoints for knowledge base management and chat. -""" -from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends -from fastapi.middleware.cors import CORSMiddleware -from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse -from pathlib import Path -import shutil -import uuid -from datetime import datetime -from typing import Optional -import logging - -from app.config import settings -from app.middleware.auth import get_auth_context, require_auth -from app.middleware.rate_limit import ( - limiter, - get_tenant_rate_limit_key, - RateLimitExceeded, - _rate_limit_exceeded_handler -) -from app.models.schemas import ( - UploadResponse, - ChatRequest, - ChatResponse, - KnowledgeBaseStats, - HealthResponse, - DocumentStatus, - Citation, -) -from app.models.billing_schemas import ( - UsageResponse, - PlanLimitsResponse, - CostReportResponse, - SetPlanRequest -) -from app.rag.ingest import parser -from app.rag.chunking import chunker -from app.rag.embeddings import get_embedding_service -from app.rag.vectorstore import get_vector_store -from app.rag.retrieval import get_retrieval_service -from app.rag.answer import get_answer_service -from app.db.database import get_db, init_db -from app.billing.quota import check_quota, ensure_tenant_exists -from app.billing.usage_tracker import track_usage - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# Initialize FastAPI app -app = FastAPI( - title=settings.APP_NAME, - description="RAG-based customer support chatbot API", - version="1.0.0", -) - -# Initialize database on startup -@app.on_event("startup") -async def startup_event(): - """Initialize database on application startup.""" - init_db() - logger.info("Database initialized") - -# Configure CORS - SECURITY: Restrict in production -if settings.ALLOWED_ORIGINS == "*": - allowed_origins = ["*"] -else: - # Split by comma and strip whitespace - allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()] - -# Default to allowing localhost if no origins specified -if not allowed_origins or allowed_origins == ["*"]: - allowed_origins = ["*"] # Allow all in dev mode - -logger.info(f"CORS configured with origins: {allowed_origins}") - -app.add_middleware( - CORSMiddleware, - allow_origins=allowed_origins, - allow_credentials=True, - allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight - allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers -) - -# Configure rate limiting -app.state.limiter = limiter -app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) - -# Add exception handler for validation errors -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): - """Handle request validation errors with detailed logging.""" - body = await request.body() - logger.error(f"Request validation error: {exc.errors()}") - logger.error(f"Request body (raw): {body}") - logger.error(f"Request headers: {dict(request.headers)}") - return JSONResponse( - status_code=422, - content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')} - ) - -# Add exception handler for validation errors -from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse - -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(request: Request, exc: RequestValidationError): - """Handle request validation errors with detailed logging.""" - logger.error(f"Request validation error: {exc.errors()}") - logger.error(f"Request body: {await request.body()}") - return JSONResponse( - status_code=422, - content={"detail": exc.errors(), "body": str(await request.body())} - ) - - -# ============== Health & Status Endpoints ============== - -@app.get("/", response_model=HealthResponse) -async def root(): - """Root endpoint with basic info.""" - return HealthResponse( - status="ok", - version="1.0.0", - vector_db_connected=True, - llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) - ) - - -@app.get("/health", response_model=HealthResponse) -async def health_check(): - """Health check endpoint.""" - try: - vector_store = get_vector_store() - stats = vector_store.get_stats() - - return HealthResponse( - status="healthy", - version="1.0.0", - vector_db_connected=True, - llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) - ) - except Exception as e: - logger.error(f"Health check failed: {e}") - return HealthResponse( - status="unhealthy", - version="1.0.0", - vector_db_connected=False, - llm_configured=False - ) - - -@app.get("/health/live") -async def liveness(): - """Kubernetes liveness probe - always returns alive.""" - return {"status": "alive"} - - -@app.get("/health/ready") -async def readiness(): - """Kubernetes readiness probe - checks dependencies.""" - checks = { - "vector_db": False, - "llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) - } - - # Check vector DB connection - try: - vector_store = get_vector_store() - vector_store.get_stats() - checks["vector_db"] = True - except Exception as e: - logger.warning(f"Vector DB check failed: {e}") - checks["vector_db"] = False - - # All checks must pass - if all(checks.values()): - return {"status": "ready", "checks": checks} - else: - from fastapi import HTTPException - raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks}) - - -# ============== Knowledge Base Endpoints ============== - -@app.post("/kb/upload", response_model=UploadResponse) -@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key) -async def upload_document( - background_tasks: BackgroundTasks, - request: Request, - file: UploadFile = File(...), - tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod - user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod - kb_id: str = Form(...) -): - """ - Upload a document to the knowledge base. - - - Saves file to disk - - Parses and chunks the document - - Generates embeddings - - Stores in vector database - """ - # SECURITY: Extract tenant_id from auth token in production - if settings.ENV == "prod": - auth_context = await require_auth(request) - tenant_id = auth_context.get("tenant_id") - if not tenant_id: - raise HTTPException( - status_code=403, - detail="tenant_id must come from authentication token in production mode" - ) - elif not tenant_id: - raise HTTPException( - status_code=400, - detail="tenant_id is required" - ) - - # Validate file type - file_ext = Path(file.filename).suffix.lower() - if file_ext not in parser.SUPPORTED_EXTENSIONS: - raise HTTPException( - status_code=400, - detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}" - ) - - # Validate file size (SECURITY) - file.file.seek(0, 2) # Seek to end - file_size = file.file.tell() - file.file.seek(0) # Reset to start - max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024 - if file_size > max_size_bytes: - raise HTTPException( - status_code=400, - detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB" - ) - - # Generate document ID - doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}" - - # Save file to uploads directory - upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}" - try: - with open(upload_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - logger.info(f"Saved file: {upload_path}") - except Exception as e: - logger.error(f"Error saving file: {e}") - raise HTTPException(status_code=500, detail="Failed to save file") - - # Process document in background - background_tasks.add_task( - process_document, - upload_path, - tenant_id, # CRITICAL: Multi-tenant isolation - user_id, - kb_id, - file.filename, - doc_id - ) - - return UploadResponse( - success=True, - message="Document upload started. Processing in background.", - document_id=doc_id, - file_name=file.filename, - chunks_created=0, - status=DocumentStatus.PROCESSING - ) - - -async def process_document( - file_path: Path, - tenant_id: str, # CRITICAL: Multi-tenant isolation - user_id: str, - kb_id: str, - original_filename: str, - document_id: str -): - """ - Background task to process an uploaded document. - """ - try: - logger.info(f"Processing document: {original_filename}") - - # Parse document - parsed_doc = parser.parse(file_path) - logger.info(f"Parsed document: {len(parsed_doc.text)} characters") - - # Chunk document - chunks = chunker.chunk_text( - parsed_doc.text, - page_numbers=parsed_doc.page_map - ) - logger.info(f"Created {len(chunks)} chunks") - - if not chunks: - logger.warning(f"No chunks created from {original_filename}") - return - - # Create metadata for each chunk - metadatas = [] - chunk_ids = [] - chunk_texts = [] - - for chunk in chunks: - metadata = chunker.create_chunk_metadata( - chunk=chunk, - tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation - kb_id=kb_id, - user_id=user_id, - file_name=original_filename, - file_type=parsed_doc.file_type, - total_chunks=len(chunks), - document_id=document_id - ) - metadatas.append(metadata) - chunk_ids.append(metadata["chunk_id"]) - chunk_texts.append(chunk.content) - - # Generate embeddings - embedding_service = get_embedding_service() - embeddings = embedding_service.embed_texts(chunk_texts) - logger.info(f"Generated {len(embeddings)} embeddings") - - # Store in vector database - vector_store = get_vector_store() - vector_store.add_documents( - documents=chunk_texts, - embeddings=embeddings, - metadatas=metadatas, - ids=chunk_ids - ) - - logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored") - - except Exception as e: - logger.error(f"Error processing document {original_filename}: {e}") - raise - - -@app.get("/kb/stats", response_model=KnowledgeBaseStats) -async def get_kb_stats( - request: Request, - tenant_id: Optional[str] = None, # Optional in dev, ignored in prod - kb_id: Optional[str] = None, - user_id: Optional[str] = None # Optional in dev, ignored in prod -): - """Get statistics for a knowledge base.""" - # SECURITY: Get tenant_id and user_id from auth context - auth_context = await get_auth_context(request) - tenant_id_from_auth = auth_context.get("tenant_id") - user_id_from_auth = auth_context.get("user_id") - - if settings.ENV == "prod": - if not tenant_id_from_auth or not user_id_from_auth: - raise HTTPException( - status_code=403, - detail="tenant_id and user_id must come from authentication token in production mode" - ) - tenant_id = tenant_id_from_auth - user_id = user_id_from_auth - else: - tenant_id = tenant_id or tenant_id_from_auth - user_id = user_id or user_id_from_auth - if not tenant_id or not kb_id or not user_id: - raise HTTPException( - status_code=400, - detail="tenant_id, kb_id, and user_id are required" - ) - - try: - vector_store = get_vector_store() - stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id) - - return KnowledgeBaseStats( - tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation - kb_id=kb_id, - user_id=user_id, - total_documents=len(stats.get("file_names", [])), - total_chunks=stats.get("total_chunks", 0), - file_names=stats.get("file_names", []), - last_updated=datetime.utcnow() - ) - except Exception as e: - logger.error(f"Error getting KB stats: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.delete("/kb/document") -async def delete_document( - request: Request, - tenant_id: Optional[str] = None, # Optional in dev, ignored in prod - kb_id: Optional[str] = None, - user_id: Optional[str] = None, # Optional in dev, ignored in prod - file_name: Optional[str] = None -): - """Delete a document from the knowledge base.""" - # SECURITY: Get tenant_id and user_id from auth context - auth_context = await get_auth_context(request) - tenant_id_from_auth = auth_context.get("tenant_id") - user_id_from_auth = auth_context.get("user_id") - - if settings.ENV == "prod": - if not tenant_id_from_auth or not user_id_from_auth: - raise HTTPException( - status_code=403, - detail="tenant_id and user_id must come from authentication token in production mode" - ) - tenant_id = tenant_id_from_auth - user_id = user_id_from_auth - else: - tenant_id = tenant_id or tenant_id_from_auth - user_id = user_id or user_id_from_auth - if not tenant_id or not kb_id or not user_id or not file_name: - raise HTTPException( - status_code=400, - detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)" - ) - - try: - vector_store = get_vector_store() - deleted = vector_store.delete_by_filter({ - "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation - "kb_id": kb_id, - "user_id": user_id, - "file_name": file_name - }) - - return { - "success": True, - "message": f"Deleted {deleted} chunks", - "file_name": file_name - } - except Exception as e: - logger.error(f"Error deleting document: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.delete("/kb/clear") -async def clear_kb( - request: Request, - tenant_id: Optional[str] = None, # Optional in dev, ignored in prod - kb_id: Optional[str] = None, - user_id: Optional[str] = None # Optional in dev, ignored in prod -): - """Clear all documents from a knowledge base.""" - # SECURITY: Get tenant_id and user_id from auth context - auth_context = await get_auth_context(request) - tenant_id_from_auth = auth_context.get("tenant_id") - user_id_from_auth = auth_context.get("user_id") - - if settings.ENV == "prod": - if not tenant_id_from_auth or not user_id_from_auth: - raise HTTPException( - status_code=403, - detail="tenant_id and user_id must come from authentication token in production mode" - ) - tenant_id = tenant_id_from_auth - user_id = user_id_from_auth - else: - tenant_id = tenant_id or tenant_id_from_auth - user_id = user_id or user_id_from_auth - if not tenant_id or not kb_id or not user_id: - raise HTTPException( - status_code=400, - detail="tenant_id, kb_id, and user_id are required" - ) - try: - vector_store = get_vector_store() - deleted = vector_store.delete_by_filter({ - "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation - "kb_id": kb_id, - "user_id": user_id - }) - - return { - "success": True, - "message": f"Cleared knowledge base. Deleted {deleted} chunks.", - "kb_id": kb_id - } - except Exception as e: - logger.error(f"Error clearing KB: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -# ============== Chat Endpoints ============== - -@app.post("/chat", response_model=ChatResponse) -@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key) -async def chat(chat_request: ChatRequest, request: Request): - """ - Process a chat message using RAG. - - - Retrieves relevant context from knowledge base - - Generates answer using LLM - - Returns answer with citations - """ - conversation_id = "unknown" - try: - logger.info(f"=== CHAT REQUEST RECEIVED ===") - logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}") - logger.info(f"Request headers: {dict(request.headers)}") - - # SECURITY: Get tenant_id and user_id from auth context - # In PROD: MUST come from JWT token (never from request body) - try: - auth_context = await get_auth_context(request) - except Exception as e: - logger.error(f"Error getting auth context: {e}", exc_info=True) - raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}") - - tenant_id_from_auth = auth_context.get("tenant_id") - user_id_from_auth = auth_context.get("user_id") - - if settings.ENV == "prod": - if not tenant_id_from_auth or not user_id_from_auth: - raise HTTPException( - status_code=403, - detail="tenant_id and user_id must come from authentication token in production mode" - ) - # Override request values with auth context (security enforcement) - chat_request.tenant_id = tenant_id_from_auth - chat_request.user_id = user_id_from_auth - else: - # DEV mode: use from request if provided, otherwise from auth context - if not chat_request.tenant_id: - chat_request.tenant_id = tenant_id_from_auth - if not chat_request.user_id: - chat_request.user_id = user_id_from_auth - if not chat_request.tenant_id or not chat_request.user_id: - raise HTTPException( - status_code=400, - detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)" - ) - - # Log without PII in production - if settings.ENV == "prod": - logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}") - else: - logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...") - - # Generate conversation ID if not provided - conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}" - - # Get database session - try: - db = next(get_db()) - except Exception as e: - logger.error(f"Database connection error: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") - - try: - # Ensure tenant exists in billing DB - ensure_tenant_exists(db, chat_request.tenant_id) - - # Check quota BEFORE making LLM call - has_quota, quota_error = check_quota(db, chat_request.tenant_id) - if not has_quota: - logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}") - raise HTTPException( - status_code=402, - detail=quota_error or "AI quota exceeded. Upgrade your plan." - ) - - # Retrieve relevant context - retrieval_service = get_retrieval_service() - results, confidence, has_relevant = retrieval_service.retrieve( - query=chat_request.question, - tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation - kb_id=chat_request.kb_id, - user_id=chat_request.user_id - ) - - logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}") - - # Format context for LLM - context, citations_info = retrieval_service.get_context_for_llm(results) - - logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}") - - # Generate answer - answer_service = get_answer_service() - answer_result = answer_service.generate_answer( - question=chat_request.question, - context=context, - citations_info=citations_info, - confidence=confidence, - has_relevant_results=has_relevant - ) - - # Track usage if LLM was called (usage info present) - usage_info = answer_result.get("usage") - if usage_info: - try: - track_usage( - db=db, - tenant_id=chat_request.tenant_id, - user_id=chat_request.user_id, - kb_id=chat_request.kb_id, - provider=settings.LLM_PROVIDER, - model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL), - prompt_tokens=usage_info.get("prompt_tokens", 0), - completion_tokens=usage_info.get("completion_tokens", 0) - ) - except Exception as e: - logger.error(f"Failed to track usage: {e}", exc_info=True) - # Don't fail the request if usage tracking fails - - # Build metadata with refusal info - metadata = { - "chunks_retrieved": len(results), - "kb_id": chat_request.kb_id - } - if "refused" in answer_result: - metadata["refused"] = answer_result["refused"] - if "refusal_reason" in answer_result: - metadata["refusal_reason"] = answer_result["refusal_reason"] - if "verifier_passed" in answer_result: - metadata["verifier_passed"] = answer_result["verifier_passed"] - - return ChatResponse( - success=True, - answer=answer_result["answer"], - citations=answer_result["citations"], - confidence=answer_result["confidence"], - from_knowledge_base=answer_result["from_knowledge_base"], - escalation_suggested=answer_result["escalation_suggested"], - conversation_id=conversation_id, - refused=answer_result.get("refused", False), - metadata=metadata - ) - - except ValueError as e: - # API key or configuration error - error_msg = str(e) - logger.error(f"Configuration error: {error_msg}") - if "API key" in error_msg.lower(): - return ChatResponse( - success=False, - answer="โš ๏ธ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.", - citations=[], - confidence=0.0, - from_knowledge_base=False, - escalation_suggested=True, - conversation_id=conversation_id, - metadata={"error": error_msg, "error_type": "configuration"} - ) - else: - return ChatResponse( - success=False, - answer=f"Configuration error: {error_msg}", - citations=[], - confidence=0.0, - from_knowledge_base=False, - escalation_suggested=True, - conversation_id=conversation_id, - metadata={"error": error_msg} - ) - except HTTPException: - # Re-raise HTTP exceptions (they have proper status codes) - raise - except Exception as e: - logger.error(f"Chat error: {e}", exc_info=True) - logger.error(f"Error type: {type(e).__name__}", exc_info=True) - return ChatResponse( - success=False, - answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", - citations=[], - confidence=0.0, - from_knowledge_base=False, - escalation_suggested=True, - conversation_id=conversation_id, - metadata={"error": str(e), "error_type": type(e).__name__} - ) - except HTTPException: - # Re-raise HTTP exceptions from outer try block - raise - except Exception as e: - logger.error(f"Outer chat error: {e}", exc_info=True) - return ChatResponse( - success=False, - answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", - citations=[], - confidence=0.0, - from_knowledge_base=False, - escalation_suggested=True, - conversation_id=conversation_id, - metadata={"error": str(e), "error_type": type(e).__name__} - ) - - -# ============== Utility Endpoints ============== - -@app.get("/kb/search") -@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key) -async def search_kb( - request: Request, - query: str, - tenant_id: Optional[str] = None, # Optional in dev, ignored in prod - kb_id: Optional[str] = None, - user_id: Optional[str] = None, # Optional in dev, ignored in prod - top_k: int = 5 -): - """ - Search the knowledge base without generating an answer. - Useful for debugging and testing retrieval. - """ - # SECURITY: Extract tenant_id from auth token in production - if settings.ENV == "prod": - auth_context = await require_auth(request) - tenant_id = auth_context.get("tenant_id") - user_id = auth_context.get("user_id") - if not tenant_id or not user_id: - raise HTTPException( - status_code=403, - detail="tenant_id and user_id must come from authentication token in production mode" - ) - elif not tenant_id or not kb_id or not user_id: - raise HTTPException( - status_code=400, - detail="tenant_id, kb_id, and user_id are required" - ) - - try: - retrieval_service = get_retrieval_service() - results, confidence, has_relevant = retrieval_service.retrieve( - query=query, - tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation - kb_id=kb_id, - user_id=user_id, - top_k=top_k - ) - - return { - "success": True, - "results": [ - { - "chunk_id": r.chunk_id, - "content": r.content[:500] + "..." if len(r.content) > 500 else r.content, - "metadata": r.metadata, - "similarity_score": r.similarity_score - } - for r in results - ], - "confidence": confidence, - "has_relevant_results": has_relevant - } - except Exception as e: - logger.error(f"Search error: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -# ============== Billing & Usage Endpoints ============== - -@app.get("/billing/usage", response_model=UsageResponse) -async def get_usage( - request: Request, - range: str = "month", # "day" or "month" - year: Optional[int] = None, - month: Optional[int] = None, - day: Optional[int] = None -): - """ - Get usage statistics for the current tenant. - - Args: - range: "day" or "month" - year: Year (optional, defaults to current) - month: Month 1-12 (optional, defaults to current) - day: Day 1-31 (optional, defaults to current, only for range="day") - """ - # Get tenant from auth - auth_context = await get_auth_context(request) - tenant_id = auth_context.get("tenant_id") - - if not tenant_id: - raise HTTPException(status_code=403, detail="tenant_id required") - - db = next(get_db()) - - try: - from app.db.models import UsageDaily, UsageMonthly - from datetime import datetime - from calendar import monthrange - - now = datetime.utcnow() - target_year = year or now.year - target_month = month or now.month - - if range == "day": - target_day = day or now.day - date_start = datetime(target_year, target_month, target_day) - - daily = db.query(UsageDaily).filter( - UsageDaily.tenant_id == tenant_id, - UsageDaily.date == date_start - ).first() - - if not daily: - return UsageResponse( - tenant_id=tenant_id, - period="day", - total_requests=0, - total_tokens=0, - total_cost_usd=0.0, - start_date=date_start, - end_date=date_start - ) - - return UsageResponse( - tenant_id=tenant_id, - period="day", - total_requests=daily.total_requests, - total_tokens=daily.total_tokens, - total_cost_usd=daily.total_cost_usd, - gemini_requests=daily.gemini_requests, - openai_requests=daily.openai_requests, - start_date=daily.date, - end_date=daily.date - ) - else: # month - monthly = db.query(UsageMonthly).filter( - UsageMonthly.tenant_id == tenant_id, - UsageMonthly.year == target_year, - UsageMonthly.month == target_month - ).first() - - if not monthly: - # Calculate date range for the month - _, last_day = monthrange(target_year, target_month) - start_date = datetime(target_year, target_month, 1) - end_date = datetime(target_year, target_month, last_day) - - return UsageResponse( - tenant_id=tenant_id, - period="month", - total_requests=0, - total_tokens=0, - total_cost_usd=0.0, - start_date=start_date, - end_date=end_date - ) - - _, last_day = monthrange(monthly.year, monthly.month) - start_date = datetime(monthly.year, monthly.month, 1) - end_date = datetime(monthly.year, monthly.month, last_day) - - return UsageResponse( - tenant_id=tenant_id, - period="month", - total_requests=monthly.total_requests, - total_tokens=monthly.total_tokens, - total_cost_usd=monthly.total_cost_usd, - gemini_requests=monthly.gemini_requests, - openai_requests=monthly.openai_requests, - start_date=start_date, - end_date=end_date - ) - except Exception as e: - logger.error(f"Error getting usage: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/billing/limits", response_model=PlanLimitsResponse) -async def get_limits(request: Request): - """Get current plan limits and usage for the tenant.""" - # Get tenant from auth - auth_context = await get_auth_context(request) - tenant_id = auth_context.get("tenant_id") - - if not tenant_id: - raise HTTPException(status_code=403, detail="tenant_id required") - - db = next(get_db()) - - try: - from app.billing.quota import get_tenant_plan, get_monthly_usage - from datetime import datetime - - plan = get_tenant_plan(db, tenant_id) - if not plan: - # Default to starter - plan_name = "starter" - monthly_limit = 500 - else: - plan_name = plan.plan_name - monthly_limit = plan.monthly_chat_limit - - # Get current month usage - now = datetime.utcnow() - monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) - current_usage = monthly_usage.total_requests if monthly_usage else 0 - - remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage) - - return PlanLimitsResponse( - tenant_id=tenant_id, - plan_name=plan_name, - monthly_chat_limit=monthly_limit, - current_month_usage=current_usage, - remaining_chats=remaining - ) - except Exception as e: - logger.error(f"Error getting limits: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/billing/plan") -async def set_plan(request_body: SetPlanRequest, http_request: Request): - """ - Set tenant's subscription plan (admin only in production). - - In dev mode, allows any tenant to set their plan. - In prod mode, should be restricted to admin users. - """ - # Get tenant from auth - auth_context = await get_auth_context(http_request) - auth_tenant_id = auth_context.get("tenant_id") - - # In prod, verify admin role (placeholder - implement actual admin check) - if settings.ENV == "prod": - # TODO: Add admin role check - if auth_tenant_id != request_body.tenant_id: - raise HTTPException(status_code=403, detail="Cannot set plan for other tenants") - - db = next(get_db()) - - try: - from app.billing.quota import set_tenant_plan - - plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name) - - return { - "success": True, - "tenant_id": request_body.tenant_id, - "plan_name": plan.plan_name, - "monthly_chat_limit": plan.monthly_chat_limit - } - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(f"Error setting plan: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/billing/cost-report", response_model=CostReportResponse) -async def get_cost_report( - request: Request, - range: str = "month", - year: Optional[int] = None, - month: Optional[int] = None -): - """Get cost report with breakdown by provider and model.""" - # Get tenant from auth - auth_context = await get_auth_context(request) - tenant_id = auth_context.get("tenant_id") - - if not tenant_id: - raise HTTPException(status_code=403, detail="tenant_id required") - - db = next(get_db()) - - try: - from app.db.models import UsageEvent - from datetime import datetime - from sqlalchemy import func, and_ - - now = datetime.utcnow() - target_year = year or now.year - target_month = month or now.month - - # Query usage events for the period - if range == "month": - query = db.query(UsageEvent).filter( - and_( - UsageEvent.tenant_id == tenant_id, - func.extract('year', UsageEvent.request_timestamp) == target_year, - func.extract('month', UsageEvent.request_timestamp) == target_month - ) - ) - else: # all time - query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id) - - events = query.all() - - # Calculate totals - total_cost = sum(e.estimated_cost_usd for e in events) - total_requests = len(events) - total_tokens = sum(e.total_tokens for e in events) - - # Breakdown by provider - breakdown_by_provider = {} - for event in events: - provider = event.provider - if provider not in breakdown_by_provider: - breakdown_by_provider[provider] = { - "requests": 0, - "tokens": 0, - "cost_usd": 0.0 - } - breakdown_by_provider[provider]["requests"] += 1 - breakdown_by_provider[provider]["tokens"] += event.total_tokens - breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd - - # Breakdown by model - breakdown_by_model = {} - for event in events: - model = event.model - if model not in breakdown_by_model: - breakdown_by_model[model] = { - "requests": 0, - "tokens": 0, - "cost_usd": 0.0 - } - breakdown_by_model[model]["requests"] += 1 - breakdown_by_model[model]["tokens"] += event.total_tokens - breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd - - return CostReportResponse( - tenant_id=tenant_id, - period=range, - total_cost_usd=total_cost, - total_requests=total_requests, - total_tokens=total_tokens, - breakdown_by_provider=breakdown_by_provider, - breakdown_by_model=breakdown_by_model - ) - except Exception as e: - logger.error(f"Error getting cost report: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) - +""" +FastAPI application for ClientSphere RAG Backend. +Provides endpoints for knowledge base management and chat. +""" +from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from pathlib import Path +import shutil +import uuid +from datetime import datetime +from typing import Optional +import logging + +from app.config import settings +from app.middleware.auth import get_auth_context, require_auth +from app.middleware.rate_limit import ( + limiter, + get_tenant_rate_limit_key, + RateLimitExceeded, + _rate_limit_exceeded_handler +) +from app.models.schemas import ( + UploadResponse, + ChatRequest, + ChatResponse, + KnowledgeBaseStats, + HealthResponse, + DocumentStatus, + Citation, +) +from app.models.billing_schemas import ( + UsageResponse, + PlanLimitsResponse, + CostReportResponse, + SetPlanRequest +) +from app.rag.ingest import parser +from app.rag.chunking import chunker +from app.rag.embeddings import get_embedding_service +from app.rag.vectorstore import get_vector_store +from app.rag.retrieval import get_retrieval_service +from app.rag.answer import get_answer_service +from app.db.database import get_db, init_db +from app.billing.quota import check_quota, ensure_tenant_exists +from app.billing.usage_tracker import track_usage + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title=settings.APP_NAME, + description="RAG-based customer support chatbot API", + version="1.0.0", +) + +# Initialize database on startup +@app.on_event("startup") +async def startup_event(): + """Initialize database on application startup.""" + init_db() + logger.info("Database initialized") + +# Configure CORS - SECURITY: Restrict in production +if settings.ALLOWED_ORIGINS == "*": + allowed_origins = ["*"] +else: + # Split by comma and strip whitespace + allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()] + +# Default to allowing localhost if no origins specified +if not allowed_origins or allowed_origins == ["*"]: + allowed_origins = ["*"] # Allow all in dev mode + +logger.info(f"CORS configured with origins: {allowed_origins}") + +app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight + allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers +) + +# Configure rate limiting +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# Add exception handler for validation errors +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors with detailed logging.""" + body = await request.body() + logger.error(f"Request validation error: {exc.errors()}") + logger.error(f"Request body (raw): {body}") + logger.error(f"Request headers: {dict(request.headers)}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')} + ) + +# Add exception handler for validation errors +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors with detailed logging.""" + logger.error(f"Request validation error: {exc.errors()}") + logger.error(f"Request body: {await request.body()}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": str(await request.body())} + ) + + +# ============== Health & Status Endpoints ============== + +@app.get("/", response_model=HealthResponse) +async def root(): + """Root endpoint with basic info.""" + return HealthResponse( + status="ok", + version="1.0.0", + vector_db_connected=True, + llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + ) + + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint.""" + try: + vector_store = get_vector_store() + stats = vector_store.get_stats() + + return HealthResponse( + status="healthy", + version="1.0.0", + vector_db_connected=True, + llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + ) + except Exception as e: + logger.error(f"Health check failed: {e}") + return HealthResponse( + status="unhealthy", + version="1.0.0", + vector_db_connected=False, + llm_configured=False + ) + + +@app.get("/health/live") +async def liveness(): + """Kubernetes liveness probe - always returns alive.""" + return {"status": "alive"} + + +@app.get("/health/ready") +async def readiness(): + """Kubernetes readiness probe - checks dependencies.""" + checks = { + "vector_db": False, + "llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + } + + # Check vector DB connection + try: + vector_store = get_vector_store() + vector_store.get_stats() + checks["vector_db"] = True + except Exception as e: + logger.warning(f"Vector DB check failed: {e}") + checks["vector_db"] = False + + # All checks must pass + if all(checks.values()): + return {"status": "ready", "checks": checks} + else: + from fastapi import HTTPException + raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks}) + + +# ============== Knowledge Base Endpoints ============== + +@app.post("/kb/upload", response_model=UploadResponse) +@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key) +async def upload_document( + background_tasks: BackgroundTasks, + request: Request, + file: UploadFile = File(...), + tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod + user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod + kb_id: str = Form(...) +): + """ + Upload a document to the knowledge base. + + - Saves file to disk + - Parses and chunks the document + - Generates embeddings + - Stores in vector database + """ + # SECURITY: Extract tenant_id from auth token in production + if settings.ENV == "prod": + auth_context = await require_auth(request) + tenant_id = auth_context.get("tenant_id") + if not tenant_id: + raise HTTPException( + status_code=403, + detail="tenant_id must come from authentication token in production mode" + ) + elif not tenant_id: + raise HTTPException( + status_code=400, + detail="tenant_id is required" + ) + + # Validate file type + file_ext = Path(file.filename).suffix.lower() + if file_ext not in parser.SUPPORTED_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}" + ) + + # Validate file size (SECURITY) + file.file.seek(0, 2) # Seek to end + file_size = file.file.tell() + file.file.seek(0) # Reset to start + max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024 + if file_size > max_size_bytes: + raise HTTPException( + status_code=400, + detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB" + ) + + # Generate document ID + doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}" + + # Save file to uploads directory + upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}" + try: + with open(upload_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + logger.info(f"Saved file: {upload_path}") + except Exception as e: + logger.error(f"Error saving file: {e}") + raise HTTPException(status_code=500, detail="Failed to save file") + + # Process document in background + background_tasks.add_task( + process_document, + upload_path, + tenant_id, # CRITICAL: Multi-tenant isolation + user_id, + kb_id, + file.filename, + doc_id + ) + + return UploadResponse( + success=True, + message="Document upload started. Processing in background.", + document_id=doc_id, + file_name=file.filename, + chunks_created=0, + status=DocumentStatus.PROCESSING + ) + + +async def process_document( + file_path: Path, + tenant_id: str, # CRITICAL: Multi-tenant isolation + user_id: str, + kb_id: str, + original_filename: str, + document_id: str +): + """ + Background task to process an uploaded document. + """ + try: + logger.info(f"Processing document: {original_filename}") + + # Parse document + parsed_doc = parser.parse(file_path) + logger.info(f"Parsed document: {len(parsed_doc.text)} characters") + + # Chunk document + chunks = chunker.chunk_text( + parsed_doc.text, + page_numbers=parsed_doc.page_map + ) + logger.info(f"Created {len(chunks)} chunks") + + if not chunks: + logger.warning(f"No chunks created from {original_filename}") + return + + # Create metadata for each chunk + metadatas = [] + chunk_ids = [] + chunk_texts = [] + + for chunk in chunks: + metadata = chunker.create_chunk_metadata( + chunk=chunk, + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + file_name=original_filename, + file_type=parsed_doc.file_type, + total_chunks=len(chunks), + document_id=document_id + ) + metadatas.append(metadata) + chunk_ids.append(metadata["chunk_id"]) + chunk_texts.append(chunk.content) + + # Generate embeddings + embedding_service = get_embedding_service() + embeddings = embedding_service.embed_texts(chunk_texts) + logger.info(f"Generated {len(embeddings)} embeddings") + + # Store in vector database + vector_store = get_vector_store() + vector_store.add_documents( + documents=chunk_texts, + embeddings=embeddings, + metadatas=metadatas, + ids=chunk_ids + ) + + logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored") + + except Exception as e: + logger.error(f"Error processing document {original_filename}: {e}") + raise + + +@app.get("/kb/stats", response_model=KnowledgeBaseStats) +async def get_kb_stats( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None # Optional in dev, ignored in prod +): + """Get statistics for a knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + + try: + vector_store = get_vector_store() + stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id) + + return KnowledgeBaseStats( + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + total_documents=len(stats.get("file_names", [])), + total_chunks=stats.get("total_chunks", 0), + file_names=stats.get("file_names", []), + last_updated=datetime.utcnow() + ) + except Exception as e: + logger.error(f"Error getting KB stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.delete("/kb/document") +async def delete_document( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None, # Optional in dev, ignored in prod + file_name: Optional[str] = None +): + """Delete a document from the knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id or not file_name: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)" + ) + + try: + vector_store = get_vector_store() + deleted = vector_store.delete_by_filter({ + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id, + "file_name": file_name + }) + + return { + "success": True, + "message": f"Deleted {deleted} chunks", + "file_name": file_name + } + except Exception as e: + logger.error(f"Error deleting document: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.delete("/kb/clear") +async def clear_kb( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None # Optional in dev, ignored in prod +): + """Clear all documents from a knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + try: + vector_store = get_vector_store() + deleted = vector_store.delete_by_filter({ + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id + }) + + return { + "success": True, + "message": f"Cleared knowledge base. Deleted {deleted} chunks.", + "kb_id": kb_id + } + except Exception as e: + logger.error(f"Error clearing KB: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============== Chat Endpoints ============== + +@app.post("/chat", response_model=ChatResponse) +@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key) +async def chat(chat_request: ChatRequest, request: Request): + """ + Process a chat message using RAG. + + - Retrieves relevant context from knowledge base + - Generates answer using LLM + - Returns answer with citations + """ + conversation_id = "unknown" + try: + logger.info(f"=== CHAT REQUEST RECEIVED ===") + logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}") + logger.info(f"Request headers: {dict(request.headers)}") + + # SECURITY: Get tenant_id and user_id from auth context + # In PROD: MUST come from JWT token (never from request body) + try: + auth_context = await get_auth_context(request) + except Exception as e: + logger.error(f"Error getting auth context: {e}", exc_info=True) + raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}") + + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + # Override request values with auth context (security enforcement) + chat_request.tenant_id = tenant_id_from_auth + chat_request.user_id = user_id_from_auth + else: + # DEV mode: use from request if provided, otherwise from auth context + if not chat_request.tenant_id: + chat_request.tenant_id = tenant_id_from_auth + if not chat_request.user_id: + chat_request.user_id = user_id_from_auth + if not chat_request.tenant_id or not chat_request.user_id: + raise HTTPException( + status_code=400, + detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)" + ) + + # Log without PII in production + if settings.ENV == "prod": + logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}") + else: + logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...") + + # Generate conversation ID if not provided + conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}" + + # Get database session + try: + db = next(get_db()) + except Exception as e: + logger.error(f"Database connection error: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") + + try: + # Ensure tenant exists in billing DB + ensure_tenant_exists(db, chat_request.tenant_id) + + # Check quota BEFORE making LLM call + has_quota, quota_error = check_quota(db, chat_request.tenant_id) + if not has_quota: + logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}") + raise HTTPException( + status_code=402, + detail=quota_error or "AI quota exceeded. Upgrade your plan." + ) + + # Retrieve relevant context + retrieval_service = get_retrieval_service() + results, confidence, has_relevant = retrieval_service.retrieve( + query=chat_request.question, + tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=chat_request.kb_id, + user_id=chat_request.user_id + ) + + logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}") + + # Format context for LLM + context, citations_info = retrieval_service.get_context_for_llm(results) + + logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}") + + # Generate answer + answer_service = get_answer_service() + answer_result = answer_service.generate_answer( + question=chat_request.question, + context=context, + citations_info=citations_info, + confidence=confidence, + has_relevant_results=has_relevant + ) + + # Track usage if LLM was called (usage info present) + usage_info = answer_result.get("usage") + if usage_info: + try: + track_usage( + db=db, + tenant_id=chat_request.tenant_id, + user_id=chat_request.user_id, + kb_id=chat_request.kb_id, + provider=settings.LLM_PROVIDER, + model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL), + prompt_tokens=usage_info.get("prompt_tokens", 0), + completion_tokens=usage_info.get("completion_tokens", 0) + ) + except Exception as e: + logger.error(f"Failed to track usage: {e}", exc_info=True) + # Don't fail the request if usage tracking fails + + # Build metadata with refusal info + metadata = { + "chunks_retrieved": len(results), + "kb_id": chat_request.kb_id + } + if "refused" in answer_result: + metadata["refused"] = answer_result["refused"] + if "refusal_reason" in answer_result: + metadata["refusal_reason"] = answer_result["refusal_reason"] + if "verifier_passed" in answer_result: + metadata["verifier_passed"] = answer_result["verifier_passed"] + + return ChatResponse( + success=True, + answer=answer_result["answer"], + citations=answer_result["citations"], + confidence=answer_result["confidence"], + from_knowledge_base=answer_result["from_knowledge_base"], + escalation_suggested=answer_result["escalation_suggested"], + conversation_id=conversation_id, + refused=answer_result.get("refused", False), + metadata=metadata + ) + + except ValueError as e: + # API key or configuration error + error_msg = str(e) + logger.error(f"Configuration error: {error_msg}") + if "API key" in error_msg.lower(): + return ChatResponse( + success=False, + answer="โš ๏ธ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": error_msg, "error_type": "configuration"} + ) + else: + return ChatResponse( + success=False, + answer=f"Configuration error: {error_msg}", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": error_msg} + ) + except HTTPException: + # Re-raise HTTP exceptions (they have proper status codes) + raise + except Exception as e: + logger.error(f"Chat error: {e}", exc_info=True) + logger.error(f"Error type: {type(e).__name__}", exc_info=True) + return ChatResponse( + success=False, + answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": str(e), "error_type": type(e).__name__} + ) + except HTTPException: + # Re-raise HTTP exceptions from outer try block + raise + except Exception as e: + logger.error(f"Outer chat error: {e}", exc_info=True) + return ChatResponse( + success=False, + answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": str(e), "error_type": type(e).__name__} + ) + + +# ============== Utility Endpoints ============== + +@app.get("/kb/search") +@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key) +async def search_kb( + request: Request, + query: str, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None, # Optional in dev, ignored in prod + top_k: int = 5 +): + """ + Search the knowledge base without generating an answer. + Useful for debugging and testing retrieval. + """ + # SECURITY: Extract tenant_id from auth token in production + if settings.ENV == "prod": + auth_context = await require_auth(request) + tenant_id = auth_context.get("tenant_id") + user_id = auth_context.get("user_id") + if not tenant_id or not user_id: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + elif not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + + try: + retrieval_service = get_retrieval_service() + results, confidence, has_relevant = retrieval_service.retrieve( + query=query, + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + top_k=top_k + ) + + return { + "success": True, + "results": [ + { + "chunk_id": r.chunk_id, + "content": r.content[:500] + "..." if len(r.content) > 500 else r.content, + "metadata": r.metadata, + "similarity_score": r.similarity_score + } + for r in results + ], + "confidence": confidence, + "has_relevant_results": has_relevant + } + except Exception as e: + logger.error(f"Search error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============== Billing & Usage Endpoints ============== + +@app.get("/billing/usage", response_model=UsageResponse) +async def get_usage( + request: Request, + range: str = "month", # "day" or "month" + year: Optional[int] = None, + month: Optional[int] = None, + day: Optional[int] = None +): + """ + Get usage statistics for the current tenant. + + Args: + range: "day" or "month" + year: Year (optional, defaults to current) + month: Month 1-12 (optional, defaults to current) + day: Day 1-31 (optional, defaults to current, only for range="day") + """ + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.db.models import UsageDaily, UsageMonthly + from datetime import datetime + from calendar import monthrange + + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + if range == "day": + target_day = day or now.day + date_start = datetime(target_year, target_month, target_day) + + daily = db.query(UsageDaily).filter( + UsageDaily.tenant_id == tenant_id, + UsageDaily.date == date_start + ).first() + + if not daily: + return UsageResponse( + tenant_id=tenant_id, + period="day", + total_requests=0, + total_tokens=0, + total_cost_usd=0.0, + start_date=date_start, + end_date=date_start + ) + + return UsageResponse( + tenant_id=tenant_id, + period="day", + total_requests=daily.total_requests, + total_tokens=daily.total_tokens, + total_cost_usd=daily.total_cost_usd, + gemini_requests=daily.gemini_requests, + openai_requests=daily.openai_requests, + start_date=daily.date, + end_date=daily.date + ) + else: # month + monthly = db.query(UsageMonthly).filter( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == target_year, + UsageMonthly.month == target_month + ).first() + + if not monthly: + # Calculate date range for the month + _, last_day = monthrange(target_year, target_month) + start_date = datetime(target_year, target_month, 1) + end_date = datetime(target_year, target_month, last_day) + + return UsageResponse( + tenant_id=tenant_id, + period="month", + total_requests=0, + total_tokens=0, + total_cost_usd=0.0, + start_date=start_date, + end_date=end_date + ) + + _, last_day = monthrange(monthly.year, monthly.month) + start_date = datetime(monthly.year, monthly.month, 1) + end_date = datetime(monthly.year, monthly.month, last_day) + + return UsageResponse( + tenant_id=tenant_id, + period="month", + total_requests=monthly.total_requests, + total_tokens=monthly.total_tokens, + total_cost_usd=monthly.total_cost_usd, + gemini_requests=monthly.gemini_requests, + openai_requests=monthly.openai_requests, + start_date=start_date, + end_date=end_date + ) + except Exception as e: + logger.error(f"Error getting usage: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/billing/limits", response_model=PlanLimitsResponse) +async def get_limits(request: Request): + """Get current plan limits and usage for the tenant.""" + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.billing.quota import get_tenant_plan, get_monthly_usage + from datetime import datetime + + plan = get_tenant_plan(db, tenant_id) + if not plan: + # Default to starter + plan_name = "starter" + monthly_limit = 500 + else: + plan_name = plan.plan_name + monthly_limit = plan.monthly_chat_limit + + # Get current month usage + now = datetime.utcnow() + monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) + current_usage = monthly_usage.total_requests if monthly_usage else 0 + + remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage) + + return PlanLimitsResponse( + tenant_id=tenant_id, + plan_name=plan_name, + monthly_chat_limit=monthly_limit, + current_month_usage=current_usage, + remaining_chats=remaining + ) + except Exception as e: + logger.error(f"Error getting limits: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/billing/plan") +async def set_plan(request_body: SetPlanRequest, http_request: Request): + """ + Set tenant's subscription plan (admin only in production). + + In dev mode, allows any tenant to set their plan. + In prod mode, should be restricted to admin users. + """ + # Get tenant from auth + auth_context = await get_auth_context(http_request) + auth_tenant_id = auth_context.get("tenant_id") + + # In prod, verify admin role (placeholder - implement actual admin check) + if settings.ENV == "prod": + # TODO: Add admin role check + if auth_tenant_id != request_body.tenant_id: + raise HTTPException(status_code=403, detail="Cannot set plan for other tenants") + + db = next(get_db()) + + try: + from app.billing.quota import set_tenant_plan + + plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name) + + return { + "success": True, + "tenant_id": request_body.tenant_id, + "plan_name": plan.plan_name, + "monthly_chat_limit": plan.monthly_chat_limit + } + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error setting plan: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/billing/cost-report", response_model=CostReportResponse) +async def get_cost_report( + request: Request, + range: str = "month", + year: Optional[int] = None, + month: Optional[int] = None +): + """Get cost report with breakdown by provider and model.""" + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.db.models import UsageEvent + from datetime import datetime + from sqlalchemy import func, and_ + + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + # Query usage events for the period + if range == "month": + query = db.query(UsageEvent).filter( + and_( + UsageEvent.tenant_id == tenant_id, + func.extract('year', UsageEvent.request_timestamp) == target_year, + func.extract('month', UsageEvent.request_timestamp) == target_month + ) + ) + else: # all time + query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id) + + events = query.all() + + # Calculate totals + total_cost = sum(e.estimated_cost_usd for e in events) + total_requests = len(events) + total_tokens = sum(e.total_tokens for e in events) + + # Breakdown by provider + breakdown_by_provider = {} + for event in events: + provider = event.provider + if provider not in breakdown_by_provider: + breakdown_by_provider[provider] = { + "requests": 0, + "tokens": 0, + "cost_usd": 0.0 + } + breakdown_by_provider[provider]["requests"] += 1 + breakdown_by_provider[provider]["tokens"] += event.total_tokens + breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd + + # Breakdown by model + breakdown_by_model = {} + for event in events: + model = event.model + if model not in breakdown_by_model: + breakdown_by_model[model] = { + "requests": 0, + "tokens": 0, + "cost_usd": 0.0 + } + breakdown_by_model[model]["requests"] += 1 + breakdown_by_model[model]["tokens"] += event.total_tokens + breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd + + return CostReportResponse( + tenant_id=tenant_id, + period=range, + total_cost_usd=total_cost, + total_requests=total_requests, + total_tokens=total_tokens, + breakdown_by_provider=breakdown_by_provider, + breakdown_by_model=breakdown_by_model + ) + except Exception as e: + logger.error(f"Error getting cost report: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py index c669ae02552cd455662b0442a17d5819ccf09f0d..dbbc1360534f7c3bcd1d46c63e645e50fc62d15e 100644 --- a/app/middleware/__init__.py +++ b/app/middleware/__init__.py @@ -1,13 +1,13 @@ -""" -Middleware for authentication, rate limiting, etc. -""" -from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth - -__all__ = [ - "verify_tenant_access", - "get_tenant_from_token", - "require_auth", -] - - - +""" +Middleware for authentication, rate limiting, etc. +""" +from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth + +__all__ = [ + "verify_tenant_access", + "get_tenant_from_token", + "require_auth", +] + + + diff --git a/app/middleware/__pycache__/__init__.cpython-313.pyc b/app/middleware/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index a3d66b94a8536c5164d1b6f2815522d1f7adbe3e..0000000000000000000000000000000000000000 Binary files a/app/middleware/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/app/middleware/__pycache__/auth.cpython-313.pyc b/app/middleware/__pycache__/auth.cpython-313.pyc deleted file mode 100644 index 91aa6e10d78b1ff9efc4a7abbf5a1c4c094d72ba..0000000000000000000000000000000000000000 Binary files a/app/middleware/__pycache__/auth.cpython-313.pyc and /dev/null differ diff --git a/app/middleware/__pycache__/rate_limit.cpython-313.pyc b/app/middleware/__pycache__/rate_limit.cpython-313.pyc deleted file mode 100644 index 719e95084d7c04e1142d97f811d06d6fe5d68a79..0000000000000000000000000000000000000000 Binary files a/app/middleware/__pycache__/rate_limit.cpython-313.pyc and /dev/null differ diff --git a/app/middleware/auth.py b/app/middleware/auth.py index 337c397b266e8dd463d0335f878f4998287fb3c2..64252966fcc32df6290aaf0beebab04cca649b1c 100644 --- a/app/middleware/auth.py +++ b/app/middleware/auth.py @@ -1,212 +1,212 @@ -""" -Authentication and authorization middleware. -Extracts tenant_id from JWT token in production mode. -""" -from fastapi import Request, HTTPException, Depends -from typing import Optional, Dict, Any -import logging -from jose import JWTError, jwt - -from app.config import settings - -logger = logging.getLogger(__name__) - - -async def verify_tenant_access( - request: Request, - tenant_id: str, - user_id: str -) -> bool: - """ - Verify that the user has access to the specified tenant. - - TODO: Implement actual authentication logic: - 1. Extract JWT token from Authorization header - 2. Verify token signature - 3. Extract user_id and tenant_id from token - 4. Verify user belongs to tenant - 5. Check permissions - - Args: - request: FastAPI request object - tenant_id: Tenant ID from request - user_id: User ID from request - - Returns: - True if access is granted, False otherwise - """ - # TODO: Implement actual authentication - # For now, this is a placeholder that always returns True - # In production, you MUST implement proper auth - - # Example implementation: - # token = request.headers.get("Authorization", "").replace("Bearer ", "") - # if not token: - # return False - # - # decoded = verify_jwt_token(token) - # if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id: - # return False - # - # return True - - logger.warning("โš ๏ธ Authentication middleware not implemented - using placeholder") - return True - - -def get_tenant_from_token(request: Request) -> Optional[str]: - """ - Extract tenant_id from authentication token. - - In production mode, extracts tenant_id from JWT token. - In dev mode, returns None (allows request tenant_id). - - Args: - request: FastAPI request object - - Returns: - Tenant ID if found in token, None otherwise - """ - if settings.ENV == "dev": - # Dev mode: allow request tenant_id - return None - - # Production mode: extract from JWT - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - logger.warning("Missing or invalid Authorization header") - return None - - token = auth_header.replace("Bearer ", "").strip() - if not token: - return None - - try: - # TODO: Replace with your actual JWT secret key - # For now, this is a placeholder that expects a specific token format - # In production, you should: - # 1. Get JWT_SECRET from environment - # 2. Verify token signature - # 3. Extract tenant_id from token payload - - # Example implementation (replace with your actual JWT verification): - # JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key") - # decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) - # return decoded.get("tenant_id") - - # Placeholder: Try to decode without verification (for testing) - # In production, you MUST verify the signature - try: - decoded = jwt.decode(token, options={"verify_signature": False}) - tenant_id = decoded.get("tenant_id") - if tenant_id: - logger.info(f"Extracted tenant_id from token: {tenant_id}") - return tenant_id - except jwt.DecodeError: - logger.warning("Failed to decode JWT token") - return None - - except Exception as e: - logger.error(f"Error extracting tenant from token: {e}") - return None - - return None - - -async def get_auth_context(request: Request) -> Dict[str, Any]: - """ - Get authentication context from request. - - DEV mode: - - Allows X-Tenant-Id and X-User-Id headers - - Falls back to defaults if missing - - PROD mode: - - Requires Authorization: Bearer - - Verifies JWT using JWT_SECRET - - Extracts tenant_id and user_id from token claims - - NEVER accepts tenant_id from request body/query params - - Args: - request: FastAPI request object - - Returns: - Dictionary with user_id and tenant_id - - Raises: - HTTPException: If authentication fails (production mode only) - """ - if settings.ENV == "dev": - # Dev mode: allow headers, fallback to defaults - tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant") - user_id = request.headers.get("X-User-Id", "dev_user") - return { - "user_id": user_id, - "tenant_id": tenant_id - } - - # Production mode: require JWT token - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): - raise HTTPException( - status_code=401, - detail="Authentication required. Provide valid Bearer token in Authorization header." - ) - - token = auth_header.replace("Bearer ", "").strip() - if not token: - raise HTTPException( - status_code=401, - detail="Invalid token format." - ) - - # Verify JWT token - if not settings.JWT_SECRET: - logger.error("JWT_SECRET not configured in production mode") - raise HTTPException( - status_code=500, - detail="Server configuration error: JWT_SECRET not set" - ) - - try: - decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"]) - - user_id = decoded.get("user_id") or decoded.get("sub") - tenant_id = decoded.get("tenant_id") - - if not user_id or not tenant_id: - raise HTTPException( - status_code=401, - detail="Token missing required claims (user_id, tenant_id)." - ) - - logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}") - return { - "user_id": user_id, - "tenant_id": tenant_id, - "email": decoded.get("email"), - "role": decoded.get("role") - } - - except JWTError as e: - logger.warning(f"JWT verification failed: {e}") - raise HTTPException( - status_code=401, - detail="Invalid or expired token." - ) - except Exception as e: - logger.error(f"Auth error: {e}", exc_info=True) - raise HTTPException( - status_code=401, - detail="Authentication failed." - ) - - -# FastAPI dependency for easy use in endpoints -async def require_auth(request: Request) -> Dict[str, Any]: - """ - FastAPI dependency for requiring authentication. - Alias for get_auth_context for backward compatibility. - """ - return await get_auth_context(request) - +""" +Authentication and authorization middleware. +Extracts tenant_id from JWT token in production mode. +""" +from fastapi import Request, HTTPException, Depends +from typing import Optional, Dict, Any +import logging +from jose import JWTError, jwt + +from app.config import settings + +logger = logging.getLogger(__name__) + + +async def verify_tenant_access( + request: Request, + tenant_id: str, + user_id: str +) -> bool: + """ + Verify that the user has access to the specified tenant. + + TODO: Implement actual authentication logic: + 1. Extract JWT token from Authorization header + 2. Verify token signature + 3. Extract user_id and tenant_id from token + 4. Verify user belongs to tenant + 5. Check permissions + + Args: + request: FastAPI request object + tenant_id: Tenant ID from request + user_id: User ID from request + + Returns: + True if access is granted, False otherwise + """ + # TODO: Implement actual authentication + # For now, this is a placeholder that always returns True + # In production, you MUST implement proper auth + + # Example implementation: + # token = request.headers.get("Authorization", "").replace("Bearer ", "") + # if not token: + # return False + # + # decoded = verify_jwt_token(token) + # if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id: + # return False + # + # return True + + logger.warning("โš ๏ธ Authentication middleware not implemented - using placeholder") + return True + + +def get_tenant_from_token(request: Request) -> Optional[str]: + """ + Extract tenant_id from authentication token. + + In production mode, extracts tenant_id from JWT token. + In dev mode, returns None (allows request tenant_id). + + Args: + request: FastAPI request object + + Returns: + Tenant ID if found in token, None otherwise + """ + if settings.ENV == "dev": + # Dev mode: allow request tenant_id + return None + + # Production mode: extract from JWT + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + logger.warning("Missing or invalid Authorization header") + return None + + token = auth_header.replace("Bearer ", "").strip() + if not token: + return None + + try: + # TODO: Replace with your actual JWT secret key + # For now, this is a placeholder that expects a specific token format + # In production, you should: + # 1. Get JWT_SECRET from environment + # 2. Verify token signature + # 3. Extract tenant_id from token payload + + # Example implementation (replace with your actual JWT verification): + # JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key") + # decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) + # return decoded.get("tenant_id") + + # Placeholder: Try to decode without verification (for testing) + # In production, you MUST verify the signature + try: + decoded = jwt.decode(token, options={"verify_signature": False}) + tenant_id = decoded.get("tenant_id") + if tenant_id: + logger.info(f"Extracted tenant_id from token: {tenant_id}") + return tenant_id + except jwt.DecodeError: + logger.warning("Failed to decode JWT token") + return None + + except Exception as e: + logger.error(f"Error extracting tenant from token: {e}") + return None + + return None + + +async def get_auth_context(request: Request) -> Dict[str, Any]: + """ + Get authentication context from request. + + DEV mode: + - Allows X-Tenant-Id and X-User-Id headers + - Falls back to defaults if missing + + PROD mode: + - Requires Authorization: Bearer + - Verifies JWT using JWT_SECRET + - Extracts tenant_id and user_id from token claims + - NEVER accepts tenant_id from request body/query params + + Args: + request: FastAPI request object + + Returns: + Dictionary with user_id and tenant_id + + Raises: + HTTPException: If authentication fails (production mode only) + """ + if settings.ENV == "dev": + # Dev mode: allow headers, fallback to defaults + tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant") + user_id = request.headers.get("X-User-Id", "dev_user") + return { + "user_id": user_id, + "tenant_id": tenant_id + } + + # Production mode: require JWT token + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Authentication required. Provide valid Bearer token in Authorization header." + ) + + token = auth_header.replace("Bearer ", "").strip() + if not token: + raise HTTPException( + status_code=401, + detail="Invalid token format." + ) + + # Verify JWT token + if not settings.JWT_SECRET: + logger.error("JWT_SECRET not configured in production mode") + raise HTTPException( + status_code=500, + detail="Server configuration error: JWT_SECRET not set" + ) + + try: + decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"]) + + user_id = decoded.get("user_id") or decoded.get("sub") + tenant_id = decoded.get("tenant_id") + + if not user_id or not tenant_id: + raise HTTPException( + status_code=401, + detail="Token missing required claims (user_id, tenant_id)." + ) + + logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}") + return { + "user_id": user_id, + "tenant_id": tenant_id, + "email": decoded.get("email"), + "role": decoded.get("role") + } + + except JWTError as e: + logger.warning(f"JWT verification failed: {e}") + raise HTTPException( + status_code=401, + detail="Invalid or expired token." + ) + except Exception as e: + logger.error(f"Auth error: {e}", exc_info=True) + raise HTTPException( + status_code=401, + detail="Authentication failed." + ) + + +# FastAPI dependency for easy use in endpoints +async def require_auth(request: Request) -> Dict[str, Any]: + """ + FastAPI dependency for requiring authentication. + Alias for get_auth_context for backward compatibility. + """ + return await get_auth_context(request) + diff --git a/app/middleware/rate_limit.py b/app/middleware/rate_limit.py index 3a5c6768e4a53f40556415f00cf35c9587811fc9..83346b0ead951cc2f391cb5f16e9606a059939f1 100644 --- a/app/middleware/rate_limit.py +++ b/app/middleware/rate_limit.py @@ -1,40 +1,40 @@ -""" -Rate limiting middleware using slowapi. -""" -from slowapi import Limiter, _rate_limit_exceeded_handler -from slowapi.util import get_remote_address -from slowapi.errors import RateLimitExceeded -from fastapi import Request -import logging - -from app.config import settings - -logger = logging.getLogger(__name__) - -# Initialize limiter with default limits (can be overridden per endpoint) -limiter = Limiter( - key_func=get_remote_address, - default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else [] -) - - -def get_tenant_rate_limit_key(request: Request) -> str: - """ - Get rate limit key based on tenant_id from headers (dev) or IP (prod). - - Note: This is a sync function called by slowapi, so we can't await async functions. - In dev mode, we extract tenant_id from X-Tenant-Id header. - In prod mode, we fall back to IP address (rate limiting happens before auth). - """ - # Try to get tenant_id from headers (works in dev mode) - tenant_id = request.headers.get("X-Tenant-Id") - if tenant_id: - return f"tenant:{tenant_id}" - - # Fallback to IP address (for prod mode or if no header) - return get_remote_address(request) - - -# Export limiter and key function -__all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"] - +""" +Rate limiting middleware using slowapi. +""" +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from fastapi import Request +import logging + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Initialize limiter with default limits (can be overridden per endpoint) +limiter = Limiter( + key_func=get_remote_address, + default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else [] +) + + +def get_tenant_rate_limit_key(request: Request) -> str: + """ + Get rate limit key based on tenant_id from headers (dev) or IP (prod). + + Note: This is a sync function called by slowapi, so we can't await async functions. + In dev mode, we extract tenant_id from X-Tenant-Id header. + In prod mode, we fall back to IP address (rate limiting happens before auth). + """ + # Try to get tenant_id from headers (works in dev mode) + tenant_id = request.headers.get("X-Tenant-Id") + if tenant_id: + return f"tenant:{tenant_id}" + + # Fallback to IP address (for prod mode or if no header) + return get_remote_address(request) + + +# Export limiter and key function +__all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"] + diff --git a/app/models/__init__.py b/app/models/__init__.py index 0410d8f5408171132e9f3086da9407c2607fec70..e16d79925ac5752e660fb949d4078219fd974744 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,33 +1,33 @@ -""" -Pydantic models for the RAG backend. -""" -from app.models.schemas import ( - DocumentStatus, - ChunkMetadata, - DocumentChunk, - UploadRequest, - UploadResponse, - Citation, - ChatRequest, - ChatResponse, - RetrievalResult, - KnowledgeBaseStats, - HealthResponse, -) - -__all__ = [ - "DocumentStatus", - "ChunkMetadata", - "DocumentChunk", - "UploadRequest", - "UploadResponse", - "Citation", - "ChatRequest", - "ChatResponse", - "RetrievalResult", - "KnowledgeBaseStats", - "HealthResponse", -] - - - +""" +Pydantic models for the RAG backend. +""" +from app.models.schemas import ( + DocumentStatus, + ChunkMetadata, + DocumentChunk, + UploadRequest, + UploadResponse, + Citation, + ChatRequest, + ChatResponse, + RetrievalResult, + KnowledgeBaseStats, + HealthResponse, +) + +__all__ = [ + "DocumentStatus", + "ChunkMetadata", + "DocumentChunk", + "UploadRequest", + "UploadResponse", + "Citation", + "ChatRequest", + "ChatResponse", + "RetrievalResult", + "KnowledgeBaseStats", + "HealthResponse", +] + + + diff --git a/app/models/__pycache__/__init__.cpython-313.pyc b/app/models/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 95b6ff7ec9ff159d99735ecb8e2ef5315b1b3da0..0000000000000000000000000000000000000000 Binary files a/app/models/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/app/models/__pycache__/billing_schemas.cpython-313.pyc b/app/models/__pycache__/billing_schemas.cpython-313.pyc deleted file mode 100644 index 95849d35dff01c78c13c9077b0b54e56108d2ea9..0000000000000000000000000000000000000000 Binary files a/app/models/__pycache__/billing_schemas.cpython-313.pyc and /dev/null differ diff --git a/app/models/__pycache__/schemas.cpython-313.pyc b/app/models/__pycache__/schemas.cpython-313.pyc deleted file mode 100644 index 866862af35e8ce00c238dd94b6ce7fba372a1123..0000000000000000000000000000000000000000 Binary files a/app/models/__pycache__/schemas.cpython-313.pyc and /dev/null differ diff --git a/app/models/billing_schemas.py b/app/models/billing_schemas.py index 51c26fec17e4ac1497ef17bda4f1eb73a0dbf425..33798e4fba9950e48d38336bbc5a7b36decb4174 100644 --- a/app/models/billing_schemas.py +++ b/app/models/billing_schemas.py @@ -1,46 +1,46 @@ -""" -Pydantic schemas for billing endpoints. -""" -from pydantic import BaseModel -from typing import Optional, List -from datetime import datetime - - -class UsageResponse(BaseModel): - """Usage statistics response.""" - tenant_id: str - period: str # "day" or "month" - total_requests: int - total_tokens: int - total_cost_usd: float - gemini_requests: int = 0 - openai_requests: int = 0 - start_date: datetime - end_date: datetime - - -class PlanLimitsResponse(BaseModel): - """Current plan limits response.""" - tenant_id: str - plan_name: str - monthly_chat_limit: int # -1 for unlimited - current_month_usage: int - remaining_chats: Optional[int] # None if unlimited - - -class CostReportResponse(BaseModel): - """Cost report response.""" - tenant_id: str - period: str - total_cost_usd: float - total_requests: int - total_tokens: int - breakdown_by_provider: dict - breakdown_by_model: dict - - -class SetPlanRequest(BaseModel): - """Request to set tenant plan.""" - tenant_id: str - plan_name: str # "starter", "growth", or "pro" - +""" +Pydantic schemas for billing endpoints. +""" +from pydantic import BaseModel +from typing import Optional, List +from datetime import datetime + + +class UsageResponse(BaseModel): + """Usage statistics response.""" + tenant_id: str + period: str # "day" or "month" + total_requests: int + total_tokens: int + total_cost_usd: float + gemini_requests: int = 0 + openai_requests: int = 0 + start_date: datetime + end_date: datetime + + +class PlanLimitsResponse(BaseModel): + """Current plan limits response.""" + tenant_id: str + plan_name: str + monthly_chat_limit: int # -1 for unlimited + current_month_usage: int + remaining_chats: Optional[int] # None if unlimited + + +class CostReportResponse(BaseModel): + """Cost report response.""" + tenant_id: str + period: str + total_cost_usd: float + total_requests: int + total_tokens: int + breakdown_by_provider: dict + breakdown_by_model: dict + + +class SetPlanRequest(BaseModel): + """Request to set tenant plan.""" + tenant_id: str + plan_name: str # "starter", "growth", or "pro" + diff --git a/app/models/schemas.py b/app/models/schemas.py index d99a50081cf254b745a87c68b0652f1af62cf113..82254c1068f6ff22cc431f37017552e60666798a 100644 --- a/app/models/schemas.py +++ b/app/models/schemas.py @@ -1,112 +1,112 @@ -""" -Pydantic models for API request/response schemas. -""" -from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any -from datetime import datetime -from enum import Enum - - -class DocumentStatus(str, Enum): - PENDING = "pending" - PROCESSING = "processing" - COMPLETED = "completed" - FAILED = "failed" - - -class ChunkMetadata(BaseModel): - """Metadata for a document chunk.""" - tenant_id: str # CRITICAL: Multi-tenant isolation - kb_id: str - user_id: str - file_name: str - file_type: str - chunk_id: str - chunk_index: int - page_number: Optional[int] = None - total_chunks: int - document_id: Optional[str] = None # Track original document - created_at: datetime = Field(default_factory=datetime.utcnow) - - -class DocumentChunk(BaseModel): - """A chunk of text with metadata.""" - id: str - content: str - metadata: ChunkMetadata - embedding: Optional[List[float]] = None - - -class UploadRequest(BaseModel): - """Request model for file upload.""" - tenant_id: str # CRITICAL: Multi-tenant isolation - user_id: str - kb_id: str - - -class UploadResponse(BaseModel): - """Response model for file upload.""" - success: bool - message: str - document_id: Optional[str] = None - file_name: str - chunks_created: int = 0 - status: DocumentStatus = DocumentStatus.PENDING - - -class Citation(BaseModel): - """Citation reference for an answer.""" - file_name: str - chunk_id: str - page_number: Optional[int] = None - relevance_score: float - excerpt: str # Short excerpt from the chunk - - -class ChatRequest(BaseModel): - """Request model for chat endpoint.""" - tenant_id: str # CRITICAL: Multi-tenant isolation - user_id: str - kb_id: str - conversation_id: Optional[str] = None - question: str - - -class ChatResponse(BaseModel): - """Response model for chat endpoint.""" - success: bool - answer: str - citations: List[Citation] = [] - confidence: float # 0-1 score - from_knowledge_base: bool = True - escalation_suggested: bool = False - conversation_id: str - metadata: Dict[str, Any] = {} - - -class RetrievalResult(BaseModel): - """Result from vector store retrieval.""" - chunk_id: str - content: str - metadata: Dict[str, Any] - similarity_score: float - - -class KnowledgeBaseStats(BaseModel): - """Statistics for a knowledge base.""" - tenant_id: str # CRITICAL: Multi-tenant isolation - kb_id: str - user_id: str - total_documents: int - total_chunks: int - file_names: List[str] - last_updated: Optional[datetime] = None - - -class HealthResponse(BaseModel): - """Health check response.""" - status: str - version: str = "1.0.0" - vector_db_connected: bool - llm_configured: bool - +""" +Pydantic models for API request/response schemas. +""" +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime +from enum import Enum + + +class DocumentStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class ChunkMetadata(BaseModel): + """Metadata for a document chunk.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + kb_id: str + user_id: str + file_name: str + file_type: str + chunk_id: str + chunk_index: int + page_number: Optional[int] = None + total_chunks: int + document_id: Optional[str] = None # Track original document + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class DocumentChunk(BaseModel): + """A chunk of text with metadata.""" + id: str + content: str + metadata: ChunkMetadata + embedding: Optional[List[float]] = None + + +class UploadRequest(BaseModel): + """Request model for file upload.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + user_id: str + kb_id: str + + +class UploadResponse(BaseModel): + """Response model for file upload.""" + success: bool + message: str + document_id: Optional[str] = None + file_name: str + chunks_created: int = 0 + status: DocumentStatus = DocumentStatus.PENDING + + +class Citation(BaseModel): + """Citation reference for an answer.""" + file_name: str + chunk_id: str + page_number: Optional[int] = None + relevance_score: float + excerpt: str # Short excerpt from the chunk + + +class ChatRequest(BaseModel): + """Request model for chat endpoint.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + user_id: str + kb_id: str + conversation_id: Optional[str] = None + question: str + + +class ChatResponse(BaseModel): + """Response model for chat endpoint.""" + success: bool + answer: str + citations: List[Citation] = [] + confidence: float # 0-1 score + from_knowledge_base: bool = True + escalation_suggested: bool = False + conversation_id: str + metadata: Dict[str, Any] = {} + + +class RetrievalResult(BaseModel): + """Result from vector store retrieval.""" + chunk_id: str + content: str + metadata: Dict[str, Any] + similarity_score: float + + +class KnowledgeBaseStats(BaseModel): + """Statistics for a knowledge base.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + kb_id: str + user_id: str + total_documents: int + total_chunks: int + file_names: List[str] + last_updated: Optional[datetime] = None + + +class HealthResponse(BaseModel): + """Health check response.""" + status: str + version: str = "1.0.0" + vector_db_connected: bool + llm_configured: bool + diff --git a/app/rag/__init__.py b/app/rag/__init__.py index 95b30c0bc030b289b6aa3aab2f26057c7f70cf4b..c62ed3f159c36ec4ebc292e75c4cda417c6cc8d4 100644 --- a/app/rag/__init__.py +++ b/app/rag/__init__.py @@ -1,27 +1,27 @@ -""" -RAG (Retrieval-Augmented Generation) pipeline modules. -""" -from app.rag.ingest import parser, DocumentParser -from app.rag.chunking import chunker, DocumentChunker -from app.rag.embeddings import get_embedding_service, EmbeddingService -from app.rag.vectorstore import get_vector_store, VectorStore -from app.rag.retrieval import get_retrieval_service, RetrievalService -from app.rag.answer import get_answer_service, AnswerService - -__all__ = [ - "parser", - "DocumentParser", - "chunker", - "DocumentChunker", - "get_embedding_service", - "EmbeddingService", - "get_vector_store", - "VectorStore", - "get_retrieval_service", - "RetrievalService", - "get_answer_service", - "AnswerService", -] - - - +""" +RAG (Retrieval-Augmented Generation) pipeline modules. +""" +from app.rag.ingest import parser, DocumentParser +from app.rag.chunking import chunker, DocumentChunker +from app.rag.embeddings import get_embedding_service, EmbeddingService +from app.rag.vectorstore import get_vector_store, VectorStore +from app.rag.retrieval import get_retrieval_service, RetrievalService +from app.rag.answer import get_answer_service, AnswerService + +__all__ = [ + "parser", + "DocumentParser", + "chunker", + "DocumentChunker", + "get_embedding_service", + "EmbeddingService", + "get_vector_store", + "VectorStore", + "get_retrieval_service", + "RetrievalService", + "get_answer_service", + "AnswerService", +] + + + diff --git a/app/rag/__pycache__/__init__.cpython-313.pyc b/app/rag/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index c093719d5648f029ecb92720f214f11f478add69..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/__init__.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/answer.cpython-313.pyc b/app/rag/__pycache__/answer.cpython-313.pyc deleted file mode 100644 index c2880fb72e015542920dec13dc30fb5486762860..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/answer.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/chunking.cpython-313.pyc b/app/rag/__pycache__/chunking.cpython-313.pyc deleted file mode 100644 index b45bdda267276672a8be6b3b6200d83595264e82..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/chunking.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/embeddings.cpython-313.pyc b/app/rag/__pycache__/embeddings.cpython-313.pyc deleted file mode 100644 index 06676746587b8db8e25df1e535bb09d87f1a7c8c..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/embeddings.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/ingest.cpython-313.pyc b/app/rag/__pycache__/ingest.cpython-313.pyc deleted file mode 100644 index 6a4eac74a2ef070cedb6c097df085ce4e31658e1..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/ingest.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/intent.cpython-313.pyc b/app/rag/__pycache__/intent.cpython-313.pyc deleted file mode 100644 index aa36906543907097e1f1b98233c17a0a32ebaffa..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/intent.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/prompts.cpython-313.pyc b/app/rag/__pycache__/prompts.cpython-313.pyc deleted file mode 100644 index bde3096934253d63debe1720ded8d38248d59035..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/prompts.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/retrieval.cpython-313.pyc b/app/rag/__pycache__/retrieval.cpython-313.pyc deleted file mode 100644 index c4f6703630063ae9231587ce8afa11c58290d888..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/retrieval.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/vectorstore.cpython-313.pyc b/app/rag/__pycache__/vectorstore.cpython-313.pyc deleted file mode 100644 index 576a0de4f6e29a61e5f3c2991c4e1b0f86e39953..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/vectorstore.cpython-313.pyc and /dev/null differ diff --git a/app/rag/__pycache__/verifier.cpython-313.pyc b/app/rag/__pycache__/verifier.cpython-313.pyc deleted file mode 100644 index fbf584dd66f715dd30b8a3ee6f0fe59cbc83ab31..0000000000000000000000000000000000000000 Binary files a/app/rag/__pycache__/verifier.cpython-313.pyc and /dev/null differ diff --git a/app/rag/answer.py b/app/rag/answer.py index bc461cacce2f3e973b47237996696f032e313b52..a092ceb43a538a60142dc487eca62ed2c39588f6 100644 --- a/app/rag/answer.py +++ b/app/rag/answer.py @@ -1,444 +1,444 @@ -""" -Answer generation using LLM with RAG context. -Supports Gemini and OpenAI as providers. -""" -import google.generativeai as genai -from openai import OpenAI -from typing import Optional, Dict, Any, List -import logging -import os -import re - -from app.config import settings -from app.rag.prompts import ( - format_rag_prompt, - format_draft_prompt, - get_no_context_response, - get_low_confidence_response -) -from app.rag.verifier import get_verifier_service -from app.rag.intent import detect_intents -from app.models.schemas import Citation -from abc import ABC, abstractmethod - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class LLMProvider(ABC): - """Base class for LLM providers.""" - - @abstractmethod - def generate(self, system_prompt: str, user_prompt: str) -> str: - """Generate response from prompts.""" - raise NotImplementedError - - @abstractmethod - def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: - """ - Generate response and return usage information. - - Returns: - (response_text, usage_info) - usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used - """ - raise NotImplementedError - - -class GeminiProvider(LLMProvider): - """Google Gemini LLM provider.""" - - def __init__(self, api_key: Optional[str] = None, model: str = None): - self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY") - # Default to gemini-1.5-flash if not specified - self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash" - - if not self.api_key: - raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.") - - genai.configure(api_key=self.api_key) - - # Don't initialize client here - do it lazily in generate() to handle errors better - self._client = None - logger.info(f"Gemini provider initialized (will use model: {self.model})") - - def generate(self, system_prompt: str, user_prompt: str) -> str: - """Generate response using Gemini.""" - text, _ = self.generate_with_usage(system_prompt, user_prompt) - return text - - def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: - """Generate response using Gemini and return usage info.""" - # Combine system and user prompts for Gemini - full_prompt = f"{system_prompt}\n\n{user_prompt}" - - # Estimate prompt tokens (rough: 1 token โ‰ˆ 4 chars) - prompt_tokens = len(full_prompt) // 4 - - # Try to list available models first, then use the first available one - # If that fails, try common model names - models_to_try = [] - - # First, try to get available models - try: - available_models = genai.list_models() - model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods] - if model_names: - # Extract just the model name (remove 'models/' prefix if present) - clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names] - models_to_try.extend(clean_names[:3]) # Use first 3 available models - logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}") - except Exception as e: - logger.warning(f"Could not list available models: {e}, using fallback list") - - # Fallback to common model names if listing failed - if not models_to_try: - models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"] - - # Add configured model if different - if self.model and self.model not in models_to_try: - models_to_try.insert(0, self.model) - - # Remove duplicates while preserving order - seen = set() - models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))] - - last_error = None - for model_name in models_to_try: - try: - logger.info(f"Attempting to generate with model: {model_name}") - # Create a new client for this model - client = genai.GenerativeModel(model_name) - response = client.generate_content( - full_prompt, - generation_config=genai.types.GenerationConfig( - temperature=settings.TEMPERATURE, - max_output_tokens=1024, - ) - ) - - # Extract response text - response_text = response.text - - # Try to get usage info from response - usage_info = { - "prompt_tokens": prompt_tokens, - "completion_tokens": len(response_text) // 4, # Estimate - "total_tokens": prompt_tokens + (len(response_text) // 4), - "model_used": model_name.split('/')[-1] if '/' in model_name else model_name - } - - # Try to get actual usage from response if available - if hasattr(response, 'usage_metadata'): - usage_metadata = response.usage_metadata - if hasattr(usage_metadata, 'prompt_token_count'): - usage_info["prompt_tokens"] = usage_metadata.prompt_token_count - if hasattr(usage_metadata, 'candidates_token_count'): - usage_info["completion_tokens"] = usage_metadata.candidates_token_count - if hasattr(usage_metadata, 'total_token_count'): - usage_info["total_tokens"] = usage_metadata.total_token_count - - if model_name != self.model: - logger.info(f"โœ… Successfully used model: {model_name}") - - return response_text, usage_info - except Exception as e: - error_str = str(e).lower() - last_error = e - if "not found" in error_str or "not supported" in error_str or "404" in error_str: - logger.warning(f"Model {model_name} failed: {e}") - continue # Try next model - else: - # Different error (not model not found), re-raise - logger.error(f"Gemini generation error with {model_name}: {e}") - raise - - # All models failed - return a helpful error message - error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models." - logger.error(error_msg) - raise Exception(error_msg) - - -class OpenAIProvider(LLMProvider): - """OpenAI LLM provider.""" - - def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL): - self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY") - self.model = model - - if not self.api_key: - raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.") - - self.client = OpenAI(api_key=self.api_key) - logger.info(f"OpenAI provider initialized with model: {model}") - - def generate(self, system_prompt: str, user_prompt: str) -> str: - """Generate response using OpenAI.""" - text, _ = self.generate_with_usage(system_prompt, user_prompt) - return text - - def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: - """Generate response using OpenAI and return usage info.""" - try: - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ], - temperature=settings.TEMPERATURE, - max_tokens=1024 - ) - - response_text = response.choices[0].message.content - - # Extract usage info from OpenAI response - usage_info = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - "model_used": self.model - } - - return response_text, usage_info - except Exception as e: - logger.error(f"OpenAI generation error: {e}") - raise - - -class AnswerService: - """ - Generates answers using RAG context and LLM. - Handles confidence scoring and citation extraction. - """ - - # Confidence thresholds - HIGH_CONFIDENCE_THRESHOLD = 0.5 - LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold - STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results) - - def __init__(self, provider: str = settings.LLM_PROVIDER): - """ - Initialize the answer service. - - Args: - provider: LLM provider to use ("gemini" or "openai") - """ - self.provider_name = provider - self._provider: Optional[LLMProvider] = None - - @property - def provider(self) -> LLMProvider: - """Lazy load the LLM provider.""" - if self._provider is None: - if self.provider_name == "gemini": - self._provider = GeminiProvider() - elif self.provider_name == "openai": - self._provider = OpenAIProvider() - else: - raise ValueError(f"Unknown LLM provider: {self.provider_name}") - return self._provider - - def generate_answer( - self, - question: str, - context: str, - citations_info: List[Dict[str, Any]], - confidence: float, - has_relevant_results: bool, - use_verifier: bool = None # None = use config default - ) -> Dict[str, Any]: - """ - Generate an answer based on retrieved context with mandatory verifier. - - Args: - question: User's question - context: Retrieved context from knowledge base - citations_info: List of citation information - confidence: Average confidence score from retrieval - has_relevant_results: Whether any results passed the threshold - use_verifier: Whether to use verifier mode (None = use config default) - - Returns: - Dictionary with answer, citations, confidence, and metadata - """ - # Determine if verifier should be used (mandatory by default) - if use_verifier is None: - use_verifier = settings.REQUIRE_VERIFIER - - # GATE 1: No relevant results found - REFUSE - if not has_relevant_results or not context: - logger.info("No relevant context found, returning no-context response") - return { - "answer": get_no_context_response(), - "citations": [], - "confidence": 0.0, - "from_knowledge_base": False, - "escalation_suggested": True, - "refused": True - } - - # GATE 2: Strict confidence threshold - REFUSE if below strict threshold - if confidence < self.STRICT_CONFIDENCE_THRESHOLD: - logger.warning( - f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), " - f"REFUSING to answer to prevent hallucination" - ) - return { - "answer": get_no_context_response(), - "citations": [], - "confidence": confidence, - "from_knowledge_base": False, - "escalation_suggested": True, - "refused": True, - "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}" - } - - # GATE 3: Intent-based gating for specific intents (integration, API, etc.) - intents = detect_intents(question) - if "integration" in intents or "api" in question.lower(): - # For integration/API questions, require strong relevance - if confidence < 0.50: # Even stricter for integration questions - logger.warning( - f"Integration/API question with low confidence ({confidence:.3f}), " - f"REFUSING to prevent hallucination" - ) - return { - "answer": get_no_context_response(), - "citations": [], - "confidence": confidence, - "from_knowledge_base": False, - "escalation_suggested": True, - "refused": True, - "refusal_reason": "Integration/API questions require higher confidence" - } - - # Case 3: Passed all gates - generate answer with MANDATORY verifier - logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}") - - try: - # VERIFIER MODE IS MANDATORY: Draft โ†’ Verify โ†’ Final - # Step 1: Generate draft answer with usage tracking - draft_system, draft_user = format_draft_prompt(context, question) - draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user) - logger.info("Generated draft answer, running verifier...") - - # Step 2: Verify draft answer (MANDATORY) - verifier = get_verifier_service() - verification = verifier.verify_answer( - draft_answer=draft_answer, - context=context, - citations_info=citations_info - ) - - # Step 3: Handle verification result - if verification["pass"]: - logger.info("โœ… Verifier PASSED - Using draft answer") - citations = self._extract_citations(draft_answer, citations_info) - return { - "answer": draft_answer, - "citations": citations, - "confidence": confidence, - "from_knowledge_base": True, - "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD, - "verifier_passed": True, - "refused": False, - "usage": usage_info # Include usage info for tracking - } - else: - # Verifier failed - REFUSE to answer - issues = verification.get('issues', []) - unsupported = verification.get('unsupported_claims', []) - logger.warning( - f"โŒ Verifier FAILED - Issues: {issues}, " - f"Unsupported claims: {unsupported}" - ) - refusal_message = ( - get_no_context_response() + - "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. " - "This helps prevent providing incorrect information." - ) - return { - "answer": refusal_message, - "citations": [], - "confidence": 0.0, - "from_knowledge_base": False, - "escalation_suggested": True, - "verifier_passed": False, - "verifier_issues": issues, - "unsupported_claims": unsupported, - "refused": True, - "refusal_reason": "Verifier failed: claims not supported by context", - "usage": usage_info # Still track usage even if refused - } - - except ValueError as e: - # Configuration errors (e.g., missing API key) - error_msg = str(e) - logger.error(f"Configuration error in answer generation: {error_msg}") - if "API key" in error_msg.lower(): - raise ValueError(f"LLM API key not configured: {error_msg}") - raise - except Exception as e: - logger.error(f"Error generating answer: {e}", exc_info=True) - # Re-raise to be handled by the endpoint - raise - - def _extract_citations( - self, - answer: str, - citations_info: List[Dict[str, Any]] - ) -> List[Citation]: - """ - Extract and format citations from the answer. - - Args: - answer: Generated answer with [Source X] references - citations_info: Available citation information - - Returns: - List of Citation objects - """ - citations = [] - - # Find all [Source X] references in the answer - source_pattern = r'\[Source\s*(\d+)\]' - matches = re.findall(source_pattern, answer) - referenced_indices = set(int(m) for m in matches) - - # Build citation objects for referenced sources - for info in citations_info: - if info.get("index") in referenced_indices: - citations.append(Citation( - file_name=info.get("file_name", "Unknown"), - chunk_id=info.get("chunk_id", ""), - page_number=info.get("page_number"), - relevance_score=info.get("similarity_score", 0.0), - excerpt=info.get("excerpt", "") - )) - - # If no specific citations found but we have context, include top sources - if not citations and citations_info: - for info in citations_info[:3]: # Top 3 sources - citations.append(Citation( - file_name=info.get("file_name", "Unknown"), - chunk_id=info.get("chunk_id", ""), - page_number=info.get("page_number"), - relevance_score=info.get("similarity_score", 0.0), - excerpt=info.get("excerpt", "") - )) - - return citations - - -# Global answer service instance -_answer_service: Optional[AnswerService] = None - - -def get_answer_service() -> AnswerService: - """Get the global answer service instance.""" - global _answer_service - if _answer_service is None: - _answer_service = AnswerService() - return _answer_service - +""" +Answer generation using LLM with RAG context. +Supports Gemini and OpenAI as providers. +""" +import google.generativeai as genai +from openai import OpenAI +from typing import Optional, Dict, Any, List +import logging +import os +import re + +from app.config import settings +from app.rag.prompts import ( + format_rag_prompt, + format_draft_prompt, + get_no_context_response, + get_low_confidence_response +) +from app.rag.verifier import get_verifier_service +from app.rag.intent import detect_intents +from app.models.schemas import Citation +from abc import ABC, abstractmethod + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class LLMProvider(ABC): + """Base class for LLM providers.""" + + @abstractmethod + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response from prompts.""" + raise NotImplementedError + + @abstractmethod + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """ + Generate response and return usage information. + + Returns: + (response_text, usage_info) + usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used + """ + raise NotImplementedError + + +class GeminiProvider(LLMProvider): + """Google Gemini LLM provider.""" + + def __init__(self, api_key: Optional[str] = None, model: str = None): + self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY") + # Default to gemini-1.5-flash if not specified + self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash" + + if not self.api_key: + raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.") + + genai.configure(api_key=self.api_key) + + # Don't initialize client here - do it lazily in generate() to handle errors better + self._client = None + logger.info(f"Gemini provider initialized (will use model: {self.model})") + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response using Gemini.""" + text, _ = self.generate_with_usage(system_prompt, user_prompt) + return text + + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """Generate response using Gemini and return usage info.""" + # Combine system and user prompts for Gemini + full_prompt = f"{system_prompt}\n\n{user_prompt}" + + # Estimate prompt tokens (rough: 1 token โ‰ˆ 4 chars) + prompt_tokens = len(full_prompt) // 4 + + # Try to list available models first, then use the first available one + # If that fails, try common model names + models_to_try = [] + + # First, try to get available models + try: + available_models = genai.list_models() + model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods] + if model_names: + # Extract just the model name (remove 'models/' prefix if present) + clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names] + models_to_try.extend(clean_names[:3]) # Use first 3 available models + logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}") + except Exception as e: + logger.warning(f"Could not list available models: {e}, using fallback list") + + # Fallback to common model names if listing failed + if not models_to_try: + models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"] + + # Add configured model if different + if self.model and self.model not in models_to_try: + models_to_try.insert(0, self.model) + + # Remove duplicates while preserving order + seen = set() + models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))] + + last_error = None + for model_name in models_to_try: + try: + logger.info(f"Attempting to generate with model: {model_name}") + # Create a new client for this model + client = genai.GenerativeModel(model_name) + response = client.generate_content( + full_prompt, + generation_config=genai.types.GenerationConfig( + temperature=settings.TEMPERATURE, + max_output_tokens=1024, + ) + ) + + # Extract response text + response_text = response.text + + # Try to get usage info from response + usage_info = { + "prompt_tokens": prompt_tokens, + "completion_tokens": len(response_text) // 4, # Estimate + "total_tokens": prompt_tokens + (len(response_text) // 4), + "model_used": model_name.split('/')[-1] if '/' in model_name else model_name + } + + # Try to get actual usage from response if available + if hasattr(response, 'usage_metadata'): + usage_metadata = response.usage_metadata + if hasattr(usage_metadata, 'prompt_token_count'): + usage_info["prompt_tokens"] = usage_metadata.prompt_token_count + if hasattr(usage_metadata, 'candidates_token_count'): + usage_info["completion_tokens"] = usage_metadata.candidates_token_count + if hasattr(usage_metadata, 'total_token_count'): + usage_info["total_tokens"] = usage_metadata.total_token_count + + if model_name != self.model: + logger.info(f"โœ… Successfully used model: {model_name}") + + return response_text, usage_info + except Exception as e: + error_str = str(e).lower() + last_error = e + if "not found" in error_str or "not supported" in error_str or "404" in error_str: + logger.warning(f"Model {model_name} failed: {e}") + continue # Try next model + else: + # Different error (not model not found), re-raise + logger.error(f"Gemini generation error with {model_name}: {e}") + raise + + # All models failed - return a helpful error message + error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models." + logger.error(error_msg) + raise Exception(error_msg) + + +class OpenAIProvider(LLMProvider): + """OpenAI LLM provider.""" + + def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL): + self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY") + self.model = model + + if not self.api_key: + raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.") + + self.client = OpenAI(api_key=self.api_key) + logger.info(f"OpenAI provider initialized with model: {model}") + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response using OpenAI.""" + text, _ = self.generate_with_usage(system_prompt, user_prompt) + return text + + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """Generate response using OpenAI and return usage info.""" + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=settings.TEMPERATURE, + max_tokens=1024 + ) + + response_text = response.choices[0].message.content + + # Extract usage info from OpenAI response + usage_info = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "model_used": self.model + } + + return response_text, usage_info + except Exception as e: + logger.error(f"OpenAI generation error: {e}") + raise + + +class AnswerService: + """ + Generates answers using RAG context and LLM. + Handles confidence scoring and citation extraction. + """ + + # Confidence thresholds + HIGH_CONFIDENCE_THRESHOLD = 0.5 + LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold + STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results) + + def __init__(self, provider: str = settings.LLM_PROVIDER): + """ + Initialize the answer service. + + Args: + provider: LLM provider to use ("gemini" or "openai") + """ + self.provider_name = provider + self._provider: Optional[LLMProvider] = None + + @property + def provider(self) -> LLMProvider: + """Lazy load the LLM provider.""" + if self._provider is None: + if self.provider_name == "gemini": + self._provider = GeminiProvider() + elif self.provider_name == "openai": + self._provider = OpenAIProvider() + else: + raise ValueError(f"Unknown LLM provider: {self.provider_name}") + return self._provider + + def generate_answer( + self, + question: str, + context: str, + citations_info: List[Dict[str, Any]], + confidence: float, + has_relevant_results: bool, + use_verifier: bool = None # None = use config default + ) -> Dict[str, Any]: + """ + Generate an answer based on retrieved context with mandatory verifier. + + Args: + question: User's question + context: Retrieved context from knowledge base + citations_info: List of citation information + confidence: Average confidence score from retrieval + has_relevant_results: Whether any results passed the threshold + use_verifier: Whether to use verifier mode (None = use config default) + + Returns: + Dictionary with answer, citations, confidence, and metadata + """ + # Determine if verifier should be used (mandatory by default) + if use_verifier is None: + use_verifier = settings.REQUIRE_VERIFIER + + # GATE 1: No relevant results found - REFUSE + if not has_relevant_results or not context: + logger.info("No relevant context found, returning no-context response") + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": 0.0, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True + } + + # GATE 2: Strict confidence threshold - REFUSE if below strict threshold + if confidence < self.STRICT_CONFIDENCE_THRESHOLD: + logger.warning( + f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), " + f"REFUSING to answer to prevent hallucination" + ) + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": confidence, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True, + "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}" + } + + # GATE 3: Intent-based gating for specific intents (integration, API, etc.) + intents = detect_intents(question) + if "integration" in intents or "api" in question.lower(): + # For integration/API questions, require strong relevance + if confidence < 0.50: # Even stricter for integration questions + logger.warning( + f"Integration/API question with low confidence ({confidence:.3f}), " + f"REFUSING to prevent hallucination" + ) + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": confidence, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True, + "refusal_reason": "Integration/API questions require higher confidence" + } + + # Case 3: Passed all gates - generate answer with MANDATORY verifier + logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}") + + try: + # VERIFIER MODE IS MANDATORY: Draft โ†’ Verify โ†’ Final + # Step 1: Generate draft answer with usage tracking + draft_system, draft_user = format_draft_prompt(context, question) + draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user) + logger.info("Generated draft answer, running verifier...") + + # Step 2: Verify draft answer (MANDATORY) + verifier = get_verifier_service() + verification = verifier.verify_answer( + draft_answer=draft_answer, + context=context, + citations_info=citations_info + ) + + # Step 3: Handle verification result + if verification["pass"]: + logger.info("โœ… Verifier PASSED - Using draft answer") + citations = self._extract_citations(draft_answer, citations_info) + return { + "answer": draft_answer, + "citations": citations, + "confidence": confidence, + "from_knowledge_base": True, + "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD, + "verifier_passed": True, + "refused": False, + "usage": usage_info # Include usage info for tracking + } + else: + # Verifier failed - REFUSE to answer + issues = verification.get('issues', []) + unsupported = verification.get('unsupported_claims', []) + logger.warning( + f"โŒ Verifier FAILED - Issues: {issues}, " + f"Unsupported claims: {unsupported}" + ) + refusal_message = ( + get_no_context_response() + + "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. " + "This helps prevent providing incorrect information." + ) + return { + "answer": refusal_message, + "citations": [], + "confidence": 0.0, + "from_knowledge_base": False, + "escalation_suggested": True, + "verifier_passed": False, + "verifier_issues": issues, + "unsupported_claims": unsupported, + "refused": True, + "refusal_reason": "Verifier failed: claims not supported by context", + "usage": usage_info # Still track usage even if refused + } + + except ValueError as e: + # Configuration errors (e.g., missing API key) + error_msg = str(e) + logger.error(f"Configuration error in answer generation: {error_msg}") + if "API key" in error_msg.lower(): + raise ValueError(f"LLM API key not configured: {error_msg}") + raise + except Exception as e: + logger.error(f"Error generating answer: {e}", exc_info=True) + # Re-raise to be handled by the endpoint + raise + + def _extract_citations( + self, + answer: str, + citations_info: List[Dict[str, Any]] + ) -> List[Citation]: + """ + Extract and format citations from the answer. + + Args: + answer: Generated answer with [Source X] references + citations_info: Available citation information + + Returns: + List of Citation objects + """ + citations = [] + + # Find all [Source X] references in the answer + source_pattern = r'\[Source\s*(\d+)\]' + matches = re.findall(source_pattern, answer) + referenced_indices = set(int(m) for m in matches) + + # Build citation objects for referenced sources + for info in citations_info: + if info.get("index") in referenced_indices: + citations.append(Citation( + file_name=info.get("file_name", "Unknown"), + chunk_id=info.get("chunk_id", ""), + page_number=info.get("page_number"), + relevance_score=info.get("similarity_score", 0.0), + excerpt=info.get("excerpt", "") + )) + + # If no specific citations found but we have context, include top sources + if not citations and citations_info: + for info in citations_info[:3]: # Top 3 sources + citations.append(Citation( + file_name=info.get("file_name", "Unknown"), + chunk_id=info.get("chunk_id", ""), + page_number=info.get("page_number"), + relevance_score=info.get("similarity_score", 0.0), + excerpt=info.get("excerpt", "") + )) + + return citations + + +# Global answer service instance +_answer_service: Optional[AnswerService] = None + + +def get_answer_service() -> AnswerService: + """Get the global answer service instance.""" + global _answer_service + if _answer_service is None: + _answer_service = AnswerService() + return _answer_service + diff --git a/app/rag/chunking.py b/app/rag/chunking.py index dccd54e782174e1426ff6b3e962ba4c974649d04..8b49bf05cbbab8482248ea8ecd0a01a8e2c6df5e 100644 --- a/app/rag/chunking.py +++ b/app/rag/chunking.py @@ -1,196 +1,196 @@ -""" -Document chunking with overlap and metadata preservation. -""" -import tiktoken -from typing import List, Dict, Any, Optional -from dataclasses import dataclass -import re -import uuid -from datetime import datetime - -from app.config import settings - - -@dataclass -class TextChunk: - """Represents a chunk of text with metadata.""" - content: str - chunk_index: int - start_char: int - end_char: int - page_number: Optional[int] = None - token_count: int = 0 - - -class DocumentChunker: - """ - Chunks documents into smaller pieces with overlap. - Uses tiktoken for accurate token counting. - """ - - def __init__( - self, - chunk_size: int = settings.CHUNK_SIZE, - chunk_overlap: int = settings.CHUNK_OVERLAP, - min_chunk_size: int = settings.MIN_CHUNK_SIZE - ): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.min_chunk_size = min_chunk_size - # Use cl100k_base encoding (same as GPT-4, good general purpose) - self.encoding = tiktoken.get_encoding("cl100k_base") - - def count_tokens(self, text: str) -> int: - """Count tokens in text.""" - return len(self.encoding.encode(text)) - - def _split_into_sentences(self, text: str) -> List[str]: - """Split text into sentences while preserving structure.""" - # Split on sentence boundaries but keep delimiters - sentence_endings = r'(?<=[.!?])\s+' - sentences = re.split(sentence_endings, text) - return [s.strip() for s in sentences if s.strip()] - - def _split_into_paragraphs(self, text: str) -> List[str]: - """Split text into paragraphs.""" - paragraphs = re.split(r'\n\s*\n', text) - return [p.strip() for p in paragraphs if p.strip()] - - def chunk_text( - self, - text: str, - page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number - ) -> List[TextChunk]: - """ - Chunk text into smaller pieces with overlap. - - Args: - text: The text to chunk - page_numbers: Optional mapping of character positions to page numbers - - Returns: - List of TextChunk objects - """ - if not text.strip(): - return [] - - chunks = [] - current_chunk = "" - current_start = 0 - chunk_index = 0 - - # First, split into paragraphs for natural boundaries - paragraphs = self._split_into_paragraphs(text) - - char_position = 0 - for para in paragraphs: - para_tokens = self.count_tokens(para) - current_tokens = self.count_tokens(current_chunk) - - # If adding this paragraph exceeds chunk size - if current_tokens + para_tokens > self.chunk_size and current_chunk: - # Save current chunk if it meets minimum size - if current_tokens >= self.min_chunk_size: - page_num = None - if page_numbers: - # Find the page number for this chunk's start position - for pos, page in sorted(page_numbers.items()): - if pos <= current_start: - page_num = page - - chunks.append(TextChunk( - content=current_chunk.strip(), - chunk_index=chunk_index, - start_char=current_start, - end_char=char_position, - page_number=page_num, - token_count=current_tokens - )) - chunk_index += 1 - - # Start new chunk with overlap - overlap_text = self._get_overlap_text(current_chunk) - current_chunk = overlap_text + "\n\n" + para if overlap_text else para - current_start = char_position - len(overlap_text) if overlap_text else char_position - else: - # Add paragraph to current chunk - if current_chunk: - current_chunk += "\n\n" + para - else: - current_chunk = para - current_start = char_position - - char_position += len(para) + 2 # +2 for paragraph separator - - # Don't forget the last chunk - if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size: - page_num = None - if page_numbers: - for pos, page in sorted(page_numbers.items()): - if pos <= current_start: - page_num = page - - chunks.append(TextChunk( - content=current_chunk.strip(), - chunk_index=chunk_index, - start_char=current_start, - end_char=len(text), - page_number=page_num, - token_count=self.count_tokens(current_chunk) - )) - - return chunks - - def _get_overlap_text(self, text: str) -> str: - """Get the overlap text from the end of a chunk.""" - sentences = self._split_into_sentences(text) - if not sentences: - return "" - - overlap = "" - tokens = 0 - - # Work backwards through sentences - for sentence in reversed(sentences): - sentence_tokens = self.count_tokens(sentence) - if tokens + sentence_tokens <= self.chunk_overlap: - overlap = sentence + " " + overlap if overlap else sentence - tokens += sentence_tokens - else: - break - - return overlap.strip() - - def create_chunk_metadata( - self, - chunk: TextChunk, - tenant_id: str, # CRITICAL: Multi-tenant isolation - kb_id: str, - user_id: str, - file_name: str, - file_type: str, - total_chunks: int, - document_id: Optional[str] = None - ) -> Dict[str, Any]: - """Create metadata dictionary for a chunk.""" - chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}" - - return { - "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation - "kb_id": kb_id, - "user_id": user_id, - "file_name": file_name, - "file_type": file_type, - "chunk_id": chunk_id, - "chunk_index": chunk.chunk_index, - "page_number": chunk.page_number, - "total_chunks": total_chunks, - "token_count": chunk.token_count, - "document_id": document_id, # Track original document - "created_at": datetime.utcnow().isoformat() - } - - -# Global chunker instance -chunker = DocumentChunker() - +""" +Document chunking with overlap and metadata preservation. +""" +import tiktoken +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +import re +import uuid +from datetime import datetime + +from app.config import settings + + +@dataclass +class TextChunk: + """Represents a chunk of text with metadata.""" + content: str + chunk_index: int + start_char: int + end_char: int + page_number: Optional[int] = None + token_count: int = 0 + + +class DocumentChunker: + """ + Chunks documents into smaller pieces with overlap. + Uses tiktoken for accurate token counting. + """ + + def __init__( + self, + chunk_size: int = settings.CHUNK_SIZE, + chunk_overlap: int = settings.CHUNK_OVERLAP, + min_chunk_size: int = settings.MIN_CHUNK_SIZE + ): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_chunk_size = min_chunk_size + # Use cl100k_base encoding (same as GPT-4, good general purpose) + self.encoding = tiktoken.get_encoding("cl100k_base") + + def count_tokens(self, text: str) -> int: + """Count tokens in text.""" + return len(self.encoding.encode(text)) + + def _split_into_sentences(self, text: str) -> List[str]: + """Split text into sentences while preserving structure.""" + # Split on sentence boundaries but keep delimiters + sentence_endings = r'(?<=[.!?])\s+' + sentences = re.split(sentence_endings, text) + return [s.strip() for s in sentences if s.strip()] + + def _split_into_paragraphs(self, text: str) -> List[str]: + """Split text into paragraphs.""" + paragraphs = re.split(r'\n\s*\n', text) + return [p.strip() for p in paragraphs if p.strip()] + + def chunk_text( + self, + text: str, + page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number + ) -> List[TextChunk]: + """ + Chunk text into smaller pieces with overlap. + + Args: + text: The text to chunk + page_numbers: Optional mapping of character positions to page numbers + + Returns: + List of TextChunk objects + """ + if not text.strip(): + return [] + + chunks = [] + current_chunk = "" + current_start = 0 + chunk_index = 0 + + # First, split into paragraphs for natural boundaries + paragraphs = self._split_into_paragraphs(text) + + char_position = 0 + for para in paragraphs: + para_tokens = self.count_tokens(para) + current_tokens = self.count_tokens(current_chunk) + + # If adding this paragraph exceeds chunk size + if current_tokens + para_tokens > self.chunk_size and current_chunk: + # Save current chunk if it meets minimum size + if current_tokens >= self.min_chunk_size: + page_num = None + if page_numbers: + # Find the page number for this chunk's start position + for pos, page in sorted(page_numbers.items()): + if pos <= current_start: + page_num = page + + chunks.append(TextChunk( + content=current_chunk.strip(), + chunk_index=chunk_index, + start_char=current_start, + end_char=char_position, + page_number=page_num, + token_count=current_tokens + )) + chunk_index += 1 + + # Start new chunk with overlap + overlap_text = self._get_overlap_text(current_chunk) + current_chunk = overlap_text + "\n\n" + para if overlap_text else para + current_start = char_position - len(overlap_text) if overlap_text else char_position + else: + # Add paragraph to current chunk + if current_chunk: + current_chunk += "\n\n" + para + else: + current_chunk = para + current_start = char_position + + char_position += len(para) + 2 # +2 for paragraph separator + + # Don't forget the last chunk + if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size: + page_num = None + if page_numbers: + for pos, page in sorted(page_numbers.items()): + if pos <= current_start: + page_num = page + + chunks.append(TextChunk( + content=current_chunk.strip(), + chunk_index=chunk_index, + start_char=current_start, + end_char=len(text), + page_number=page_num, + token_count=self.count_tokens(current_chunk) + )) + + return chunks + + def _get_overlap_text(self, text: str) -> str: + """Get the overlap text from the end of a chunk.""" + sentences = self._split_into_sentences(text) + if not sentences: + return "" + + overlap = "" + tokens = 0 + + # Work backwards through sentences + for sentence in reversed(sentences): + sentence_tokens = self.count_tokens(sentence) + if tokens + sentence_tokens <= self.chunk_overlap: + overlap = sentence + " " + overlap if overlap else sentence + tokens += sentence_tokens + else: + break + + return overlap.strip() + + def create_chunk_metadata( + self, + chunk: TextChunk, + tenant_id: str, # CRITICAL: Multi-tenant isolation + kb_id: str, + user_id: str, + file_name: str, + file_type: str, + total_chunks: int, + document_id: Optional[str] = None + ) -> Dict[str, Any]: + """Create metadata dictionary for a chunk.""" + chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}" + + return { + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id, + "file_name": file_name, + "file_type": file_type, + "chunk_id": chunk_id, + "chunk_index": chunk.chunk_index, + "page_number": chunk.page_number, + "total_chunks": total_chunks, + "token_count": chunk.token_count, + "document_id": document_id, # Track original document + "created_at": datetime.utcnow().isoformat() + } + + +# Global chunker instance +chunker = DocumentChunker() + diff --git a/app/rag/embeddings.py b/app/rag/embeddings.py index d5b77cf73d4b971e21cc5d41c068f2f885ea3191..cd81a6f6f35e35ab3b0a43d2ef75614761ea38a7 100644 --- a/app/rag/embeddings.py +++ b/app/rag/embeddings.py @@ -1,145 +1,145 @@ -""" -Embedding generation using Sentence Transformers. -Supports local models for privacy and offline use. -""" -from sentence_transformers import SentenceTransformer -from typing import List, Optional -import numpy as np -import logging -from functools import lru_cache - -from app.config import settings - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class EmbeddingService: - """ - Generates embeddings for text using Sentence Transformers. - Uses a lightweight model optimized for semantic search. - """ - - def __init__(self, model_name: str = settings.EMBEDDING_MODEL): - """ - Initialize the embedding service. - - Args: - model_name: Name of the Sentence Transformer model to use - """ - self.model_name = model_name - self._model: Optional[SentenceTransformer] = None - logger.info(f"Embedding service initialized with model: {model_name}") - - @property - def model(self) -> SentenceTransformer: - """Lazy load the model.""" - if self._model is None: - logger.info(f"Loading embedding model: {self.model_name}") - self._model = SentenceTransformer(self.model_name) - logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}") - return self._model - - def embed_text(self, text: str) -> List[float]: - """ - Generate embedding for a single text. - - Args: - text: Text to embed - - Returns: - List of floats representing the embedding vector - """ - if not text.strip(): - raise ValueError("Cannot embed empty text") - - embedding = self.model.encode(text, convert_to_numpy=True) - return embedding.tolist() - - def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: - """ - Generate embeddings for multiple texts. - - Args: - texts: List of texts to embed - batch_size: Batch size for processing - - Returns: - List of embedding vectors - """ - if not texts: - return [] - - # Filter out empty texts - valid_texts = [t for t in texts if t.strip()] - if len(valid_texts) != len(texts): - logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts") - - logger.info(f"Generating embeddings for {len(valid_texts)} texts") - - embeddings = self.model.encode( - valid_texts, - batch_size=batch_size, - show_progress_bar=len(valid_texts) > 100, - convert_to_numpy=True - ) - - return embeddings.tolist() - - def embed_query(self, query: str) -> List[float]: - """ - Generate embedding for a search query. - Some models have different embeddings for queries vs documents. - - Args: - query: Search query to embed - - Returns: - Embedding vector for the query - """ - # For most models, query embedding is the same as document embedding - # But we keep this separate for models that differentiate - return self.embed_text(query) - - def get_dimension(self) -> int: - """Get the embedding dimension.""" - return self.model.get_sentence_embedding_dimension() - - def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: - """ - Compute cosine similarity between two embeddings. - - Args: - embedding1: First embedding vector - embedding2: Second embedding vector - - Returns: - Cosine similarity score (0-1) - """ - vec1 = np.array(embedding1) - vec2 = np.array(embedding2) - - # Cosine similarity - dot_product = np.dot(vec1, vec2) - norm1 = np.linalg.norm(vec1) - norm2 = np.linalg.norm(vec2) - - if norm1 == 0 or norm2 == 0: - return 0.0 - - return float(dot_product / (norm1 * norm2)) - - -# Global embedding service instance (lazy loaded) -_embedding_service: Optional[EmbeddingService] = None - - -def get_embedding_service() -> EmbeddingService: - """Get the global embedding service instance.""" - global _embedding_service - if _embedding_service is None: - _embedding_service = EmbeddingService() - return _embedding_service - - - +""" +Embedding generation using Sentence Transformers. +Supports local models for privacy and offline use. +""" +from sentence_transformers import SentenceTransformer +from typing import List, Optional +import numpy as np +import logging +from functools import lru_cache + +from app.config import settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class EmbeddingService: + """ + Generates embeddings for text using Sentence Transformers. + Uses a lightweight model optimized for semantic search. + """ + + def __init__(self, model_name: str = settings.EMBEDDING_MODEL): + """ + Initialize the embedding service. + + Args: + model_name: Name of the Sentence Transformer model to use + """ + self.model_name = model_name + self._model: Optional[SentenceTransformer] = None + logger.info(f"Embedding service initialized with model: {model_name}") + + @property + def model(self) -> SentenceTransformer: + """Lazy load the model.""" + if self._model is None: + logger.info(f"Loading embedding model: {self.model_name}") + self._model = SentenceTransformer(self.model_name) + logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}") + return self._model + + def embed_text(self, text: str) -> List[float]: + """ + Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + List of floats representing the embedding vector + """ + if not text.strip(): + raise ValueError("Cannot embed empty text") + + embedding = self.model.encode(text, convert_to_numpy=True) + return embedding.tolist() + + def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: + """ + Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + batch_size: Batch size for processing + + Returns: + List of embedding vectors + """ + if not texts: + return [] + + # Filter out empty texts + valid_texts = [t for t in texts if t.strip()] + if len(valid_texts) != len(texts): + logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts") + + logger.info(f"Generating embeddings for {len(valid_texts)} texts") + + embeddings = self.model.encode( + valid_texts, + batch_size=batch_size, + show_progress_bar=len(valid_texts) > 100, + convert_to_numpy=True + ) + + return embeddings.tolist() + + def embed_query(self, query: str) -> List[float]: + """ + Generate embedding for a search query. + Some models have different embeddings for queries vs documents. + + Args: + query: Search query to embed + + Returns: + Embedding vector for the query + """ + # For most models, query embedding is the same as document embedding + # But we keep this separate for models that differentiate + return self.embed_text(query) + + def get_dimension(self) -> int: + """Get the embedding dimension.""" + return self.model.get_sentence_embedding_dimension() + + def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: + """ + Compute cosine similarity between two embeddings. + + Args: + embedding1: First embedding vector + embedding2: Second embedding vector + + Returns: + Cosine similarity score (0-1) + """ + vec1 = np.array(embedding1) + vec2 = np.array(embedding2) + + # Cosine similarity + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return float(dot_product / (norm1 * norm2)) + + +# Global embedding service instance (lazy loaded) +_embedding_service: Optional[EmbeddingService] = None + + +def get_embedding_service() -> EmbeddingService: + """Get the global embedding service instance.""" + global _embedding_service + if _embedding_service is None: + _embedding_service = EmbeddingService() + return _embedding_service + + + diff --git a/app/rag/ingest.py b/app/rag/ingest.py index 348ac7102d07af1ba5ce8f758225dc5d89365aff..9e70c16a20ac064e3a39ed492b2af1f999eac452 100644 --- a/app/rag/ingest.py +++ b/app/rag/ingest.py @@ -1,231 +1,231 @@ -""" -Document ingestion and parsing pipeline. -Supports PDF, DOCX, TXT, and Markdown files. -""" -import fitz # PyMuPDF -from docx import Document as DocxDocument -import markdown -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any -import chardet -import re -from dataclasses import dataclass -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -@dataclass -class ParsedDocument: - """Represents a parsed document with text and metadata.""" - text: str - file_name: str - file_type: str - page_count: Optional[int] = None - page_map: Optional[Dict[int, int]] = None # char_position -> page_number - metadata: Dict[str, Any] = None - - -class DocumentParser: - """ - Parses various document formats into plain text. - Preserves page information for citations. - """ - - SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.doc', '.txt', '.md', '.markdown'} - - def __init__(self): - self.parsers = { - '.pdf': self._parse_pdf, - '.docx': self._parse_docx, - '.doc': self._parse_docx, # Try docx parser for doc files - '.txt': self._parse_text, - '.md': self._parse_markdown, - '.markdown': self._parse_markdown, - } - - def parse(self, file_path: Path) -> ParsedDocument: - """ - Parse a document file into text. - - Args: - file_path: Path to the document file - - Returns: - ParsedDocument with extracted text and metadata - """ - if not file_path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") - - ext = file_path.suffix.lower() - if ext not in self.SUPPORTED_EXTENSIONS: - raise ValueError(f"Unsupported file type: {ext}") - - parser = self.parsers.get(ext) - if not parser: - raise ValueError(f"No parser available for: {ext}") - - logger.info(f"Parsing document: {file_path.name} ({ext})") - return parser(file_path) - - def _parse_pdf(self, file_path: Path) -> ParsedDocument: - """Parse PDF file with page tracking.""" - try: - doc = fitz.open(file_path) - text_parts = [] - page_map = {} - current_pos = 0 - - for page_num, page in enumerate(doc, start=1): - page_text = page.get_text("text") - if page_text.strip(): - # Record where this page starts - page_map[current_pos] = page_num - text_parts.append(page_text) - current_pos += len(page_text) + 2 # +2 for separator - - page_count = len(doc) - doc.close() - - full_text = "\n\n".join(text_parts) - - return ParsedDocument( - text=self._clean_text(full_text), - file_name=file_path.name, - file_type="pdf", - page_count=page_count, - page_map=page_map, - metadata={"source": str(file_path)} - ) - except Exception as e: - logger.error(f"Error parsing PDF {file_path}: {e}") - raise - - def _parse_docx(self, file_path: Path) -> ParsedDocument: - """Parse DOCX file.""" - try: - doc = DocxDocument(file_path) - paragraphs = [] - - for para in doc.paragraphs: - if para.text.strip(): - paragraphs.append(para.text) - - # Also extract text from tables - for table in doc.tables: - for row in table.rows: - row_text = [] - for cell in row.cells: - if cell.text.strip(): - row_text.append(cell.text.strip()) - if row_text: - paragraphs.append(" | ".join(row_text)) - - full_text = "\n\n".join(paragraphs) - - return ParsedDocument( - text=self._clean_text(full_text), - file_name=file_path.name, - file_type="docx", - metadata={"source": str(file_path)} - ) - except Exception as e: - logger.error(f"Error parsing DOCX {file_path}: {e}") - raise - - def _parse_text(self, file_path: Path) -> ParsedDocument: - """Parse plain text file with encoding detection.""" - try: - # Detect encoding - with open(file_path, 'rb') as f: - raw_data = f.read() - detected = chardet.detect(raw_data) - encoding = detected.get('encoding', 'utf-8') - - # Read with detected encoding - with open(file_path, 'r', encoding=encoding, errors='replace') as f: - text = f.read() - - return ParsedDocument( - text=self._clean_text(text), - file_name=file_path.name, - file_type="txt", - metadata={"source": str(file_path), "encoding": encoding} - ) - except Exception as e: - logger.error(f"Error parsing text file {file_path}: {e}") - raise - - def _parse_markdown(self, file_path: Path) -> ParsedDocument: - """Parse Markdown file, converting to plain text.""" - try: - with open(file_path, 'r', encoding='utf-8', errors='replace') as f: - md_content = f.read() - - # Convert markdown to HTML, then strip tags - html = markdown.markdown(md_content) - text = self._strip_html_tags(html) - - # Also keep the original markdown structure for better context - # Remove markdown syntax but keep structure - clean_md = self._clean_markdown(md_content) - - return ParsedDocument( - text=self._clean_text(clean_md), - file_name=file_path.name, - file_type="markdown", - metadata={"source": str(file_path)} - ) - except Exception as e: - logger.error(f"Error parsing Markdown {file_path}: {e}") - raise - - def _clean_text(self, text: str) -> str: - """Clean and normalize text.""" - # Replace multiple whitespace with single space - text = re.sub(r'[ \t]+', ' ', text) - # Replace multiple newlines with double newline - text = re.sub(r'\n{3,}', '\n\n', text) - # Remove leading/trailing whitespace from lines - lines = [line.strip() for line in text.split('\n')] - text = '\n'.join(lines) - # Remove leading/trailing whitespace - return text.strip() - - def _strip_html_tags(self, html: str) -> str: - """Remove HTML tags from text.""" - clean = re.sub(r'<[^>]+>', '', html) - return clean - - def _clean_markdown(self, md_text: str) -> str: - """Clean markdown syntax while preserving structure.""" - # Remove code blocks but keep content - md_text = re.sub(r'```[\s\S]*?```', '', md_text) - md_text = re.sub(r'`([^`]+)`', r'\1', md_text) - - # Convert headers to plain text with emphasis - md_text = re.sub(r'^#{1,6}\s+(.+)$', r'\1:', md_text, flags=re.MULTILINE) - - # Remove bold/italic markers - md_text = re.sub(r'\*\*([^*]+)\*\*', r'\1', md_text) - md_text = re.sub(r'\*([^*]+)\*', r'\1', md_text) - md_text = re.sub(r'__([^_]+)__', r'\1', md_text) - md_text = re.sub(r'_([^_]+)_', r'\1', md_text) - - # Remove links but keep text - md_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', md_text) - - # Remove images - md_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', '', md_text) - - # Convert lists to plain text - md_text = re.sub(r'^[\*\-\+]\s+', 'โ€ข ', md_text, flags=re.MULTILINE) - md_text = re.sub(r'^\d+\.\s+', '', md_text, flags=re.MULTILINE) - - return md_text - - -# Global parser instance -parser = DocumentParser() - +""" +Document ingestion and parsing pipeline. +Supports PDF, DOCX, TXT, and Markdown files. +""" +import fitz # PyMuPDF +from docx import Document as DocxDocument +import markdown +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +import chardet +import re +from dataclasses import dataclass +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class ParsedDocument: + """Represents a parsed document with text and metadata.""" + text: str + file_name: str + file_type: str + page_count: Optional[int] = None + page_map: Optional[Dict[int, int]] = None # char_position -> page_number + metadata: Dict[str, Any] = None + + +class DocumentParser: + """ + Parses various document formats into plain text. + Preserves page information for citations. + """ + + SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.doc', '.txt', '.md', '.markdown'} + + def __init__(self): + self.parsers = { + '.pdf': self._parse_pdf, + '.docx': self._parse_docx, + '.doc': self._parse_docx, # Try docx parser for doc files + '.txt': self._parse_text, + '.md': self._parse_markdown, + '.markdown': self._parse_markdown, + } + + def parse(self, file_path: Path) -> ParsedDocument: + """ + Parse a document file into text. + + Args: + file_path: Path to the document file + + Returns: + ParsedDocument with extracted text and metadata + """ + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + ext = file_path.suffix.lower() + if ext not in self.SUPPORTED_EXTENSIONS: + raise ValueError(f"Unsupported file type: {ext}") + + parser = self.parsers.get(ext) + if not parser: + raise ValueError(f"No parser available for: {ext}") + + logger.info(f"Parsing document: {file_path.name} ({ext})") + return parser(file_path) + + def _parse_pdf(self, file_path: Path) -> ParsedDocument: + """Parse PDF file with page tracking.""" + try: + doc = fitz.open(file_path) + text_parts = [] + page_map = {} + current_pos = 0 + + for page_num, page in enumerate(doc, start=1): + page_text = page.get_text("text") + if page_text.strip(): + # Record where this page starts + page_map[current_pos] = page_num + text_parts.append(page_text) + current_pos += len(page_text) + 2 # +2 for separator + + page_count = len(doc) + doc.close() + + full_text = "\n\n".join(text_parts) + + return ParsedDocument( + text=self._clean_text(full_text), + file_name=file_path.name, + file_type="pdf", + page_count=page_count, + page_map=page_map, + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing PDF {file_path}: {e}") + raise + + def _parse_docx(self, file_path: Path) -> ParsedDocument: + """Parse DOCX file.""" + try: + doc = DocxDocument(file_path) + paragraphs = [] + + for para in doc.paragraphs: + if para.text.strip(): + paragraphs.append(para.text) + + # Also extract text from tables + for table in doc.tables: + for row in table.rows: + row_text = [] + for cell in row.cells: + if cell.text.strip(): + row_text.append(cell.text.strip()) + if row_text: + paragraphs.append(" | ".join(row_text)) + + full_text = "\n\n".join(paragraphs) + + return ParsedDocument( + text=self._clean_text(full_text), + file_name=file_path.name, + file_type="docx", + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing DOCX {file_path}: {e}") + raise + + def _parse_text(self, file_path: Path) -> ParsedDocument: + """Parse plain text file with encoding detection.""" + try: + # Detect encoding + with open(file_path, 'rb') as f: + raw_data = f.read() + detected = chardet.detect(raw_data) + encoding = detected.get('encoding', 'utf-8') + + # Read with detected encoding + with open(file_path, 'r', encoding=encoding, errors='replace') as f: + text = f.read() + + return ParsedDocument( + text=self._clean_text(text), + file_name=file_path.name, + file_type="txt", + metadata={"source": str(file_path), "encoding": encoding} + ) + except Exception as e: + logger.error(f"Error parsing text file {file_path}: {e}") + raise + + def _parse_markdown(self, file_path: Path) -> ParsedDocument: + """Parse Markdown file, converting to plain text.""" + try: + with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + md_content = f.read() + + # Convert markdown to HTML, then strip tags + html = markdown.markdown(md_content) + text = self._strip_html_tags(html) + + # Also keep the original markdown structure for better context + # Remove markdown syntax but keep structure + clean_md = self._clean_markdown(md_content) + + return ParsedDocument( + text=self._clean_text(clean_md), + file_name=file_path.name, + file_type="markdown", + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing Markdown {file_path}: {e}") + raise + + def _clean_text(self, text: str) -> str: + """Clean and normalize text.""" + # Replace multiple whitespace with single space + text = re.sub(r'[ \t]+', ' ', text) + # Replace multiple newlines with double newline + text = re.sub(r'\n{3,}', '\n\n', text) + # Remove leading/trailing whitespace from lines + lines = [line.strip() for line in text.split('\n')] + text = '\n'.join(lines) + # Remove leading/trailing whitespace + return text.strip() + + def _strip_html_tags(self, html: str) -> str: + """Remove HTML tags from text.""" + clean = re.sub(r'<[^>]+>', '', html) + return clean + + def _clean_markdown(self, md_text: str) -> str: + """Clean markdown syntax while preserving structure.""" + # Remove code blocks but keep content + md_text = re.sub(r'```[\s\S]*?```', '', md_text) + md_text = re.sub(r'`([^`]+)`', r'\1', md_text) + + # Convert headers to plain text with emphasis + md_text = re.sub(r'^#{1,6}\s+(.+)$', r'\1:', md_text, flags=re.MULTILINE) + + # Remove bold/italic markers + md_text = re.sub(r'\*\*([^*]+)\*\*', r'\1', md_text) + md_text = re.sub(r'\*([^*]+)\*', r'\1', md_text) + md_text = re.sub(r'__([^_]+)__', r'\1', md_text) + md_text = re.sub(r'_([^_]+)_', r'\1', md_text) + + # Remove links but keep text + md_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', md_text) + + # Remove images + md_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', '', md_text) + + # Convert lists to plain text + md_text = re.sub(r'^[\*\-\+]\s+', 'โ€ข ', md_text, flags=re.MULTILINE) + md_text = re.sub(r'^\d+\.\s+', '', md_text, flags=re.MULTILINE) + + return md_text + + +# Global parser instance +parser = DocumentParser() + diff --git a/app/rag/intent.py b/app/rag/intent.py index e4a7eb814c6f1ed8b99cf385068535a8c345dc17..0bd9780df42b1a14f79fe16299083b248ccebe7c 100644 --- a/app/rag/intent.py +++ b/app/rag/intent.py @@ -1,153 +1,153 @@ -""" -Intent detection module for RAG pipeline. -Detects user intent from queries to enable intent-based gating. -""" -import re -from typing import List, Dict, Set -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -# Intent keywords mapping -INTENT_KEYWORDS: Dict[str, List[str]] = { - "integration": [ - "integrate", "integration", "api", "connect", "connection", "webhook", - "shopify", "woocommerce", "stripe", "paypal", "payment gateway", - "whatsapp", "telegram", "slack", "zapier", "ifttt", "automation" - ], - "billing": [ - "billing", "invoice", "payment", "subscription", "plan", "pricing", - "cost", "price", "charge", "fee", "refund", "cancel", "renew" - ], - "account": [ - "account", "profile", "settings", "preferences", "user", "login", - "signup", "register", "authentication", "auth" - ], - "password_reset": [ - "password", "reset", "forgot", "change password", "update password", - "password reset link", "expire", "expiry" - ], - "pricing": [ - "pricing", "price", "plan", "cost", "subscription", "tier", "starter", - "pro", "enterprise", "monthly", "yearly", "billing" - ], - "general": [] # Catch-all for general queries -} - - -def detect_intents(query: str) -> List[str]: - """ - Detect intents from a user query. - - Args: - query: User's question - - Returns: - List of detected intent labels (e.g., ["integration", "billing"]) - """ - query_lower = query.lower() - detected = [] - - for intent, keywords in INTENT_KEYWORDS.items(): - if intent == "general": - continue # Skip general, it's a catch-all - - # Check if any keyword matches - for keyword in keywords: - # Use word boundary matching for better accuracy - pattern = r'\b' + re.escape(keyword.lower()) + r'\b' - if re.search(pattern, query_lower): - detected.append(intent) - break # Only add intent once - - # If no specific intent detected, return general - if not detected: - detected = ["general"] - - logger.info(f"Detected intents for query '{query[:50]}...': {detected}") - return detected - - -def get_intent_keywords(intents: List[str]) -> Set[str]: - """ - Get all keywords for a list of intents. - - Args: - intents: List of intent labels - - Returns: - Set of keywords for those intents - """ - keywords = set() - for intent in intents: - if intent in INTENT_KEYWORDS: - keywords.update(INTENT_KEYWORDS[intent]) - return keywords - - -def check_direct_match( - query: str, - retrieved_chunks: List[str], - intent_keywords: Set[str] = None -) -> bool: - """ - Check if at least one retrieved chunk contains direct matches for query intent. - - Args: - query: User's question - retrieved_chunks: List of retrieved chunk texts - intent_keywords: Optional set of intent keywords to check - - Returns: - True if at least one chunk has direct match, False otherwise - """ - if not retrieved_chunks: - return False - - query_lower = query.lower() - query_words = set(re.findall(r'\b\w+\b', query_lower)) - - # Get intent keywords if not provided - if intent_keywords is None: - intents = detect_intents(query) - intent_keywords = get_intent_keywords(intents) - - # Check each chunk for direct matches - for chunk in retrieved_chunks: - chunk_lower = chunk.lower() - - # Check 1: Intent keywords must be present in chunk - if intent_keywords: - intent_found = any( - re.search(r'\b' + re.escape(kw.lower()) + r'\b', chunk_lower) - for kw in intent_keywords - ) - if not intent_found: - continue # Skip this chunk if no intent keywords - - # Check 2: At least 2-3 important query words should be in chunk - # (excluding common stop words) - stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", - "to", "of", "and", "or", "but", "in", "on", "at", "for", - "with", "how", "what", "when", "where", "why", "do", "does"} - important_words = query_words - stop_words - - if len(important_words) >= 2: - # Need at least 2 important words to match - matches = sum(1 for word in important_words if word in chunk_lower) - if matches >= 2: - logger.info(f"Direct match found: {matches} important words matched in chunk") - return True - elif len(important_words) == 1: - # Single important word - require exact phrase match - for word in important_words: - if re.search(r'\b' + re.escape(word) + r'\b', chunk_lower): - logger.info(f"Direct match found: single important word '{word}' matched") - return True - - logger.warning("No direct match found in retrieved chunks") - return False - - +""" +Intent detection module for RAG pipeline. +Detects user intent from queries to enable intent-based gating. +""" +import re +from typing import List, Dict, Set +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Intent keywords mapping +INTENT_KEYWORDS: Dict[str, List[str]] = { + "integration": [ + "integrate", "integration", "api", "connect", "connection", "webhook", + "shopify", "woocommerce", "stripe", "paypal", "payment gateway", + "whatsapp", "telegram", "slack", "zapier", "ifttt", "automation" + ], + "billing": [ + "billing", "invoice", "payment", "subscription", "plan", "pricing", + "cost", "price", "charge", "fee", "refund", "cancel", "renew" + ], + "account": [ + "account", "profile", "settings", "preferences", "user", "login", + "signup", "register", "authentication", "auth" + ], + "password_reset": [ + "password", "reset", "forgot", "change password", "update password", + "password reset link", "expire", "expiry" + ], + "pricing": [ + "pricing", "price", "plan", "cost", "subscription", "tier", "starter", + "pro", "enterprise", "monthly", "yearly", "billing" + ], + "general": [] # Catch-all for general queries +} + + +def detect_intents(query: str) -> List[str]: + """ + Detect intents from a user query. + + Args: + query: User's question + + Returns: + List of detected intent labels (e.g., ["integration", "billing"]) + """ + query_lower = query.lower() + detected = [] + + for intent, keywords in INTENT_KEYWORDS.items(): + if intent == "general": + continue # Skip general, it's a catch-all + + # Check if any keyword matches + for keyword in keywords: + # Use word boundary matching for better accuracy + pattern = r'\b' + re.escape(keyword.lower()) + r'\b' + if re.search(pattern, query_lower): + detected.append(intent) + break # Only add intent once + + # If no specific intent detected, return general + if not detected: + detected = ["general"] + + logger.info(f"Detected intents for query '{query[:50]}...': {detected}") + return detected + + +def get_intent_keywords(intents: List[str]) -> Set[str]: + """ + Get all keywords for a list of intents. + + Args: + intents: List of intent labels + + Returns: + Set of keywords for those intents + """ + keywords = set() + for intent in intents: + if intent in INTENT_KEYWORDS: + keywords.update(INTENT_KEYWORDS[intent]) + return keywords + + +def check_direct_match( + query: str, + retrieved_chunks: List[str], + intent_keywords: Set[str] = None +) -> bool: + """ + Check if at least one retrieved chunk contains direct matches for query intent. + + Args: + query: User's question + retrieved_chunks: List of retrieved chunk texts + intent_keywords: Optional set of intent keywords to check + + Returns: + True if at least one chunk has direct match, False otherwise + """ + if not retrieved_chunks: + return False + + query_lower = query.lower() + query_words = set(re.findall(r'\b\w+\b', query_lower)) + + # Get intent keywords if not provided + if intent_keywords is None: + intents = detect_intents(query) + intent_keywords = get_intent_keywords(intents) + + # Check each chunk for direct matches + for chunk in retrieved_chunks: + chunk_lower = chunk.lower() + + # Check 1: Intent keywords must be present in chunk + if intent_keywords: + intent_found = any( + re.search(r'\b' + re.escape(kw.lower()) + r'\b', chunk_lower) + for kw in intent_keywords + ) + if not intent_found: + continue # Skip this chunk if no intent keywords + + # Check 2: At least 2-3 important query words should be in chunk + # (excluding common stop words) + stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", + "to", "of", "and", "or", "but", "in", "on", "at", "for", + "with", "how", "what", "when", "where", "why", "do", "does"} + important_words = query_words - stop_words + + if len(important_words) >= 2: + # Need at least 2 important words to match + matches = sum(1 for word in important_words if word in chunk_lower) + if matches >= 2: + logger.info(f"Direct match found: {matches} important words matched in chunk") + return True + elif len(important_words) == 1: + # Single important word - require exact phrase match + for word in important_words: + if re.search(r'\b' + re.escape(word) + r'\b', chunk_lower): + logger.info(f"Direct match found: single important word '{word}' matched") + return True + + logger.warning("No direct match found in retrieved chunks") + return False + + diff --git a/app/rag/prompts.py b/app/rag/prompts.py index cbc50f895ebc132341fef900d239c57cdb7fad75..abd122c6efe6523b5cc4d22fca166ba35e66c158 100644 --- a/app/rag/prompts.py +++ b/app/rag/prompts.py @@ -1,162 +1,162 @@ -""" -Prompt templates for RAG-based question answering. -Implements strict anti-hallucination rules. -""" - -# System prompt for RAG-based answering - ENHANCED for strict anti-hallucination -RAG_SYSTEM_PROMPT = """You are a helpful customer support assistant for ClientSphere. Your ONLY job is to answer questions based STRICTLY on the provided context. - -## CRITICAL RULES - YOU MUST FOLLOW THESE (NO EXCEPTIONS): - -1. **ONLY use information from the provided context** - Do NOT use any prior knowledge, training data, or general knowledge. If it's not in the context, you don't know it. - -2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. - -3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. It's better to refuse than to hallucinate. - -4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. - -5. **Be concise and direct** - Answer the question without unnecessary elaboration or adding information not in the context. - -6. **If the question is unclear** - Ask for clarification rather than guessing what the user means. - -7. **For multi-part questions** - Address each part separately and cite sources for each. If any part isn't in the context, say so. - -8. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. - -9. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. - -10. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. - -## Response Format: -- Start with a direct answer to the question -- Include [Source X] citations inline where you use information -- End with a brief summary if the answer is complex -- If no relevant information: clearly state this and suggest contacting support - -## Example Good Response: -"Based on the documentation, the return policy allows returns within 30 days of purchase [Source 1]. Items must be in original packaging [Source 2]. For items purchased on sale, a 15-day window applies [Source 1]." - -## Example When Information Not Found: -"I couldn't find specific information about warranty extensions in the available documentation. I recommend contacting our support team at support@example.com for detailed warranty inquiries." -""" - -# User prompt template -RAG_USER_PROMPT_TEMPLATE = """## Context from Knowledge Base: - -{context} - ---- - -## User Question: -{question} - ---- - -Please answer the question using ONLY the information provided in the context above. Remember to cite sources using [Source X] notation. - -**IMPORTANT:** If the answer is not explicitly stated in the context above, you MUST say "I couldn't find this information in the knowledge base" and suggest contacting support. DO NOT attempt to answer from memory or general knowledge.""" - - -# Prompt for when no relevant context is found -NO_CONTEXT_RESPONSE = """I apologize, but I couldn't find relevant information in the knowledge base to answer your question. - -This could mean: -1. The topic isn't covered in the current documentation -2. The question might need to be rephrased for better matching - -**Recommended Actions:** -- Try rephrasing your question with different keywords -- Contact our support team directly for personalized assistance -- Check if there's additional documentation that might help - -Would you like me to help you with a different question, or would you prefer to connect with a human support agent?""" - - -# Prompt for low confidence responses -LOW_CONFIDENCE_RESPONSE = """I found some potentially relevant information, but I'm not confident it fully addresses your question. - -Based on what I found: {partial_answer} - -**However**, I recommend verifying this information with our support team, as the context may not fully cover your specific situation. - -Would you like me to connect you with a human support agent for more detailed assistance?""" - - -def format_rag_prompt(context: str, question: str) -> tuple: - """ - Format the RAG prompt for the LLM. - - Args: - context: Retrieved context from knowledge base - question: User's question - - Returns: - Tuple of (system_prompt, user_prompt) - """ - user_prompt = RAG_USER_PROMPT_TEMPLATE.format( - context=context, - question=question - ) - - return RAG_SYSTEM_PROMPT, user_prompt - - -def get_no_context_response() -> str: - """Get the response for when no context is found.""" - return NO_CONTEXT_RESPONSE - - -def get_low_confidence_response(partial_answer: str) -> str: - """Get the response for low confidence answers.""" - return LOW_CONFIDENCE_RESPONSE.format(partial_answer=partial_answer) - - -# Draft prompt for verifier mode (stricter than final prompt) -DRAFT_PROMPT_SYSTEM = """You are a customer support assistant. Generate a DRAFT answer based STRICTLY on the provided context. - -## CRITICAL RULES - THIS DRAFT WILL BE VERIFIED: - -1. **ONLY use information explicitly stated in the context** - Do NOT use any prior knowledge, training data, or general knowledge. - -2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. - -3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. - -4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. - -5. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. - -6. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. - -7. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. - -Return ONLY the draft answer with citations. This will be verified for accuracy.""" - -DRAFT_PROMPT_USER = """## Context: -{context} - -## Question: -{question} - -Generate a DRAFT answer with citations. This will be verified for accuracy.""" - - -def format_draft_prompt(context: str, question: str) -> tuple: - """ - Format the draft prompt for initial answer generation. - - Args: - context: Retrieved context from knowledge base - question: User's question - - Returns: - Tuple of (system_prompt, user_prompt) - """ - user_prompt = DRAFT_PROMPT_USER.format( - context=context, - question=question - ) - - return DRAFT_PROMPT_SYSTEM, user_prompt - +""" +Prompt templates for RAG-based question answering. +Implements strict anti-hallucination rules. +""" + +# System prompt for RAG-based answering - ENHANCED for strict anti-hallucination +RAG_SYSTEM_PROMPT = """You are a helpful customer support assistant for ClientSphere. Your ONLY job is to answer questions based STRICTLY on the provided context. + +## CRITICAL RULES - YOU MUST FOLLOW THESE (NO EXCEPTIONS): + +1. **ONLY use information from the provided context** - Do NOT use any prior knowledge, training data, or general knowledge. If it's not in the context, you don't know it. + +2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. + +3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. It's better to refuse than to hallucinate. + +4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. + +5. **Be concise and direct** - Answer the question without unnecessary elaboration or adding information not in the context. + +6. **If the question is unclear** - Ask for clarification rather than guessing what the user means. + +7. **For multi-part questions** - Address each part separately and cite sources for each. If any part isn't in the context, say so. + +8. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. + +9. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. + +10. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. + +## Response Format: +- Start with a direct answer to the question +- Include [Source X] citations inline where you use information +- End with a brief summary if the answer is complex +- If no relevant information: clearly state this and suggest contacting support + +## Example Good Response: +"Based on the documentation, the return policy allows returns within 30 days of purchase [Source 1]. Items must be in original packaging [Source 2]. For items purchased on sale, a 15-day window applies [Source 1]." + +## Example When Information Not Found: +"I couldn't find specific information about warranty extensions in the available documentation. I recommend contacting our support team at support@example.com for detailed warranty inquiries." +""" + +# User prompt template +RAG_USER_PROMPT_TEMPLATE = """## Context from Knowledge Base: + +{context} + +--- + +## User Question: +{question} + +--- + +Please answer the question using ONLY the information provided in the context above. Remember to cite sources using [Source X] notation. + +**IMPORTANT:** If the answer is not explicitly stated in the context above, you MUST say "I couldn't find this information in the knowledge base" and suggest contacting support. DO NOT attempt to answer from memory or general knowledge.""" + + +# Prompt for when no relevant context is found +NO_CONTEXT_RESPONSE = """I apologize, but I couldn't find relevant information in the knowledge base to answer your question. + +This could mean: +1. The topic isn't covered in the current documentation +2. The question might need to be rephrased for better matching + +**Recommended Actions:** +- Try rephrasing your question with different keywords +- Contact our support team directly for personalized assistance +- Check if there's additional documentation that might help + +Would you like me to help you with a different question, or would you prefer to connect with a human support agent?""" + + +# Prompt for low confidence responses +LOW_CONFIDENCE_RESPONSE = """I found some potentially relevant information, but I'm not confident it fully addresses your question. + +Based on what I found: {partial_answer} + +**However**, I recommend verifying this information with our support team, as the context may not fully cover your specific situation. + +Would you like me to connect you with a human support agent for more detailed assistance?""" + + +def format_rag_prompt(context: str, question: str) -> tuple: + """ + Format the RAG prompt for the LLM. + + Args: + context: Retrieved context from knowledge base + question: User's question + + Returns: + Tuple of (system_prompt, user_prompt) + """ + user_prompt = RAG_USER_PROMPT_TEMPLATE.format( + context=context, + question=question + ) + + return RAG_SYSTEM_PROMPT, user_prompt + + +def get_no_context_response() -> str: + """Get the response for when no context is found.""" + return NO_CONTEXT_RESPONSE + + +def get_low_confidence_response(partial_answer: str) -> str: + """Get the response for low confidence answers.""" + return LOW_CONFIDENCE_RESPONSE.format(partial_answer=partial_answer) + + +# Draft prompt for verifier mode (stricter than final prompt) +DRAFT_PROMPT_SYSTEM = """You are a customer support assistant. Generate a DRAFT answer based STRICTLY on the provided context. + +## CRITICAL RULES - THIS DRAFT WILL BE VERIFIED: + +1. **ONLY use information explicitly stated in the context** - Do NOT use any prior knowledge, training data, or general knowledge. + +2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. + +3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. + +4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. + +5. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. + +6. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. + +7. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. + +Return ONLY the draft answer with citations. This will be verified for accuracy.""" + +DRAFT_PROMPT_USER = """## Context: +{context} + +## Question: +{question} + +Generate a DRAFT answer with citations. This will be verified for accuracy.""" + + +def format_draft_prompt(context: str, question: str) -> tuple: + """ + Format the draft prompt for initial answer generation. + + Args: + context: Retrieved context from knowledge base + question: User's question + + Returns: + Tuple of (system_prompt, user_prompt) + """ + user_prompt = DRAFT_PROMPT_USER.format( + context=context, + question=question + ) + + return DRAFT_PROMPT_SYSTEM, user_prompt + diff --git a/app/rag/retrieval.py b/app/rag/retrieval.py index cf1a7a3959dd478d5861bd2ca7b403e970c6ab19..cce9c7550c18cd7acd4c59de74142671f76c894f 100644 --- a/app/rag/retrieval.py +++ b/app/rag/retrieval.py @@ -1,242 +1,242 @@ -""" -Retrieval pipeline with confidence scoring and filtering. -""" -from typing import List, Dict, Any, Optional, Tuple -import logging -import re - -from app.config import settings -from app.rag.embeddings import get_embedding_service -from app.rag.vectorstore import get_vector_store -from app.rag.intent import detect_intents, check_direct_match, get_intent_keywords -from app.models.schemas import RetrievalResult - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class RetrievalService: - """ - Handles document retrieval with confidence scoring. - Implements threshold-based filtering for quality control. - """ - - def __init__( - self, - top_k: int = settings.TOP_K, - similarity_threshold: float = settings.SIMILARITY_THRESHOLD - ): - """ - Initialize the retrieval service. - - Args: - top_k: Number of results to retrieve - similarity_threshold: Minimum similarity score to consider relevant - """ - self.top_k = top_k - self.similarity_threshold = similarity_threshold - self.embedding_service = get_embedding_service() - self.vector_store = get_vector_store() - - def retrieve( - self, - query: str, - tenant_id: str, # CRITICAL: Multi-tenant isolation - kb_id: str, - user_id: str, - top_k: Optional[int] = None - ) -> Tuple[List[RetrievalResult], float, bool]: - """ - Retrieve relevant documents for a query. - - Args: - query: User's question - tenant_id: Tenant ID for multi-tenant isolation (CRITICAL) - kb_id: Knowledge base ID to search - user_id: User ID for filtering - top_k: Optional override for number of results - - Returns: - Tuple of (results, average_confidence, has_relevant_results) - """ - k = top_k or self.top_k - - # Generate query embedding - logger.info(f"Generating embedding for query: {query[:50]}...") - query_embedding = self.embedding_service.embed_query(query) - - # Search vector store with filters - MUST include tenant_id for isolation - filter_dict = { - "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation - "kb_id": kb_id, - "user_id": user_id - } - - logger.info(f"Searching vector store with filters: {filter_dict}") - raw_results = self.vector_store.search( - query_embedding=query_embedding, - top_k=k, - filter_dict=filter_dict - ) - - if not raw_results: - logger.warning(f"No results found for query in kb_id={kb_id}") - return [], 0.0, False - - # Convert to RetrievalResult objects - results = [] - for r in raw_results: - results.append(RetrievalResult( - chunk_id=r['id'], - content=r['content'], - metadata=r['metadata'], - similarity_score=r['similarity_score'] - )) - - # HEAVY CONFIDENCE MODE: Use maximum similarity score from top results - # This ensures confidence reflects the best match found, not dragged down by weaker results - if results: - # Get top 3 results and use the maximum similarity score - # This gives maximum confidence if there's at least one strong match - top_results = results[:3] - max_score = max(r.similarity_score for r in top_results) - - # If max score is good (>=0.4), use it directly - # Otherwise, use weighted average of top 3 to avoid over-inflating weak matches - if max_score >= 0.4: - avg_confidence = max_score - else: - # For weaker matches, use weighted average of top 3 - scores = [r.similarity_score for r in top_results] - weights = [1.0, 0.7, 0.5][:len(scores)] # Aggressive weighting - weighted_sum = sum(s * w for s, w in zip(scores, weights)) - total_weight = sum(weights[:len(scores)]) - avg_confidence = weighted_sum / total_weight if total_weight > 0 else max_score - else: - avg_confidence = 0.0 - - # Filter results above threshold - filtered_results = [ - r for r in results - if r.similarity_score >= self.similarity_threshold - ] - - # If no results pass threshold but we have results, use top results anyway - # This prevents over-filtering when threshold is too strict - if not filtered_results and results: - logger.warning(f"No results passed threshold {self.similarity_threshold}, using top {min(3, len(results))} results anyway") - filtered_results = results[:min(3, len(results))] - # Recalculate confidence with the fallback results - if filtered_results: - scores = [r.similarity_score for r in filtered_results] - avg_confidence = sum(scores) / len(scores) if scores else 0.0 - - # DIRECT MATCH GATE: Check if at least one chunk directly matches query intent - # For integration/API questions, this gate is stricter - has_direct_match = False - if filtered_results: - chunk_texts = [r.content for r in filtered_results] - intents = detect_intents(query) - intent_keywords = get_intent_keywords(intents) - - # For integration/API questions, require direct match - if "integration" in intents or "api" in query.lower(): - has_direct_match = check_direct_match(query, chunk_texts, intent_keywords) - logger.info(f"Direct match check (strict for integration): {has_direct_match} (intents: {intents})") - else: - # For other questions, be more lenient - just check if important words match - query_words = set(re.findall(r'\b\w+\b', query.lower())) - stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", - "to", "of", "and", "or", "but", "in", "on", "at", "for", - "with", "how", "what", "when", "where", "why", "do", "does"} - important_words = query_words - stop_words - - # Check if at least one important word appears in chunks - for chunk in chunk_texts: - chunk_lower = chunk.lower() - matches = sum(1 for word in important_words if word in chunk_lower) - if matches >= 1 and len(important_words) > 0: # At least one important word - has_direct_match = True - break - - logger.info(f"Direct match check (lenient): {has_direct_match} (intents: {intents})") - - # Only consider relevant if we have filtered results AND (direct match OR high confidence) - # High confidence (>0.40) can bypass direct match requirement for non-integration questions - has_relevant = len(filtered_results) > 0 and (has_direct_match or avg_confidence > 0.40) - - logger.info( - f"Retrieved {len(results)} results, " - f"{len(filtered_results)} above threshold ({self.similarity_threshold}), " - f"avg confidence: {avg_confidence:.3f}, " - f"direct match: {has_direct_match}" - ) - - return filtered_results, avg_confidence, has_relevant - - def get_context_for_llm( - self, - results: List[RetrievalResult], - max_tokens: int = settings.MAX_CONTEXT_TOKENS - ) -> Tuple[str, List[Dict[str, Any]]]: - """ - Format retrieved results into context for the LLM. - - Args: - results: List of retrieval results - max_tokens: Maximum tokens for context - - Returns: - Tuple of (formatted_context, citation_info) - """ - if not results: - return "", [] - - context_parts = [] - citations = [] - current_tokens = 0 - - # Estimate tokens (rough approximation: 1 token โ‰ˆ 4 chars) - for i, result in enumerate(results): - chunk_text = result.content - estimated_tokens = len(chunk_text) // 4 - - if current_tokens + estimated_tokens > max_tokens: - logger.info(f"Truncating context at {i} chunks due to token limit") - break - - # Format chunk with source info - source_info = f"[Source {i+1}: {result.metadata.get('file_name', 'Unknown')}]" - if result.metadata.get('page_number'): - source_info += f" (Page {result.metadata['page_number']})" - - context_parts.append(f"{source_info}\n{chunk_text}") - - # Build citation info - citations.append({ - "index": i + 1, - "file_name": result.metadata.get('file_name', 'Unknown'), - "chunk_id": result.chunk_id, - "page_number": result.metadata.get('page_number'), - "similarity_score": result.similarity_score, - "excerpt": chunk_text[:200] + "..." if len(chunk_text) > 200 else chunk_text - }) - - current_tokens += estimated_tokens - - formatted_context = "\n\n---\n\n".join(context_parts) - - return formatted_context, citations - - -# Global retrieval service instance -_retrieval_service: Optional[RetrievalService] = None - - -def get_retrieval_service() -> RetrievalService: - """Get the global retrieval service instance.""" - global _retrieval_service - if _retrieval_service is None: - _retrieval_service = RetrievalService() - return _retrieval_service - +""" +Retrieval pipeline with confidence scoring and filtering. +""" +from typing import List, Dict, Any, Optional, Tuple +import logging +import re + +from app.config import settings +from app.rag.embeddings import get_embedding_service +from app.rag.vectorstore import get_vector_store +from app.rag.intent import detect_intents, check_direct_match, get_intent_keywords +from app.models.schemas import RetrievalResult + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class RetrievalService: + """ + Handles document retrieval with confidence scoring. + Implements threshold-based filtering for quality control. + """ + + def __init__( + self, + top_k: int = settings.TOP_K, + similarity_threshold: float = settings.SIMILARITY_THRESHOLD + ): + """ + Initialize the retrieval service. + + Args: + top_k: Number of results to retrieve + similarity_threshold: Minimum similarity score to consider relevant + """ + self.top_k = top_k + self.similarity_threshold = similarity_threshold + self.embedding_service = get_embedding_service() + self.vector_store = get_vector_store() + + def retrieve( + self, + query: str, + tenant_id: str, # CRITICAL: Multi-tenant isolation + kb_id: str, + user_id: str, + top_k: Optional[int] = None + ) -> Tuple[List[RetrievalResult], float, bool]: + """ + Retrieve relevant documents for a query. + + Args: + query: User's question + tenant_id: Tenant ID for multi-tenant isolation (CRITICAL) + kb_id: Knowledge base ID to search + user_id: User ID for filtering + top_k: Optional override for number of results + + Returns: + Tuple of (results, average_confidence, has_relevant_results) + """ + k = top_k or self.top_k + + # Generate query embedding + logger.info(f"Generating embedding for query: {query[:50]}...") + query_embedding = self.embedding_service.embed_query(query) + + # Search vector store with filters - MUST include tenant_id for isolation + filter_dict = { + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id + } + + logger.info(f"Searching vector store with filters: {filter_dict}") + raw_results = self.vector_store.search( + query_embedding=query_embedding, + top_k=k, + filter_dict=filter_dict + ) + + if not raw_results: + logger.warning(f"No results found for query in kb_id={kb_id}") + return [], 0.0, False + + # Convert to RetrievalResult objects + results = [] + for r in raw_results: + results.append(RetrievalResult( + chunk_id=r['id'], + content=r['content'], + metadata=r['metadata'], + similarity_score=r['similarity_score'] + )) + + # HEAVY CONFIDENCE MODE: Use maximum similarity score from top results + # This ensures confidence reflects the best match found, not dragged down by weaker results + if results: + # Get top 3 results and use the maximum similarity score + # This gives maximum confidence if there's at least one strong match + top_results = results[:3] + max_score = max(r.similarity_score for r in top_results) + + # If max score is good (>=0.4), use it directly + # Otherwise, use weighted average of top 3 to avoid over-inflating weak matches + if max_score >= 0.4: + avg_confidence = max_score + else: + # For weaker matches, use weighted average of top 3 + scores = [r.similarity_score for r in top_results] + weights = [1.0, 0.7, 0.5][:len(scores)] # Aggressive weighting + weighted_sum = sum(s * w for s, w in zip(scores, weights)) + total_weight = sum(weights[:len(scores)]) + avg_confidence = weighted_sum / total_weight if total_weight > 0 else max_score + else: + avg_confidence = 0.0 + + # Filter results above threshold + filtered_results = [ + r for r in results + if r.similarity_score >= self.similarity_threshold + ] + + # If no results pass threshold but we have results, use top results anyway + # This prevents over-filtering when threshold is too strict + if not filtered_results and results: + logger.warning(f"No results passed threshold {self.similarity_threshold}, using top {min(3, len(results))} results anyway") + filtered_results = results[:min(3, len(results))] + # Recalculate confidence with the fallback results + if filtered_results: + scores = [r.similarity_score for r in filtered_results] + avg_confidence = sum(scores) / len(scores) if scores else 0.0 + + # DIRECT MATCH GATE: Check if at least one chunk directly matches query intent + # For integration/API questions, this gate is stricter + has_direct_match = False + if filtered_results: + chunk_texts = [r.content for r in filtered_results] + intents = detect_intents(query) + intent_keywords = get_intent_keywords(intents) + + # For integration/API questions, require direct match + if "integration" in intents or "api" in query.lower(): + has_direct_match = check_direct_match(query, chunk_texts, intent_keywords) + logger.info(f"Direct match check (strict for integration): {has_direct_match} (intents: {intents})") + else: + # For other questions, be more lenient - just check if important words match + query_words = set(re.findall(r'\b\w+\b', query.lower())) + stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", + "to", "of", "and", "or", "but", "in", "on", "at", "for", + "with", "how", "what", "when", "where", "why", "do", "does"} + important_words = query_words - stop_words + + # Check if at least one important word appears in chunks + for chunk in chunk_texts: + chunk_lower = chunk.lower() + matches = sum(1 for word in important_words if word in chunk_lower) + if matches >= 1 and len(important_words) > 0: # At least one important word + has_direct_match = True + break + + logger.info(f"Direct match check (lenient): {has_direct_match} (intents: {intents})") + + # Only consider relevant if we have filtered results AND (direct match OR high confidence) + # High confidence (>0.40) can bypass direct match requirement for non-integration questions + has_relevant = len(filtered_results) > 0 and (has_direct_match or avg_confidence > 0.40) + + logger.info( + f"Retrieved {len(results)} results, " + f"{len(filtered_results)} above threshold ({self.similarity_threshold}), " + f"avg confidence: {avg_confidence:.3f}, " + f"direct match: {has_direct_match}" + ) + + return filtered_results, avg_confidence, has_relevant + + def get_context_for_llm( + self, + results: List[RetrievalResult], + max_tokens: int = settings.MAX_CONTEXT_TOKENS + ) -> Tuple[str, List[Dict[str, Any]]]: + """ + Format retrieved results into context for the LLM. + + Args: + results: List of retrieval results + max_tokens: Maximum tokens for context + + Returns: + Tuple of (formatted_context, citation_info) + """ + if not results: + return "", [] + + context_parts = [] + citations = [] + current_tokens = 0 + + # Estimate tokens (rough approximation: 1 token โ‰ˆ 4 chars) + for i, result in enumerate(results): + chunk_text = result.content + estimated_tokens = len(chunk_text) // 4 + + if current_tokens + estimated_tokens > max_tokens: + logger.info(f"Truncating context at {i} chunks due to token limit") + break + + # Format chunk with source info + source_info = f"[Source {i+1}: {result.metadata.get('file_name', 'Unknown')}]" + if result.metadata.get('page_number'): + source_info += f" (Page {result.metadata['page_number']})" + + context_parts.append(f"{source_info}\n{chunk_text}") + + # Build citation info + citations.append({ + "index": i + 1, + "file_name": result.metadata.get('file_name', 'Unknown'), + "chunk_id": result.chunk_id, + "page_number": result.metadata.get('page_number'), + "similarity_score": result.similarity_score, + "excerpt": chunk_text[:200] + "..." if len(chunk_text) > 200 else chunk_text + }) + + current_tokens += estimated_tokens + + formatted_context = "\n\n---\n\n".join(context_parts) + + return formatted_context, citations + + +# Global retrieval service instance +_retrieval_service: Optional[RetrievalService] = None + + +def get_retrieval_service() -> RetrievalService: + """Get the global retrieval service instance.""" + global _retrieval_service + if _retrieval_service is None: + _retrieval_service = RetrievalService() + return _retrieval_service + diff --git a/app/rag/vectorstore.py b/app/rag/vectorstore.py index a4e03c9d0da109065f336a7ed40f09a5f7ce1e63..a4e4e67d380a6eea57b242a7e9994ddaa88a2b72 100644 --- a/app/rag/vectorstore.py +++ b/app/rag/vectorstore.py @@ -1,273 +1,273 @@ -""" -Vector store using ChromaDB for local storage. -Supports efficient similarity search and filtering. -""" -import chromadb -from chromadb.config import Settings as ChromaSettings -from typing import List, Dict, Any, Optional -import logging -from pathlib import Path - -from app.config import settings -from app.rag.embeddings import get_embedding_service - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class VectorStore: - """ - Vector store using ChromaDB for persistent local storage. - Supports CRUD operations and similarity search. - """ - - def __init__( - self, - persist_directory: Path = settings.VECTORDB_DIR, - collection_name: str = settings.COLLECTION_NAME - ): - """ - Initialize the vector store. - - Args: - persist_directory: Directory to persist the database - collection_name: Name of the collection to use - """ - self.persist_directory = persist_directory - self.collection_name = collection_name - - # Initialize ChromaDB client with persistence - self.client = chromadb.PersistentClient( - path=str(persist_directory), - settings=ChromaSettings( - anonymized_telemetry=False, - allow_reset=True - ) - ) - - # Get or create collection - self.collection = self.client.get_or_create_collection( - name=collection_name, - metadata={"hnsw:space": "cosine"} # Use cosine similarity - ) - - logger.info(f"Vector store initialized. Collection: {collection_name}, Items: {self.collection.count()}") - - def add_documents( - self, - documents: List[str], - embeddings: List[List[float]], - metadatas: List[Dict[str, Any]], - ids: List[str] - ) -> None: - """ - Add documents to the vector store. - - Args: - documents: List of document texts - embeddings: List of embedding vectors - metadatas: List of metadata dictionaries - ids: List of unique document IDs - """ - if not documents: - logger.warning("No documents to add") - return - - # ChromaDB doesn't accept None values in metadata - clean_metadatas = [] - for meta in metadatas: - clean_meta = {} - for k, v in meta.items(): - if v is not None: - clean_meta[k] = v - clean_metadatas.append(clean_meta) - - self.collection.add( - documents=documents, - embeddings=embeddings, - metadatas=clean_metadatas, - ids=ids - ) - - logger.info(f"Added {len(documents)} documents to vector store") - - def search( - self, - query_embedding: List[float], - top_k: int = settings.TOP_K, - filter_dict: Optional[Dict[str, Any]] = None - ) -> List[Dict[str, Any]]: - """ - Search for similar documents. - - Args: - query_embedding: Query embedding vector - top_k: Number of results to return - filter_dict: Optional filter criteria (e.g., {"kb_id": "123"}) - - Returns: - List of results with document, metadata, and similarity score - """ - # ChromaDB requires filters in $and/$or format for multiple conditions - where_filter = None - if filter_dict: - if len(filter_dict) == 1: - # Single condition - use directly - where_filter = filter_dict - else: - # Multiple conditions - use $and operator - where_filter = { - "$and": [ - {k: v} for k, v in filter_dict.items() - ] - } - - results = self.collection.query( - query_embeddings=[query_embedding], - n_results=top_k, - where=where_filter, - include=["documents", "metadatas", "distances"] - ) - - # Format results - formatted_results = [] - if results and results['ids'] and results['ids'][0]: - for i, doc_id in enumerate(results['ids'][0]): - # ChromaDB returns distances, convert to similarity - # For cosine distance: similarity = 1 - distance - distance = results['distances'][0][i] if results['distances'] else 0 - similarity = 1 - distance # Convert distance to similarity - - formatted_results.append({ - 'id': doc_id, - 'content': results['documents'][0][i] if results['documents'] else "", - 'metadata': results['metadatas'][0][i] if results['metadatas'] else {}, - 'similarity_score': max(0, min(1, similarity)) # Clamp to 0-1 - }) - - return formatted_results - - def delete_by_filter(self, filter_dict: Dict[str, Any]) -> int: - """ - Delete documents matching a filter. - - Args: - filter_dict: Filter criteria - - Returns: - Number of documents deleted - """ - # ChromaDB requires filters in $and/$or format for multiple conditions - where_filter = None - if len(filter_dict) == 1: - where_filter = filter_dict - else: - where_filter = { - "$and": [ - {k: v} for k, v in filter_dict.items() - ] - } - - # First, find matching documents - results = self.collection.get( - where=where_filter, - include=["metadatas"] - ) - - if results and results['ids']: - self.collection.delete(ids=results['ids']) - logger.info(f"Deleted {len(results['ids'])} documents matching filter") - return len(results['ids']) - - return 0 - - def delete_by_ids(self, ids: List[str]) -> None: - """Delete documents by their IDs.""" - if ids: - self.collection.delete(ids=ids) - logger.info(f"Deleted {len(ids)} documents by ID") - - def get_stats( - self, - tenant_id: Optional[str] = None, # CRITICAL: Multi-tenant isolation - kb_id: Optional[str] = None, - user_id: Optional[str] = None - ) -> Dict[str, Any]: - """ - Get statistics about the vector store. - - Args: - tenant_id: Tenant ID for multi-tenant isolation (REQUIRED if filtering) - kb_id: Optional knowledge base ID to filter - user_id: Optional user ID to filter - - Returns: - Statistics dictionary - """ - filter_dict = {} - if tenant_id: - filter_dict["tenant_id"] = tenant_id # CRITICAL: Multi-tenant isolation - if kb_id: - filter_dict["kb_id"] = kb_id - if user_id: - filter_dict["user_id"] = user_id - - if filter_dict: - # ChromaDB requires filters in $and/$or format for multiple conditions - where_filter = None - if len(filter_dict) == 1: - where_filter = filter_dict - else: - where_filter = { - "$and": [ - {k: v} for k, v in filter_dict.items() - ] - } - - results = self.collection.get( - where=where_filter, - include=["metadatas"] - ) - count = len(results['ids']) if results and results['ids'] else 0 - - # Get unique file names - file_names = set() - if results and results['metadatas']: - for meta in results['metadatas']: - if 'file_name' in meta: - file_names.add(meta['file_name']) - - return { - "total_chunks": count, - "file_names": list(file_names), - "tenant_id": tenant_id, - "kb_id": kb_id, - "user_id": user_id - } - else: - return { - "total_chunks": self.collection.count(), - "collection_name": self.collection_name - } - - def clear_collection(self) -> None: - """Clear all documents from the collection.""" - self.client.delete_collection(self.collection_name) - self.collection = self.client.create_collection( - name=self.collection_name, - metadata={"hnsw:space": "cosine"} - ) - logger.info(f"Cleared collection: {self.collection_name}") - - -# Global vector store instance -_vector_store: Optional[VectorStore] = None - - -def get_vector_store() -> VectorStore: - """Get the global vector store instance.""" - global _vector_store - if _vector_store is None: - _vector_store = VectorStore() - return _vector_store - +""" +Vector store using ChromaDB for local storage. +Supports efficient similarity search and filtering. +""" +import chromadb +from chromadb.config import Settings as ChromaSettings +from typing import List, Dict, Any, Optional +import logging +from pathlib import Path + +from app.config import settings +from app.rag.embeddings import get_embedding_service + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class VectorStore: + """ + Vector store using ChromaDB for persistent local storage. + Supports CRUD operations and similarity search. + """ + + def __init__( + self, + persist_directory: Path = settings.VECTORDB_DIR, + collection_name: str = settings.COLLECTION_NAME + ): + """ + Initialize the vector store. + + Args: + persist_directory: Directory to persist the database + collection_name: Name of the collection to use + """ + self.persist_directory = persist_directory + self.collection_name = collection_name + + # Initialize ChromaDB client with persistence + self.client = chromadb.PersistentClient( + path=str(persist_directory), + settings=ChromaSettings( + anonymized_telemetry=False, + allow_reset=True + ) + ) + + # Get or create collection + self.collection = self.client.get_or_create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity + ) + + logger.info(f"Vector store initialized. Collection: {collection_name}, Items: {self.collection.count()}") + + def add_documents( + self, + documents: List[str], + embeddings: List[List[float]], + metadatas: List[Dict[str, Any]], + ids: List[str] + ) -> None: + """ + Add documents to the vector store. + + Args: + documents: List of document texts + embeddings: List of embedding vectors + metadatas: List of metadata dictionaries + ids: List of unique document IDs + """ + if not documents: + logger.warning("No documents to add") + return + + # ChromaDB doesn't accept None values in metadata + clean_metadatas = [] + for meta in metadatas: + clean_meta = {} + for k, v in meta.items(): + if v is not None: + clean_meta[k] = v + clean_metadatas.append(clean_meta) + + self.collection.add( + documents=documents, + embeddings=embeddings, + metadatas=clean_metadatas, + ids=ids + ) + + logger.info(f"Added {len(documents)} documents to vector store") + + def search( + self, + query_embedding: List[float], + top_k: int = settings.TOP_K, + filter_dict: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + Search for similar documents. + + Args: + query_embedding: Query embedding vector + top_k: Number of results to return + filter_dict: Optional filter criteria (e.g., {"kb_id": "123"}) + + Returns: + List of results with document, metadata, and similarity score + """ + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if filter_dict: + if len(filter_dict) == 1: + # Single condition - use directly + where_filter = filter_dict + else: + # Multiple conditions - use $and operator + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=top_k, + where=where_filter, + include=["documents", "metadatas", "distances"] + ) + + # Format results + formatted_results = [] + if results and results['ids'] and results['ids'][0]: + for i, doc_id in enumerate(results['ids'][0]): + # ChromaDB returns distances, convert to similarity + # For cosine distance: similarity = 1 - distance + distance = results['distances'][0][i] if results['distances'] else 0 + similarity = 1 - distance # Convert distance to similarity + + formatted_results.append({ + 'id': doc_id, + 'content': results['documents'][0][i] if results['documents'] else "", + 'metadata': results['metadatas'][0][i] if results['metadatas'] else {}, + 'similarity_score': max(0, min(1, similarity)) # Clamp to 0-1 + }) + + return formatted_results + + def delete_by_filter(self, filter_dict: Dict[str, Any]) -> int: + """ + Delete documents matching a filter. + + Args: + filter_dict: Filter criteria + + Returns: + Number of documents deleted + """ + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if len(filter_dict) == 1: + where_filter = filter_dict + else: + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + # First, find matching documents + results = self.collection.get( + where=where_filter, + include=["metadatas"] + ) + + if results and results['ids']: + self.collection.delete(ids=results['ids']) + logger.info(f"Deleted {len(results['ids'])} documents matching filter") + return len(results['ids']) + + return 0 + + def delete_by_ids(self, ids: List[str]) -> None: + """Delete documents by their IDs.""" + if ids: + self.collection.delete(ids=ids) + logger.info(f"Deleted {len(ids)} documents by ID") + + def get_stats( + self, + tenant_id: Optional[str] = None, # CRITICAL: Multi-tenant isolation + kb_id: Optional[str] = None, + user_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get statistics about the vector store. + + Args: + tenant_id: Tenant ID for multi-tenant isolation (REQUIRED if filtering) + kb_id: Optional knowledge base ID to filter + user_id: Optional user ID to filter + + Returns: + Statistics dictionary + """ + filter_dict = {} + if tenant_id: + filter_dict["tenant_id"] = tenant_id # CRITICAL: Multi-tenant isolation + if kb_id: + filter_dict["kb_id"] = kb_id + if user_id: + filter_dict["user_id"] = user_id + + if filter_dict: + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if len(filter_dict) == 1: + where_filter = filter_dict + else: + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + results = self.collection.get( + where=where_filter, + include=["metadatas"] + ) + count = len(results['ids']) if results and results['ids'] else 0 + + # Get unique file names + file_names = set() + if results and results['metadatas']: + for meta in results['metadatas']: + if 'file_name' in meta: + file_names.add(meta['file_name']) + + return { + "total_chunks": count, + "file_names": list(file_names), + "tenant_id": tenant_id, + "kb_id": kb_id, + "user_id": user_id + } + else: + return { + "total_chunks": self.collection.count(), + "collection_name": self.collection_name + } + + def clear_collection(self) -> None: + """Clear all documents from the collection.""" + self.client.delete_collection(self.collection_name) + self.collection = self.client.create_collection( + name=self.collection_name, + metadata={"hnsw:space": "cosine"} + ) + logger.info(f"Cleared collection: {self.collection_name}") + + +# Global vector store instance +_vector_store: Optional[VectorStore] = None + + +def get_vector_store() -> VectorStore: + """Get the global vector store instance.""" + global _vector_store + if _vector_store is None: + _vector_store = VectorStore() + return _vector_store + diff --git a/app/rag/verifier.py b/app/rag/verifier.py index 16d3a96fff7c55111ef0c13076af2dfc148d4bc5..7180eed4713258fb226dbb09880c20e0b349cfda 100644 --- a/app/rag/verifier.py +++ b/app/rag/verifier.py @@ -1,276 +1,276 @@ -""" -Verifier module for RAG pipeline. -Implements Draft โ†’ Verify โ†’ Final flow to minimize hallucination. -""" -import json -import re -from typing import Dict, Any, List, Optional -import logging - -from app.config import settings - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -# Verifier prompt template -VERIFIER_PROMPT = """You are a strict fact-checker for a customer support chatbot. Your job is to verify that every factual claim in a draft answer is supported by the provided context. - -## Your Task: -1. Review the DRAFT ANSWER below -2. Check each factual claim against the PROVIDED CONTEXT -3. Identify any claims that are NOT supported by the context -4. Return a JSON response with your verification results - -## CRITICAL RULES: -- If ANY claim is not explicitly supported by the context โ†’ FAIL -- If the answer adds information not in context โ†’ FAIL -- If citations are missing or incorrect โ†’ FAIL -- Only PASS if ALL claims are verifiable in the context - -## Response Format (JSON): -{{ - "pass": true/false, - "issues": ["list of issues found"], - "unsupported_claims": ["list of unsupported claims"], - "final_answer": "corrected answer if needed (optional)" -}} - -## Example FAIL Response: -{{ - "pass": false, - "issues": ["Claim about '30 days' not found in context", "Missing citation for pricing information"], - "unsupported_claims": ["Refund window is 30 days", "Starter plan costs โ‚น999"], - "final_answer": null -}} - -## Example PASS Response: -{{ - "pass": true, - "issues": [], - "unsupported_claims": [], - "final_answer": null -}} - ---- - -## PROVIDED CONTEXT: -{context} - ---- - -## DRAFT ANSWER TO VERIFY: -{draft_answer} - ---- - -Now verify the draft answer and return ONLY valid JSON (no markdown, no code blocks, just raw JSON):""" - - -class VerifierService: - """ - Verifies that draft answers are supported by retrieved context. - Implements strict factual validation to prevent hallucination. - """ - - def __init__(self, provider: Optional[Any] = None): - """ - Initialize the verifier service. - - Args: - provider: Optional LLM provider (uses same as answer service if not provided) - """ - self._provider = provider - - @property - def provider(self): - """Get or create the LLM provider for verification.""" - if self._provider is None: - from app.rag.answer import GeminiProvider, OpenAIProvider - if settings.LLM_PROVIDER == "gemini": - self._provider = GeminiProvider() - elif settings.LLM_PROVIDER == "openai": - self._provider = OpenAIProvider() - else: - raise ValueError(f"Unknown LLM provider: {settings.LLM_PROVIDER}") - return self._provider - - def verify_answer( - self, - draft_answer: str, - context: str, - citations_info: List[Dict[str, Any]] - ) -> Dict[str, Any]: - """ - Verify that a draft answer is supported by the context. - - Args: - draft_answer: The draft answer to verify - context: The retrieved context from knowledge base - citations_info: List of citation information - - Returns: - Dictionary with verification results: - { - "pass": bool, - "issues": List[str], - "unsupported_claims": List[str], - "final_answer": Optional[str] - } - """ - if not context or not draft_answer: - logger.warning("Empty context or draft answer provided to verifier") - return { - "pass": False, - "issues": ["Empty context or draft answer"], - "unsupported_claims": [], - "final_answer": None - } - - # Format verifier prompt - verifier_prompt = VERIFIER_PROMPT.format( - context=context, - draft_answer=draft_answer - ) - - try: - logger.info("Running verifier on draft answer...") - # Use a more deterministic temperature for verification - try: - raw_response = self.provider.generate( - system_prompt="You are a strict fact-checker. Return ONLY valid JSON.", - user_prompt=verifier_prompt - ) - except Exception as e: - logger.error(f"Error calling LLM in verifier: {e}", exc_info=True) - # On LLM error, fail conservatively - return { - "pass": False, - "issues": [f"Verifier LLM error: {str(e)}"], - "unsupported_claims": [], - "final_answer": None - } - - # Parse JSON response - try: - verification_result = self._parse_verifier_response(raw_response) - except Exception as e: - logger.error(f"Error parsing verifier response: {e}", exc_info=True) - logger.error(f"Raw response was: {raw_response[:500]}") - # On parse error, fail conservatively - return { - "pass": False, - "issues": [f"Verifier parse error: {str(e)}"], - "unsupported_claims": [], - "final_answer": None - } - - if verification_result["pass"]: - logger.info("โœ… Verifier PASSED - All claims supported by context") - else: - logger.warning( - f"โŒ Verifier FAILED - Issues: {verification_result.get('issues', [])}" - ) - - return verification_result - - except Exception as e: - logger.error(f"Unexpected error in verifier: {e}", exc_info=True) - # On error, fail conservatively - return { - "pass": False, - "issues": [f"Verifier error: {str(e)}"], - "unsupported_claims": [], - "final_answer": None - } - - def _parse_verifier_response(self, raw_response: str) -> Dict[str, Any]: - """ - Parse the verifier's JSON response. - - Args: - raw_response: Raw response from LLM - - Returns: - Parsed verification result - """ - # Try to extract JSON from response - # Remove markdown code blocks if present - cleaned = raw_response.strip() - if cleaned.startswith("```json"): - cleaned = cleaned[7:] - if cleaned.startswith("```"): - cleaned = cleaned[3:] - if cleaned.endswith("```"): - cleaned = cleaned[:-3] - cleaned = cleaned.strip() - - # Try to find JSON object in the response - # Look for { ... } pattern - json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned, re.DOTALL) - if json_match: - cleaned = json_match.group(0) - - try: - result = json.loads(cleaned) - - # Validate structure - if not isinstance(result, dict): - raise ValueError("Response is not a dictionary") - - # Ensure required fields - return { - "pass": result.get("pass", False), - "issues": result.get("issues", []), - "unsupported_claims": result.get("unsupported_claims", []), - "final_answer": result.get("final_answer") - } - - except (json.JSONDecodeError, ValueError) as e: - logger.error(f"Failed to parse verifier JSON: {e}") - logger.error(f"Raw response (first 500 chars): {raw_response[:500]}") - logger.error(f"Cleaned response (first 500 chars): {cleaned[:500]}") - - # Fallback: try to infer pass/fail from text - response_lower = raw_response.lower() - # Check for explicit pass indicators - if ("pass" in response_lower and ("true" in response_lower or "yes" in response_lower)) or \ - ("all claims" in response_lower and "supported" in response_lower): - logger.warning("Using fallback: inferred PASS from text") - return { - "pass": True, - "issues": [], - "unsupported_claims": [], - "final_answer": None - } - elif ("pass" in response_lower and ("false" in response_lower or "no" in response_lower)) or \ - ("not supported" in response_lower or "unsupported" in response_lower): - logger.warning("Using fallback: inferred FAIL from text") - return { - "pass": False, - "issues": ["Failed to parse verifier response - inferred fail from text"], - "unsupported_claims": [], - "final_answer": None - } - else: - # Default to fail for safety - logger.warning("Using fallback: defaulting to FAIL (could not infer from text)") - return { - "pass": False, - "issues": [f"Failed to parse verifier response: {str(e)}"], - "unsupported_claims": [], - "final_answer": None - } - - -# Global verifier instance -_verifier_service: Optional[VerifierService] = None - - -def get_verifier_service() -> VerifierService: - """Get the global verifier service instance.""" - global _verifier_service - if _verifier_service is None: - _verifier_service = VerifierService() - return _verifier_service - +""" +Verifier module for RAG pipeline. +Implements Draft โ†’ Verify โ†’ Final flow to minimize hallucination. +""" +import json +import re +from typing import Dict, Any, List, Optional +import logging + +from app.config import settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Verifier prompt template +VERIFIER_PROMPT = """You are a strict fact-checker for a customer support chatbot. Your job is to verify that every factual claim in a draft answer is supported by the provided context. + +## Your Task: +1. Review the DRAFT ANSWER below +2. Check each factual claim against the PROVIDED CONTEXT +3. Identify any claims that are NOT supported by the context +4. Return a JSON response with your verification results + +## CRITICAL RULES: +- If ANY claim is not explicitly supported by the context โ†’ FAIL +- If the answer adds information not in context โ†’ FAIL +- If citations are missing or incorrect โ†’ FAIL +- Only PASS if ALL claims are verifiable in the context + +## Response Format (JSON): +{{ + "pass": true/false, + "issues": ["list of issues found"], + "unsupported_claims": ["list of unsupported claims"], + "final_answer": "corrected answer if needed (optional)" +}} + +## Example FAIL Response: +{{ + "pass": false, + "issues": ["Claim about '30 days' not found in context", "Missing citation for pricing information"], + "unsupported_claims": ["Refund window is 30 days", "Starter plan costs โ‚น999"], + "final_answer": null +}} + +## Example PASS Response: +{{ + "pass": true, + "issues": [], + "unsupported_claims": [], + "final_answer": null +}} + +--- + +## PROVIDED CONTEXT: +{context} + +--- + +## DRAFT ANSWER TO VERIFY: +{draft_answer} + +--- + +Now verify the draft answer and return ONLY valid JSON (no markdown, no code blocks, just raw JSON):""" + + +class VerifierService: + """ + Verifies that draft answers are supported by retrieved context. + Implements strict factual validation to prevent hallucination. + """ + + def __init__(self, provider: Optional[Any] = None): + """ + Initialize the verifier service. + + Args: + provider: Optional LLM provider (uses same as answer service if not provided) + """ + self._provider = provider + + @property + def provider(self): + """Get or create the LLM provider for verification.""" + if self._provider is None: + from app.rag.answer import GeminiProvider, OpenAIProvider + if settings.LLM_PROVIDER == "gemini": + self._provider = GeminiProvider() + elif settings.LLM_PROVIDER == "openai": + self._provider = OpenAIProvider() + else: + raise ValueError(f"Unknown LLM provider: {settings.LLM_PROVIDER}") + return self._provider + + def verify_answer( + self, + draft_answer: str, + context: str, + citations_info: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + Verify that a draft answer is supported by the context. + + Args: + draft_answer: The draft answer to verify + context: The retrieved context from knowledge base + citations_info: List of citation information + + Returns: + Dictionary with verification results: + { + "pass": bool, + "issues": List[str], + "unsupported_claims": List[str], + "final_answer": Optional[str] + } + """ + if not context or not draft_answer: + logger.warning("Empty context or draft answer provided to verifier") + return { + "pass": False, + "issues": ["Empty context or draft answer"], + "unsupported_claims": [], + "final_answer": None + } + + # Format verifier prompt + verifier_prompt = VERIFIER_PROMPT.format( + context=context, + draft_answer=draft_answer + ) + + try: + logger.info("Running verifier on draft answer...") + # Use a more deterministic temperature for verification + try: + raw_response = self.provider.generate( + system_prompt="You are a strict fact-checker. Return ONLY valid JSON.", + user_prompt=verifier_prompt + ) + except Exception as e: + logger.error(f"Error calling LLM in verifier: {e}", exc_info=True) + # On LLM error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier LLM error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + # Parse JSON response + try: + verification_result = self._parse_verifier_response(raw_response) + except Exception as e: + logger.error(f"Error parsing verifier response: {e}", exc_info=True) + logger.error(f"Raw response was: {raw_response[:500]}") + # On parse error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier parse error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + if verification_result["pass"]: + logger.info("โœ… Verifier PASSED - All claims supported by context") + else: + logger.warning( + f"โŒ Verifier FAILED - Issues: {verification_result.get('issues', [])}" + ) + + return verification_result + + except Exception as e: + logger.error(f"Unexpected error in verifier: {e}", exc_info=True) + # On error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + def _parse_verifier_response(self, raw_response: str) -> Dict[str, Any]: + """ + Parse the verifier's JSON response. + + Args: + raw_response: Raw response from LLM + + Returns: + Parsed verification result + """ + # Try to extract JSON from response + # Remove markdown code blocks if present + cleaned = raw_response.strip() + if cleaned.startswith("```json"): + cleaned = cleaned[7:] + if cleaned.startswith("```"): + cleaned = cleaned[3:] + if cleaned.endswith("```"): + cleaned = cleaned[:-3] + cleaned = cleaned.strip() + + # Try to find JSON object in the response + # Look for { ... } pattern + json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned, re.DOTALL) + if json_match: + cleaned = json_match.group(0) + + try: + result = json.loads(cleaned) + + # Validate structure + if not isinstance(result, dict): + raise ValueError("Response is not a dictionary") + + # Ensure required fields + return { + "pass": result.get("pass", False), + "issues": result.get("issues", []), + "unsupported_claims": result.get("unsupported_claims", []), + "final_answer": result.get("final_answer") + } + + except (json.JSONDecodeError, ValueError) as e: + logger.error(f"Failed to parse verifier JSON: {e}") + logger.error(f"Raw response (first 500 chars): {raw_response[:500]}") + logger.error(f"Cleaned response (first 500 chars): {cleaned[:500]}") + + # Fallback: try to infer pass/fail from text + response_lower = raw_response.lower() + # Check for explicit pass indicators + if ("pass" in response_lower and ("true" in response_lower or "yes" in response_lower)) or \ + ("all claims" in response_lower and "supported" in response_lower): + logger.warning("Using fallback: inferred PASS from text") + return { + "pass": True, + "issues": [], + "unsupported_claims": [], + "final_answer": None + } + elif ("pass" in response_lower and ("false" in response_lower or "no" in response_lower)) or \ + ("not supported" in response_lower or "unsupported" in response_lower): + logger.warning("Using fallback: inferred FAIL from text") + return { + "pass": False, + "issues": ["Failed to parse verifier response - inferred fail from text"], + "unsupported_claims": [], + "final_answer": None + } + else: + # Default to fail for safety + logger.warning("Using fallback: defaulting to FAIL (could not infer from text)") + return { + "pass": False, + "issues": [f"Failed to parse verifier response: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + +# Global verifier instance +_verifier_service: Optional[VerifierService] = None + + +def get_verifier_service() -> VerifierService: + """Get the global verifier service instance.""" + global _verifier_service + if _verifier_service is None: + _verifier_service = VerifierService() + return _verifier_service + diff --git a/app/utils/__init__.py b/app/utils/__init__.py index 10b7dfe004fe2fedc72db0e075cc7bff7b84a2e0..bcb106f7a9003bf3b0264c729474e4154e5246a4 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -1,6 +1,6 @@ -""" -Utility functions for the RAG backend. -""" - - - +""" +Utility functions for the RAG backend. +""" + + + diff --git a/env.example.txt b/env.example.txt index 5f4628b9c7deb46c206407beb92aa992c2b99e24..ed42263c20c277c75eba7041ec086c00da6dd6a7 100644 --- a/env.example.txt +++ b/env.example.txt @@ -1,32 +1,32 @@ -# ClientSphere RAG Backend Configuration -# Rename this file to .env and fill in your values - -# LLM Provider (gemini or openai) -LLM_PROVIDER=gemini - -# Gemini API Key (get from https://makersuite.google.com/app/apikey) -GEMINI_API_KEY=your_gemini_api_key_here - -# OpenAI API Key (optional, if using OpenAI instead) -# OPENAI_API_KEY=your_openai_api_key_here - -# Model names (optional, defaults provided) -# GEMINI_MODEL=gemini-pro -# OPENAI_MODEL=gpt-3.5-turbo - -# Embedding model (optional, default is all-MiniLM-L6-v2) -# EMBEDDING_MODEL=all-MiniLM-L6-v2 - -# Chunking settings (optional) -# CHUNK_SIZE=500 -# CHUNK_OVERLAP=100 - -# Retrieval settings (optional) -# TOP_K=6 -# SIMILARITY_THRESHOLD=0.35 - -# Debug mode -DEBUG=true - - - +# ClientSphere RAG Backend Configuration +# Rename this file to .env and fill in your values + +# LLM Provider (gemini or openai) +LLM_PROVIDER=gemini + +# Gemini API Key (get from https://makersuite.google.com/app/apikey) +GEMINI_API_KEY=your_gemini_api_key_here + +# OpenAI API Key (optional, if using OpenAI instead) +# OPENAI_API_KEY=your_openai_api_key_here + +# Model names (optional, defaults provided) +# GEMINI_MODEL=gemini-pro +# OPENAI_MODEL=gpt-3.5-turbo + +# Embedding model (optional, default is all-MiniLM-L6-v2) +# EMBEDDING_MODEL=all-MiniLM-L6-v2 + +# Chunking settings (optional) +# CHUNK_SIZE=500 +# CHUNK_OVERLAP=100 + +# Retrieval settings (optional) +# TOP_K=6 +# SIMILARITY_THRESHOLD=0.35 + +# Debug mode +DEBUG=true + + + diff --git a/render.yaml b/render.yaml index e0dc60cfcab610a4709e40695aaa13c777b7bb74..7a952f503fa514d7e23db7ad8b35e1b16a688370 100644 --- a/render.yaml +++ b/render.yaml @@ -1,21 +1,20 @@ -services: - - type: web - name: clientsphere-rag-backend - env: python - buildCommand: pip install -r requirements.txt - startCommand: uvicorn app.main:app --host 0.0.0.0 --port $PORT - envVars: - - key: GEMINI_API_KEY - sync: false # Set in Render dashboard - - key: ENV - value: prod - - key: LLM_PROVIDER - value: gemini - - key: ALLOWED_ORIGINS - sync: false # Set in Render dashboard - - key: JWT_SECRET - sync: false # Set in Render dashboard (same as Node.js backend) - - key: DEBUG - value: "false" - - +services: + - type: web + name: clientsphere-rag-backend + env: python + buildCommand: pip install -r requirements.txt + startCommand: uvicorn app.main:app --host 0.0.0.0 --port $PORT + envVars: + - key: GEMINI_API_KEY + sync: false # Set in Render dashboard + - key: ENV + value: prod + - key: LLM_PROVIDER + value: gemini + - key: ALLOWED_ORIGINS + sync: false # Set in Render dashboard + - key: JWT_SECRET + sync: false # Set in Render dashboard (same as Node.js backend) + - key: DEBUG + value: "false" + diff --git a/scripts/__init__.py b/scripts/__init__.py index 906bde563db236c7834641915e893fbea964c0f2..ecda08c6a2faef2cd00d5a984a6980d3ad0ac186 100644 --- a/scripts/__init__.py +++ b/scripts/__init__.py @@ -1,4 +1,4 @@ -# Scripts package - - - +# Scripts package + + + diff --git a/scripts/create_billing_tables.py b/scripts/create_billing_tables.py index d8dbd6a71ae6efb6a6f7a40af2f0ca5b111e37ff..2a5acae3c3630ba933ca7fea898f565d0cc9e99c 100644 --- a/scripts/create_billing_tables.py +++ b/scripts/create_billing_tables.py @@ -1,33 +1,33 @@ -""" -Create billing database tables. -Run this script to initialize the billing database. -""" -import sys -from pathlib import Path - -# Add parent directory to path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from app.db.database import init_db, engine -from app.db.models import Base - -def main(): - """Create all billing tables.""" - print("Creating billing database tables...") - try: - # Create all tables - Base.metadata.create_all(bind=engine) - print("โœ… Billing tables created successfully!") - print("\nTables created:") - print(" - tenants") - print(" - tenant_plans") - print(" - usage_events") - print(" - usage_daily") - print(" - usage_monthly") - except Exception as e: - print(f"โŒ Error creating tables: {e}") - sys.exit(1) - -if __name__ == "__main__": - main() - +""" +Create billing database tables. +Run this script to initialize the billing database. +""" +import sys +from pathlib import Path + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.db.database import init_db, engine +from app.db.models import Base + +def main(): + """Create all billing tables.""" + print("Creating billing database tables...") + try: + # Create all tables + Base.metadata.create_all(bind=engine) + print("โœ… Billing tables created successfully!") + print("\nTables created:") + print(" - tenants") + print(" - tenant_plans") + print(" - usage_events") + print(" - usage_daily") + print(" - usage_monthly") + except Exception as e: + print(f"โŒ Error creating tables: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() +