ChiragPatankar commited on
Commit
c19c7bf
·
1 Parent(s): ea5123f

Add all RAG backend files - force add

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +37 -0
  2. Dockerfile +34 -0
  3. README.md +47 -0
  4. README_HF_SPACES.md +231 -0
  5. app.py +13 -0
  6. app/__init__.py +7 -0
  7. app/__pycache__/__init__.cpython-313.pyc +0 -0
  8. app/__pycache__/config.cpython-313.pyc +0 -0
  9. app/__pycache__/main.cpython-313.pyc +0 -0
  10. app/billing/__pycache__/pricing.cpython-313.pyc +0 -0
  11. app/billing/__pycache__/quota.cpython-313.pyc +0 -0
  12. app/billing/__pycache__/usage_tracker.cpython-313.pyc +0 -0
  13. app/billing/pricing.py +57 -0
  14. app/billing/quota.py +131 -0
  15. app/billing/usage_tracker.py +173 -0
  16. app/config.py +77 -0
  17. app/db/__init__.py +2 -0
  18. app/db/__pycache__/__init__.cpython-313.pyc +0 -0
  19. app/db/__pycache__/database.cpython-313.pyc +0 -0
  20. app/db/__pycache__/models.cpython-313.pyc +0 -0
  21. app/db/database.py +53 -0
  22. app/db/models.py +129 -0
  23. app/main.py +1039 -0
  24. app/middleware/__init__.py +13 -0
  25. app/middleware/__pycache__/__init__.cpython-313.pyc +0 -0
  26. app/middleware/__pycache__/auth.cpython-313.pyc +0 -0
  27. app/middleware/__pycache__/rate_limit.cpython-313.pyc +0 -0
  28. app/middleware/auth.py +212 -0
  29. app/middleware/rate_limit.py +40 -0
  30. app/models/__init__.py +33 -0
  31. app/models/__pycache__/__init__.cpython-313.pyc +0 -0
  32. app/models/__pycache__/billing_schemas.cpython-313.pyc +0 -0
  33. app/models/__pycache__/schemas.cpython-313.pyc +0 -0
  34. app/models/billing_schemas.py +46 -0
  35. app/models/schemas.py +112 -0
  36. app/rag/__init__.py +27 -0
  37. app/rag/__pycache__/__init__.cpython-313.pyc +0 -0
  38. app/rag/__pycache__/answer.cpython-313.pyc +0 -0
  39. app/rag/__pycache__/chunking.cpython-313.pyc +0 -0
  40. app/rag/__pycache__/embeddings.cpython-313.pyc +0 -0
  41. app/rag/__pycache__/ingest.cpython-313.pyc +0 -0
  42. app/rag/__pycache__/intent.cpython-313.pyc +0 -0
  43. app/rag/__pycache__/prompts.cpython-313.pyc +0 -0
  44. app/rag/__pycache__/retrieval.cpython-313.pyc +0 -0
  45. app/rag/__pycache__/vectorstore.cpython-313.pyc +0 -0
  46. app/rag/__pycache__/verifier.cpython-313.pyc +0 -0
  47. app/rag/answer.py +444 -0
  48. app/rag/chunking.py +196 -0
  49. app/rag/embeddings.py +145 -0
  50. app/rag/ingest.py +231 -0
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python files
2
+ *.py
3
+ *.pyc
4
+ __pycache__/
5
+
6
+ # Config files
7
+ *.txt
8
+ *.yaml
9
+ *.yml
10
+ *.toml
11
+ *.json
12
+ *.md
13
+ *.sh
14
+ *.bat
15
+ *.ps1
16
+
17
+ # Directories to include
18
+ app/
19
+ requirements.txt
20
+ Dockerfile
21
+ app.py
22
+ README.md
23
+ .gitignore
24
+ .env.example.txt
25
+
26
+ # Exclude binaries
27
+ *.png
28
+ *.jpg
29
+ *.jpeg
30
+ *.db
31
+ *.pdf
32
+ *.bin
33
+ *.sqlite3
34
+ public/
35
+ data/billing/
36
+ data/vectordb/
37
+ venv/
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ gcc \
9
+ g++ \
10
+ curl \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # Copy requirements first (for better caching)
14
+ COPY requirements.txt .
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY . .
21
+
22
+ # Create necessary directories
23
+ RUN mkdir -p data/uploads data/processed data/vectordb data/billing
24
+
25
+ # Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var)
26
+ EXPOSE 7860
27
+
28
+ # Health check
29
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
30
+ CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1
31
+
32
+ # Start the application (Hugging Face Spaces provides PORT env var)
33
+ CMD uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860}
34
+
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ClientSphere RAG Backend
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ sdk_version: "4.0.0"
8
+ python_version: "3.11"
9
+ app_file: app.py
10
+ pinned: false
11
+ ---
12
+
13
+ # ClientSphere RAG Backend
14
+
15
+ FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform.
16
+
17
+ ## Features
18
+
19
+ - 📚 Knowledge base management
20
+ - 🔍 Semantic search with embeddings
21
+ - 💬 AI-powered chat with citations
22
+ - 📊 Confidence scoring
23
+ - 🔒 Multi-tenant isolation
24
+ - 📈 Usage tracking and billing
25
+
26
+ ## API Endpoints
27
+
28
+ - `GET /health/live` - Health check
29
+ - `GET /kb/stats` - Knowledge base statistics
30
+ - `POST /kb/upload` - Upload documents
31
+ - `POST /chat` - Chat with RAG
32
+ - `GET /kb/search` - Search knowledge base
33
+
34
+ ## Environment Variables
35
+
36
+ Required:
37
+ - `GEMINI_API_KEY` - Google Gemini API key
38
+ - `ENV` - Set to `prod` for production
39
+ - `LLM_PROVIDER` - `gemini` or `openai`
40
+
41
+ Optional:
42
+ - `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated)
43
+ - `JWT_SECRET` - JWT secret for authentication
44
+
45
+ ## Documentation
46
+
47
+ See `README_HF_SPACES.md` for deployment details.
README_HF_SPACES.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Deploy RAG Backend to Hugging Face Spaces
2
+
3
+ Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies!
4
+
5
+ ## ✅ Why Hugging Face Spaces?
6
+
7
+ - ✅ **Free tier** with generous limits
8
+ - ✅ **Full Python 3.11+** support
9
+ - ✅ **ML libraries** fully supported (sentence-transformers, chromadb, etc.)
10
+ - ✅ **Persistent storage** for vector database
11
+ - ✅ **No bundle size limits**
12
+ - ✅ **GPU support** available (paid)
13
+ - ✅ **Automatic HTTPS** and custom domains
14
+ - ✅ **GitHub integration** (auto-deploy on push)
15
+
16
+ ## 📋 Prerequisites
17
+
18
+ 1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co)
19
+ 2. **GitHub Repository**: Your code should be in a GitHub repository
20
+ 3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey)
21
+
22
+ ## 🚀 Step-by-Step Deployment
23
+
24
+ ### Step 1: Prepare Your Repository
25
+
26
+ Your `rag-backend/` directory should contain:
27
+ - ✅ `app.py` - Entry point (already created)
28
+ - ✅ `requirements.txt` - Dependencies
29
+ - ✅ `app/main.py` - FastAPI application
30
+ - ✅ All other application files
31
+
32
+ ### Step 2: Create Hugging Face Space
33
+
34
+ 1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
35
+ 2. Click **"Create new Space"**
36
+ 3. Configure:
37
+ - **Owner**: Your username
38
+ - **Space name**: `clientsphere-rag-backend` (or your choice)
39
+ - **SDK**: **Docker** (recommended) or **Gradio** (if you want UI)
40
+ - **Hardware**:
41
+ - **CPU basic** (free) - Good for testing
42
+ - **CPU upgrade** (paid) - Better performance
43
+ - **GPU** (paid) - For heavy ML workloads
44
+
45
+ ### Step 3: Connect GitHub Repository
46
+
47
+ 1. In Space creation, select **"Repository"** as source
48
+ 2. Choose your GitHub repository
49
+ 3. Set **Repository path** to: `rag-backend/` (subdirectory)
50
+ 4. Click **"Create Space"**
51
+
52
+ ### Step 4: Configure Environment Variables
53
+
54
+ 1. Go to your Space's **Settings** tab
55
+ 2. Scroll to **"Repository secrets"** or **"Variables"**
56
+ 3. Add these secrets:
57
+
58
+ **Required:**
59
+ ```
60
+ GEMINI_API_KEY=your_gemini_api_key_here
61
+ ENV=prod
62
+ LLM_PROVIDER=gemini
63
+ ```
64
+
65
+ **Optional (but recommended):**
66
+ ```
67
+ ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev
68
+ JWT_SECRET=your_secure_jwt_secret
69
+ DEBUG=false
70
+ ```
71
+
72
+ ### Step 5: Configure Docker (if using Docker SDK)
73
+
74
+ If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`.
75
+
76
+ **Your existing `Dockerfile` should work!** It's already configured correctly.
77
+
78
+ ### Step 6: Alternative - Use app.py (Simpler)
79
+
80
+ If you want to use the simpler `app.py` approach:
81
+
82
+ 1. In Space settings, set:
83
+ - **SDK**: **Gradio** or **Streamlit** (but we'll override)
84
+ - **App file**: `app.py`
85
+
86
+ 2. Hugging Face will automatically:
87
+ - Install dependencies from `requirements.txt`
88
+ - Run `python app.py`
89
+ - Expose on port 7860
90
+
91
+ ### Step 7: Deploy!
92
+
93
+ 1. **Push to GitHub** (if not already):
94
+ ```bash
95
+ git add rag-backend/app.py
96
+ git commit -m "Add Hugging Face Spaces entry point"
97
+ git push origin main
98
+ ```
99
+
100
+ 2. **Hugging Face will auto-deploy** from your GitHub repo!
101
+
102
+ 3. **Wait for build** (5-10 minutes first time, faster after)
103
+
104
+ 4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space`
105
+
106
+ ## 🔧 Configuration Options
107
+
108
+ ### Option A: Docker (Recommended)
109
+
110
+ **Advantages:**
111
+ - Full control over environment
112
+ - Can customize Python version
113
+ - Better for production
114
+
115
+ **Setup:**
116
+ - Use existing `Dockerfile`
117
+ - Hugging Face will build and run it
118
+ - Exposes on port 7860 automatically
119
+
120
+ ### Option B: app.py (Simpler)
121
+
122
+ **Advantages:**
123
+ - Simpler setup
124
+ - Faster builds
125
+ - Good for development
126
+
127
+ **Setup:**
128
+ - Create `app.py` in `rag-backend/` (already done)
129
+ - Hugging Face runs it automatically
130
+
131
+ ## 📝 Environment Variables Reference
132
+
133
+ | Variable | Required | Description |
134
+ |----------|----------|-------------|
135
+ | `GEMINI_API_KEY` | ✅ Yes | Your Google Gemini API key |
136
+ | `ENV` | ✅ Yes | Set to `prod` for production |
137
+ | `LLM_PROVIDER` | ✅ Yes | `gemini` or `openai` |
138
+ | `ALLOWED_ORIGINS` | ⚠️ Recommended | CORS allowed origins (comma-separated) |
139
+ | `JWT_SECRET` | ⚠️ Recommended | JWT secret for authentication |
140
+ | `DEBUG` | ❌ Optional | Set to `false` in production |
141
+ | `OPENAI_API_KEY` | ❌ Optional | If using OpenAI instead of Gemini |
142
+
143
+ ## 🌐 CORS Configuration
144
+
145
+ After deployment, update `ALLOWED_ORIGINS` to include:
146
+ - Your Cloudflare Pages frontend URL
147
+ - Your Cloudflare Workers backend URL
148
+ - Any other origins that need access
149
+
150
+ Example:
151
+ ```
152
+ ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev
153
+ ```
154
+
155
+ ## 🔄 Updating Deployment
156
+
157
+ **Automatic (Recommended):**
158
+ - Push to GitHub → Hugging Face auto-deploys
159
+
160
+ **Manual:**
161
+ - Go to Space → Settings → "Rebuild Space"
162
+
163
+ ## 📊 Resource Limits
164
+
165
+ ### Free Tier:
166
+ - ✅ **CPU**: Basic (sufficient for RAG)
167
+ - ✅ **Storage**: 50GB (plenty for vector DB)
168
+ - ✅ **Memory**: 16GB RAM
169
+ - ✅ **Build time**: 20 minutes
170
+ - ✅ **Sleep after inactivity**: 48 hours (wakes on request)
171
+
172
+ ### Paid Tiers:
173
+ - **CPU upgrade**: Better performance
174
+ - **GPU**: For heavy ML workloads
175
+ - **No sleep**: Always-on service
176
+
177
+ ## 🧪 Testing Deployment
178
+
179
+ After deployment, test your endpoints:
180
+
181
+ ```bash
182
+ # Health check
183
+ curl https://your-username-clientsphere-rag-backend.hf.space/health/live
184
+
185
+ # KB Stats (with auth)
186
+ curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test
187
+ ```
188
+
189
+ ## 🔗 Update Frontend
190
+
191
+ After deployment, update Cloudflare Pages environment variable:
192
+
193
+ ```
194
+ VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space
195
+ ```
196
+
197
+ Then redeploy frontend:
198
+ ```bash
199
+ npm run build
200
+ npx wrangler pages deploy dist --project-name=clientsphere
201
+ ```
202
+
203
+ ## ✅ Advantages Over Render
204
+
205
+ | Feature | Hugging Face Spaces | Render |
206
+ |---------|-------------------|--------|
207
+ | Free Tier | ✅ Generous | ⚠️ Limited |
208
+ | ML Libraries | ✅ Full support | ✅ Full support |
209
+ | Auto-deploy | ✅ GitHub integration | ✅ GitHub integration |
210
+ | Storage | ✅ 50GB free | ⚠️ Limited |
211
+ | Sleep Mode | ✅ Wakes on request | ❌ No sleep mode |
212
+ | GPU Support | ✅ Available | ❌ Not available |
213
+ | Community | ✅ Large ML community | ⚠️ Smaller |
214
+
215
+ ## 🎯 Summary
216
+
217
+ 1. ✅ Create Hugging Face Space
218
+ 2. ✅ Connect GitHub repository (rag-backend/)
219
+ 3. ✅ Set environment variables
220
+ 4. ✅ Deploy (automatic on push)
221
+ 5. ✅ Update frontend `VITE_RAG_API_URL`
222
+ 6. ✅ Test and enjoy!
223
+
224
+ **Your RAG backend will be live at:**
225
+ `https://your-username-clientsphere-rag-backend.hf.space`
226
+
227
+ ---
228
+
229
+ **Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces)
230
+
231
+
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces entry point for RAG Backend.
3
+ This file is used when deploying to Hugging Face Spaces.
4
+ """
5
+ import os
6
+ import uvicorn
7
+ from app.main import app
8
+
9
+ if __name__ == "__main__":
10
+ # Hugging Face Spaces provides PORT environment variable (defaults to 7860)
11
+ port = int(os.getenv("PORT", 7860))
12
+ uvicorn.run(app, host="0.0.0.0", port=port)
13
+
app/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ ClientSphere RAG Backend Application.
3
+ """
4
+ __version__ = "1.0.0"
5
+
6
+
7
+
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (222 Bytes). View file
 
app/__pycache__/config.cpython-313.pyc ADDED
Binary file (3.22 kB). View file
 
app/__pycache__/main.cpython-313.pyc ADDED
Binary file (38.9 kB). View file
 
app/billing/__pycache__/pricing.cpython-313.pyc ADDED
Binary file (2.32 kB). View file
 
app/billing/__pycache__/quota.cpython-313.pyc ADDED
Binary file (5.12 kB). View file
 
app/billing/__pycache__/usage_tracker.cpython-313.pyc ADDED
Binary file (5.56 kB). View file
 
app/billing/pricing.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pricing table for LLM providers.
3
+ Used to calculate estimated costs from token usage.
4
+ """
5
+ from typing import Dict, Optional
6
+
7
+ # Pricing per 1M tokens (as of 2024, update as needed)
8
+ PRICING_TABLE: Dict[str, Dict[str, float]] = {
9
+ "gemini": {
10
+ "gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens
11
+ "gemini-1.5-pro": {"input": 1.25, "output": 5.00},
12
+ "gemini-1.5-flash": {"input": 0.075, "output": 0.30},
13
+ "gemini-1.0-pro": {"input": 0.50, "output": 1.50},
14
+ "default": {"input": 0.50, "output": 1.50}
15
+ },
16
+ "openai": {
17
+ "gpt-4": {"input": 30.00, "output": 60.00},
18
+ "gpt-4-turbo": {"input": 10.00, "output": 30.00},
19
+ "gpt-3.5-turbo": {"input": 0.50, "output": 1.50},
20
+ "default": {"input": 0.50, "output": 1.50}
21
+ }
22
+ }
23
+
24
+
25
+ def calculate_cost(
26
+ provider: str,
27
+ model: str,
28
+ prompt_tokens: int,
29
+ completion_tokens: int
30
+ ) -> float:
31
+ """
32
+ Calculate estimated cost in USD based on token usage.
33
+
34
+ Args:
35
+ provider: "gemini" or "openai"
36
+ model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo")
37
+ prompt_tokens: Number of input tokens
38
+ completion_tokens: Number of output tokens
39
+
40
+ Returns:
41
+ Estimated cost in USD
42
+ """
43
+ provider_pricing = PRICING_TABLE.get(provider.lower(), {})
44
+ model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
45
+
46
+ # Calculate cost: (tokens / 1M) * price_per_1M
47
+ input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"]
48
+ output_cost = (completion_tokens / 1_000_000) * model_pricing["output"]
49
+
50
+ return input_cost + output_cost
51
+
52
+
53
+ def get_model_pricing(provider: str, model: str) -> Dict[str, float]:
54
+ """Get pricing for a specific model."""
55
+ provider_pricing = PRICING_TABLE.get(provider.lower(), {})
56
+ return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50}))
57
+
app/billing/quota.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quota management and enforcement.
3
+ """
4
+ from sqlalchemy.orm import Session
5
+ from sqlalchemy import func, and_
6
+ from datetime import datetime, timedelta
7
+ from typing import Optional, Tuple
8
+ import logging
9
+
10
+ from app.db.models import TenantPlan, UsageMonthly, Tenant
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Plan limits (chats per month)
14
+ PLAN_LIMITS = {
15
+ "starter": 500,
16
+ "growth": 5000,
17
+ "pro": -1 # -1 means unlimited
18
+ }
19
+
20
+
21
+ def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]:
22
+ """Get tenant's current plan."""
23
+ return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first()
24
+
25
+
26
+ def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]:
27
+ """Get monthly usage for tenant."""
28
+ now = datetime.utcnow()
29
+ target_year = year or now.year
30
+ target_month = month or now.month
31
+
32
+ return db.query(UsageMonthly).filter(
33
+ and_(
34
+ UsageMonthly.tenant_id == tenant_id,
35
+ UsageMonthly.year == target_year,
36
+ UsageMonthly.month == target_month
37
+ )
38
+ ).first()
39
+
40
+
41
+ def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]:
42
+ """
43
+ Check if tenant has quota remaining for the current month.
44
+
45
+ Returns:
46
+ (has_quota, error_message)
47
+ has_quota: True if quota available, False if exceeded
48
+ error_message: None if quota available, error message if exceeded
49
+ """
50
+ # Get tenant plan
51
+ plan = get_tenant_plan(db, tenant_id)
52
+
53
+ if not plan:
54
+ # Default to starter plan if no plan assigned
55
+ logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter")
56
+ monthly_limit = PLAN_LIMITS.get("starter", 500)
57
+ else:
58
+ monthly_limit = plan.monthly_chat_limit
59
+
60
+ # Unlimited plan (-1) always passes
61
+ if monthly_limit == -1:
62
+ return True, None
63
+
64
+ # Get current month usage
65
+ now = datetime.utcnow()
66
+ monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
67
+
68
+ current_usage = monthly_usage.total_requests if monthly_usage else 0
69
+
70
+ # Check if quota exceeded
71
+ if current_usage >= monthly_limit:
72
+ return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan."
73
+
74
+ return True, None
75
+
76
+
77
+ def ensure_tenant_exists(db: Session, tenant_id: str) -> None:
78
+ """Ensure tenant record exists in database."""
79
+ tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first()
80
+ if not tenant:
81
+ # Create tenant with default starter plan
82
+ tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}")
83
+ db.add(tenant)
84
+
85
+ # Create default starter plan
86
+ plan = TenantPlan(
87
+ tenant_id=tenant_id,
88
+ plan_name="starter",
89
+ monthly_chat_limit=PLAN_LIMITS["starter"]
90
+ )
91
+ db.add(plan)
92
+ db.commit()
93
+ logger.info(f"Created tenant {tenant_id} with starter plan")
94
+
95
+
96
+ def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan:
97
+ """
98
+ Set tenant's subscription plan.
99
+
100
+ Args:
101
+ db: Database session
102
+ tenant_id: Tenant ID
103
+ plan_name: "starter", "growth", or "pro"
104
+
105
+ Returns:
106
+ Updated TenantPlan
107
+ """
108
+ if plan_name not in PLAN_LIMITS:
109
+ raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}")
110
+
111
+ # Ensure tenant exists
112
+ ensure_tenant_exists(db, tenant_id)
113
+
114
+ # Get or create plan
115
+ plan = get_tenant_plan(db, tenant_id)
116
+ if plan:
117
+ plan.plan_name = plan_name
118
+ plan.monthly_chat_limit = PLAN_LIMITS[plan_name]
119
+ plan.updated_at = datetime.utcnow()
120
+ else:
121
+ plan = TenantPlan(
122
+ tenant_id=tenant_id,
123
+ plan_name=plan_name,
124
+ monthly_chat_limit=PLAN_LIMITS[plan_name]
125
+ )
126
+ db.add(plan)
127
+
128
+ db.commit()
129
+ db.refresh(plan)
130
+ return plan
131
+
app/billing/usage_tracker.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage tracking service.
3
+ Tracks token usage and costs for each LLM request.
4
+ """
5
+ from sqlalchemy.orm import Session
6
+ from sqlalchemy import func, and_
7
+ from datetime import datetime, timedelta
8
+ from typing import Optional
9
+ import uuid
10
+ import logging
11
+
12
+ from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant
13
+ from app.billing.pricing import calculate_cost
14
+ from app.billing.quota import ensure_tenant_exists
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def track_usage(
20
+ db: Session,
21
+ tenant_id: str,
22
+ user_id: str,
23
+ kb_id: str,
24
+ provider: str,
25
+ model: str,
26
+ prompt_tokens: int,
27
+ completion_tokens: int,
28
+ request_timestamp: Optional[datetime] = None
29
+ ) -> UsageEvent:
30
+ """
31
+ Track a single usage event.
32
+
33
+ Args:
34
+ db: Database session
35
+ tenant_id: Tenant ID
36
+ user_id: User ID
37
+ kb_id: Knowledge base ID
38
+ provider: "gemini" or "openai"
39
+ model: Model name
40
+ prompt_tokens: Input tokens
41
+ completion_tokens: Output tokens
42
+ request_timestamp: Request timestamp (defaults to now)
43
+
44
+ Returns:
45
+ Created UsageEvent
46
+ """
47
+ # Ensure tenant exists
48
+ ensure_tenant_exists(db, tenant_id)
49
+
50
+ # Calculate cost
51
+ total_tokens = prompt_tokens + completion_tokens
52
+ estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens)
53
+
54
+ # Create usage event
55
+ request_id = f"req_{uuid.uuid4().hex[:16]}"
56
+ timestamp = request_timestamp or datetime.utcnow()
57
+
58
+ usage_event = UsageEvent(
59
+ request_id=request_id,
60
+ tenant_id=tenant_id,
61
+ user_id=user_id,
62
+ kb_id=kb_id,
63
+ provider=provider,
64
+ model=model,
65
+ prompt_tokens=prompt_tokens,
66
+ completion_tokens=completion_tokens,
67
+ total_tokens=total_tokens,
68
+ estimated_cost_usd=estimated_cost,
69
+ request_timestamp=timestamp
70
+ )
71
+
72
+ db.add(usage_event)
73
+
74
+ # Update daily aggregation
75
+ _update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
76
+
77
+ # Update monthly aggregation
78
+ _update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost)
79
+
80
+ db.commit()
81
+ db.refresh(usage_event)
82
+
83
+ logger.info(
84
+ f"Tracked usage: tenant={tenant_id}, provider={provider}, "
85
+ f"tokens={total_tokens}, cost=${estimated_cost:.6f}"
86
+ )
87
+
88
+ return usage_event
89
+
90
+
91
+ def _update_daily_usage(
92
+ db: Session,
93
+ tenant_id: str,
94
+ timestamp: datetime,
95
+ provider: str,
96
+ tokens: int,
97
+ cost: float
98
+ ):
99
+ """Update daily usage aggregation."""
100
+ date = timestamp.date()
101
+ date_start = datetime.combine(date, datetime.min.time())
102
+
103
+ daily = db.query(UsageDaily).filter(
104
+ and_(
105
+ UsageDaily.tenant_id == tenant_id,
106
+ UsageDaily.date == date_start
107
+ )
108
+ ).first()
109
+
110
+ if daily:
111
+ daily.total_requests += 1
112
+ daily.total_tokens += tokens
113
+ daily.total_cost_usd += cost
114
+ if provider == "gemini":
115
+ daily.gemini_requests += 1
116
+ elif provider == "openai":
117
+ daily.openai_requests += 1
118
+ daily.updated_at = datetime.utcnow()
119
+ else:
120
+ daily = UsageDaily(
121
+ tenant_id=tenant_id,
122
+ date=date_start,
123
+ total_requests=1,
124
+ total_tokens=tokens,
125
+ total_cost_usd=cost,
126
+ gemini_requests=1 if provider == "gemini" else 0,
127
+ openai_requests=1 if provider == "openai" else 0
128
+ )
129
+ db.add(daily)
130
+
131
+
132
+ def _update_monthly_usage(
133
+ db: Session,
134
+ tenant_id: str,
135
+ timestamp: datetime,
136
+ provider: str,
137
+ tokens: int,
138
+ cost: float
139
+ ):
140
+ """Update monthly usage aggregation."""
141
+ year = timestamp.year
142
+ month = timestamp.month
143
+
144
+ monthly = db.query(UsageMonthly).filter(
145
+ and_(
146
+ UsageMonthly.tenant_id == tenant_id,
147
+ UsageMonthly.year == year,
148
+ UsageMonthly.month == month
149
+ )
150
+ ).first()
151
+
152
+ if monthly:
153
+ monthly.total_requests += 1
154
+ monthly.total_tokens += tokens
155
+ monthly.total_cost_usd += cost
156
+ if provider == "gemini":
157
+ monthly.gemini_requests += 1
158
+ elif provider == "openai":
159
+ monthly.openai_requests += 1
160
+ monthly.updated_at = datetime.utcnow()
161
+ else:
162
+ monthly = UsageMonthly(
163
+ tenant_id=tenant_id,
164
+ year=year,
165
+ month=month,
166
+ total_requests=1,
167
+ total_tokens=tokens,
168
+ total_cost_usd=cost,
169
+ gemini_requests=1 if provider == "gemini" else 0,
170
+ openai_requests=1 if provider == "openai" else 0
171
+ )
172
+ db.add(monthly)
173
+
app/config.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the RAG backend.
3
+ """
4
+ from pydantic_settings import BaseSettings
5
+ from pathlib import Path
6
+ from typing import Optional
7
+ import os
8
+
9
+
10
+ class Settings(BaseSettings):
11
+ """Application settings with environment variable support."""
12
+
13
+ # App settings
14
+ APP_NAME: str = "ClientSphere RAG Backend"
15
+ DEBUG: bool = True
16
+ ENV: str = "dev" # "dev" or "prod" - controls tenant_id security
17
+
18
+ # Paths
19
+ BASE_DIR: Path = Path(__file__).parent.parent
20
+ DATA_DIR: Path = BASE_DIR / "data"
21
+ UPLOADS_DIR: Path = DATA_DIR / "uploads"
22
+ PROCESSED_DIR: Path = DATA_DIR / "processed"
23
+ VECTORDB_DIR: Path = DATA_DIR / "vectordb"
24
+
25
+ # Chunking settings (optimized for retrieval quality)
26
+ CHUNK_SIZE: int = 600 # tokens (increased for better context)
27
+ CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity)
28
+ MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks)
29
+
30
+ # Embedding settings
31
+ EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality
32
+ EMBEDDING_DIMENSION: int = 384
33
+
34
+ # Vector store settings
35
+ COLLECTION_NAME: str = "clientsphere_kb"
36
+
37
+ # Retrieval settings (optimized for maximum confidence)
38
+ TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches)
39
+ SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks
40
+ SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination)
41
+
42
+ # LLM settings
43
+ LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai"
44
+ GEMINI_API_KEY: Optional[str] = None
45
+ OPENAI_API_KEY: Optional[str] = None
46
+ GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model
47
+ OPENAI_MODEL: str = "gpt-3.5-turbo"
48
+
49
+ # Response settings
50
+ MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus)
51
+ TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination)
52
+ REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention
53
+
54
+ # Security settings
55
+ MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB
56
+ ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all)
57
+ RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute
58
+ JWT_SECRET: Optional[str] = None # JWT secret for authentication
59
+
60
+ # Rate limiting
61
+ RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting
62
+
63
+ class Config:
64
+ env_file = ".env"
65
+ env_file_encoding = "utf-8"
66
+
67
+ def __init__(self, **kwargs):
68
+ super().__init__(**kwargs)
69
+ # Create directories if they don't exist
70
+ self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
71
+ self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
72
+ self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True)
73
+
74
+
75
+ # Global settings instance
76
+ settings = Settings()
77
+
app/db/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Database module
2
+
app/db/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (138 Bytes). View file
 
app/db/__pycache__/database.cpython-313.pyc ADDED
Binary file (2.12 kB). View file
 
app/db/__pycache__/models.cpython-313.pyc ADDED
Binary file (5.03 kB). View file
 
app/db/database.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database setup and session management.
3
+ Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema.
4
+ """
5
+ from sqlalchemy import create_engine
6
+ from sqlalchemy.ext.declarative import declarative_base
7
+ from sqlalchemy.orm import sessionmaker, Session
8
+ from pathlib import Path
9
+ import logging
10
+
11
+ from app.config import settings
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Database path
16
+ DB_DIR = settings.DATA_DIR / "billing"
17
+ DB_DIR.mkdir(parents=True, exist_ok=True)
18
+ DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}"
19
+
20
+ # Create engine
21
+ engine = create_engine(
22
+ DATABASE_URL,
23
+ connect_args={"check_same_thread": False}, # SQLite specific
24
+ echo=False # Set to True for SQL query logging
25
+ )
26
+
27
+ # Session factory
28
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
29
+
30
+ # Base class for models
31
+ Base = declarative_base()
32
+
33
+
34
+ def get_db() -> Session:
35
+ """Get database session (dependency for FastAPI)."""
36
+ db = SessionLocal()
37
+ try:
38
+ yield db
39
+ finally:
40
+ db.close()
41
+
42
+
43
+ def init_db():
44
+ """Initialize database tables."""
45
+ Base.metadata.create_all(bind=engine)
46
+ logger.info("Database tables created/verified")
47
+
48
+
49
+ def drop_db():
50
+ """Drop all tables (use with caution!)."""
51
+ Base.metadata.drop_all(bind=engine)
52
+ logger.warning("All database tables dropped")
53
+
app/db/models.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database models for billing and usage tracking.
3
+ """
4
+ from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text
5
+ from sqlalchemy.orm import relationship
6
+ from datetime import datetime
7
+ from typing import Optional
8
+
9
+ from app.db.database import Base
10
+
11
+
12
+ class Tenant(Base):
13
+ """Tenant/organization model."""
14
+ __tablename__ = "tenants"
15
+
16
+ id = Column(String, primary_key=True, index=True)
17
+ name = Column(String, nullable=False)
18
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
19
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
20
+
21
+ # Relationships
22
+ plan = relationship("TenantPlan", back_populates="tenant", uselist=False)
23
+ usage_events = relationship("UsageEvent", back_populates="tenant")
24
+ daily_usage = relationship("UsageDaily", back_populates="tenant")
25
+ monthly_usage = relationship("UsageMonthly", back_populates="tenant")
26
+
27
+
28
+ class TenantPlan(Base):
29
+ """Tenant subscription plan."""
30
+ __tablename__ = "tenant_plans"
31
+
32
+ id = Column(Integer, primary_key=True, autoincrement=True)
33
+ tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True)
34
+ plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro"
35
+ monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited
36
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
37
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
38
+
39
+ # Relationships
40
+ tenant = relationship("Tenant", back_populates="plan")
41
+
42
+
43
+ class UsageEvent(Base):
44
+ """Individual usage event (each /chat request)."""
45
+ __tablename__ = "usage_events"
46
+
47
+ id = Column(Integer, primary_key=True, autoincrement=True)
48
+ request_id = Column(String, unique=True, nullable=False, index=True)
49
+ tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
50
+ user_id = Column(String, nullable=False, index=True)
51
+ kb_id = Column(String, nullable=False)
52
+
53
+ # LLM details
54
+ provider = Column(String, nullable=False) # "gemini" or "openai"
55
+ model = Column(String, nullable=False)
56
+
57
+ # Token usage
58
+ prompt_tokens = Column(Integer, nullable=False, default=0)
59
+ completion_tokens = Column(Integer, nullable=False, default=0)
60
+ total_tokens = Column(Integer, nullable=False, default=0)
61
+
62
+ # Cost tracking
63
+ estimated_cost_usd = Column(Float, nullable=False, default=0.0)
64
+
65
+ # Timestamp
66
+ request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
67
+
68
+ # Relationships
69
+ tenant = relationship("Tenant", back_populates="usage_events")
70
+
71
+
72
+ class UsageDaily(Base):
73
+ """Daily aggregated usage per tenant."""
74
+ __tablename__ = "usage_daily"
75
+
76
+ id = Column(Integer, primary_key=True, autoincrement=True)
77
+ tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
78
+ date = Column(DateTime, nullable=False, index=True)
79
+
80
+ # Aggregated metrics
81
+ total_requests = Column(Integer, nullable=False, default=0)
82
+ total_tokens = Column(Integer, nullable=False, default=0)
83
+ total_cost_usd = Column(Float, nullable=False, default=0.0)
84
+
85
+ # Provider breakdown
86
+ gemini_requests = Column(Integer, nullable=False, default=0)
87
+ openai_requests = Column(Integer, nullable=False, default=0)
88
+
89
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
90
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
91
+
92
+ # Unique constraint: one record per tenant per day
93
+ __table_args__ = (
94
+ {"sqlite_autoincrement": True},
95
+ )
96
+
97
+ # Relationships
98
+ tenant = relationship("Tenant", back_populates="daily_usage")
99
+
100
+
101
+ class UsageMonthly(Base):
102
+ """Monthly aggregated usage per tenant."""
103
+ __tablename__ = "usage_monthly"
104
+
105
+ id = Column(Integer, primary_key=True, autoincrement=True)
106
+ tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True)
107
+ year = Column(Integer, nullable=False, index=True)
108
+ month = Column(Integer, nullable=False, index=True) # 1-12
109
+
110
+ # Aggregated metrics
111
+ total_requests = Column(Integer, nullable=False, default=0)
112
+ total_tokens = Column(Integer, nullable=False, default=0)
113
+ total_cost_usd = Column(Float, nullable=False, default=0.0)
114
+
115
+ # Provider breakdown
116
+ gemini_requests = Column(Integer, nullable=False, default=0)
117
+ openai_requests = Column(Integer, nullable=False, default=0)
118
+
119
+ created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
120
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
121
+
122
+ # Unique constraint: one record per tenant per month
123
+ __table_args__ = (
124
+ {"sqlite_autoincrement": True},
125
+ )
126
+
127
+ # Relationships
128
+ tenant = relationship("Tenant", back_populates="monthly_usage")
129
+
app/main.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for ClientSphere RAG Backend.
3
+ Provides endpoints for knowledge base management and chat.
4
+ """
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.exceptions import RequestValidationError
8
+ from fastapi.responses import JSONResponse
9
+ from pathlib import Path
10
+ import shutil
11
+ import uuid
12
+ from datetime import datetime
13
+ from typing import Optional
14
+ import logging
15
+
16
+ from app.config import settings
17
+ from app.middleware.auth import get_auth_context, require_auth
18
+ from app.middleware.rate_limit import (
19
+ limiter,
20
+ get_tenant_rate_limit_key,
21
+ RateLimitExceeded,
22
+ _rate_limit_exceeded_handler
23
+ )
24
+ from app.models.schemas import (
25
+ UploadResponse,
26
+ ChatRequest,
27
+ ChatResponse,
28
+ KnowledgeBaseStats,
29
+ HealthResponse,
30
+ DocumentStatus,
31
+ Citation,
32
+ )
33
+ from app.models.billing_schemas import (
34
+ UsageResponse,
35
+ PlanLimitsResponse,
36
+ CostReportResponse,
37
+ SetPlanRequest
38
+ )
39
+ from app.rag.ingest import parser
40
+ from app.rag.chunking import chunker
41
+ from app.rag.embeddings import get_embedding_service
42
+ from app.rag.vectorstore import get_vector_store
43
+ from app.rag.retrieval import get_retrieval_service
44
+ from app.rag.answer import get_answer_service
45
+ from app.db.database import get_db, init_db
46
+ from app.billing.quota import check_quota, ensure_tenant_exists
47
+ from app.billing.usage_tracker import track_usage
48
+
49
+ logging.basicConfig(level=logging.INFO)
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # Initialize FastAPI app
53
+ app = FastAPI(
54
+ title=settings.APP_NAME,
55
+ description="RAG-based customer support chatbot API",
56
+ version="1.0.0",
57
+ )
58
+
59
+ # Initialize database on startup
60
+ @app.on_event("startup")
61
+ async def startup_event():
62
+ """Initialize database on application startup."""
63
+ init_db()
64
+ logger.info("Database initialized")
65
+
66
+ # Configure CORS - SECURITY: Restrict in production
67
+ if settings.ALLOWED_ORIGINS == "*":
68
+ allowed_origins = ["*"]
69
+ else:
70
+ # Split by comma and strip whitespace
71
+ allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()]
72
+
73
+ # Default to allowing localhost if no origins specified
74
+ if not allowed_origins or allowed_origins == ["*"]:
75
+ allowed_origins = ["*"] # Allow all in dev mode
76
+
77
+ logger.info(f"CORS configured with origins: {allowed_origins}")
78
+
79
+ app.add_middleware(
80
+ CORSMiddleware,
81
+ allow_origins=allowed_origins,
82
+ allow_credentials=True,
83
+ allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight
84
+ allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers
85
+ )
86
+
87
+ # Configure rate limiting
88
+ app.state.limiter = limiter
89
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
90
+
91
+ # Add exception handler for validation errors
92
+ @app.exception_handler(RequestValidationError)
93
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
94
+ """Handle request validation errors with detailed logging."""
95
+ body = await request.body()
96
+ logger.error(f"Request validation error: {exc.errors()}")
97
+ logger.error(f"Request body (raw): {body}")
98
+ logger.error(f"Request headers: {dict(request.headers)}")
99
+ return JSONResponse(
100
+ status_code=422,
101
+ content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')}
102
+ )
103
+
104
+ # Add exception handler for validation errors
105
+ from fastapi.exceptions import RequestValidationError
106
+ from fastapi.responses import JSONResponse
107
+
108
+ @app.exception_handler(RequestValidationError)
109
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
110
+ """Handle request validation errors with detailed logging."""
111
+ logger.error(f"Request validation error: {exc.errors()}")
112
+ logger.error(f"Request body: {await request.body()}")
113
+ return JSONResponse(
114
+ status_code=422,
115
+ content={"detail": exc.errors(), "body": str(await request.body())}
116
+ )
117
+
118
+
119
+ # ============== Health & Status Endpoints ==============
120
+
121
+ @app.get("/", response_model=HealthResponse)
122
+ async def root():
123
+ """Root endpoint with basic info."""
124
+ return HealthResponse(
125
+ status="ok",
126
+ version="1.0.0",
127
+ vector_db_connected=True,
128
+ llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
129
+ )
130
+
131
+
132
+ @app.get("/health", response_model=HealthResponse)
133
+ async def health_check():
134
+ """Health check endpoint."""
135
+ try:
136
+ vector_store = get_vector_store()
137
+ stats = vector_store.get_stats()
138
+
139
+ return HealthResponse(
140
+ status="healthy",
141
+ version="1.0.0",
142
+ vector_db_connected=True,
143
+ llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
144
+ )
145
+ except Exception as e:
146
+ logger.error(f"Health check failed: {e}")
147
+ return HealthResponse(
148
+ status="unhealthy",
149
+ version="1.0.0",
150
+ vector_db_connected=False,
151
+ llm_configured=False
152
+ )
153
+
154
+
155
+ @app.get("/health/live")
156
+ async def liveness():
157
+ """Kubernetes liveness probe - always returns alive."""
158
+ return {"status": "alive"}
159
+
160
+
161
+ @app.get("/health/ready")
162
+ async def readiness():
163
+ """Kubernetes readiness probe - checks dependencies."""
164
+ checks = {
165
+ "vector_db": False,
166
+ "llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
167
+ }
168
+
169
+ # Check vector DB connection
170
+ try:
171
+ vector_store = get_vector_store()
172
+ vector_store.get_stats()
173
+ checks["vector_db"] = True
174
+ except Exception as e:
175
+ logger.warning(f"Vector DB check failed: {e}")
176
+ checks["vector_db"] = False
177
+
178
+ # All checks must pass
179
+ if all(checks.values()):
180
+ return {"status": "ready", "checks": checks}
181
+ else:
182
+ from fastapi import HTTPException
183
+ raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks})
184
+
185
+
186
+ # ============== Knowledge Base Endpoints ==============
187
+
188
+ @app.post("/kb/upload", response_model=UploadResponse)
189
+ @limiter.limit("20/hour", key_func=get_tenant_rate_limit_key)
190
+ async def upload_document(
191
+ background_tasks: BackgroundTasks,
192
+ request: Request,
193
+ file: UploadFile = File(...),
194
+ tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
195
+ user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
196
+ kb_id: str = Form(...)
197
+ ):
198
+ """
199
+ Upload a document to the knowledge base.
200
+
201
+ - Saves file to disk
202
+ - Parses and chunks the document
203
+ - Generates embeddings
204
+ - Stores in vector database
205
+ """
206
+ # SECURITY: Extract tenant_id from auth token in production
207
+ if settings.ENV == "prod":
208
+ auth_context = await require_auth(request)
209
+ tenant_id = auth_context.get("tenant_id")
210
+ if not tenant_id:
211
+ raise HTTPException(
212
+ status_code=403,
213
+ detail="tenant_id must come from authentication token in production mode"
214
+ )
215
+ elif not tenant_id:
216
+ raise HTTPException(
217
+ status_code=400,
218
+ detail="tenant_id is required"
219
+ )
220
+
221
+ # Validate file type
222
+ file_ext = Path(file.filename).suffix.lower()
223
+ if file_ext not in parser.SUPPORTED_EXTENSIONS:
224
+ raise HTTPException(
225
+ status_code=400,
226
+ detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}"
227
+ )
228
+
229
+ # Validate file size (SECURITY)
230
+ file.file.seek(0, 2) # Seek to end
231
+ file_size = file.file.tell()
232
+ file.file.seek(0) # Reset to start
233
+ max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024
234
+ if file_size > max_size_bytes:
235
+ raise HTTPException(
236
+ status_code=400,
237
+ detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB"
238
+ )
239
+
240
+ # Generate document ID
241
+ doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}"
242
+
243
+ # Save file to uploads directory
244
+ upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}"
245
+ try:
246
+ with open(upload_path, "wb") as buffer:
247
+ shutil.copyfileobj(file.file, buffer)
248
+ logger.info(f"Saved file: {upload_path}")
249
+ except Exception as e:
250
+ logger.error(f"Error saving file: {e}")
251
+ raise HTTPException(status_code=500, detail="Failed to save file")
252
+
253
+ # Process document in background
254
+ background_tasks.add_task(
255
+ process_document,
256
+ upload_path,
257
+ tenant_id, # CRITICAL: Multi-tenant isolation
258
+ user_id,
259
+ kb_id,
260
+ file.filename,
261
+ doc_id
262
+ )
263
+
264
+ return UploadResponse(
265
+ success=True,
266
+ message="Document upload started. Processing in background.",
267
+ document_id=doc_id,
268
+ file_name=file.filename,
269
+ chunks_created=0,
270
+ status=DocumentStatus.PROCESSING
271
+ )
272
+
273
+
274
+ async def process_document(
275
+ file_path: Path,
276
+ tenant_id: str, # CRITICAL: Multi-tenant isolation
277
+ user_id: str,
278
+ kb_id: str,
279
+ original_filename: str,
280
+ document_id: str
281
+ ):
282
+ """
283
+ Background task to process an uploaded document.
284
+ """
285
+ try:
286
+ logger.info(f"Processing document: {original_filename}")
287
+
288
+ # Parse document
289
+ parsed_doc = parser.parse(file_path)
290
+ logger.info(f"Parsed document: {len(parsed_doc.text)} characters")
291
+
292
+ # Chunk document
293
+ chunks = chunker.chunk_text(
294
+ parsed_doc.text,
295
+ page_numbers=parsed_doc.page_map
296
+ )
297
+ logger.info(f"Created {len(chunks)} chunks")
298
+
299
+ if not chunks:
300
+ logger.warning(f"No chunks created from {original_filename}")
301
+ return
302
+
303
+ # Create metadata for each chunk
304
+ metadatas = []
305
+ chunk_ids = []
306
+ chunk_texts = []
307
+
308
+ for chunk in chunks:
309
+ metadata = chunker.create_chunk_metadata(
310
+ chunk=chunk,
311
+ tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
312
+ kb_id=kb_id,
313
+ user_id=user_id,
314
+ file_name=original_filename,
315
+ file_type=parsed_doc.file_type,
316
+ total_chunks=len(chunks),
317
+ document_id=document_id
318
+ )
319
+ metadatas.append(metadata)
320
+ chunk_ids.append(metadata["chunk_id"])
321
+ chunk_texts.append(chunk.content)
322
+
323
+ # Generate embeddings
324
+ embedding_service = get_embedding_service()
325
+ embeddings = embedding_service.embed_texts(chunk_texts)
326
+ logger.info(f"Generated {len(embeddings)} embeddings")
327
+
328
+ # Store in vector database
329
+ vector_store = get_vector_store()
330
+ vector_store.add_documents(
331
+ documents=chunk_texts,
332
+ embeddings=embeddings,
333
+ metadatas=metadatas,
334
+ ids=chunk_ids
335
+ )
336
+
337
+ logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored")
338
+
339
+ except Exception as e:
340
+ logger.error(f"Error processing document {original_filename}: {e}")
341
+ raise
342
+
343
+
344
+ @app.get("/kb/stats", response_model=KnowledgeBaseStats)
345
+ async def get_kb_stats(
346
+ request: Request,
347
+ tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
348
+ kb_id: Optional[str] = None,
349
+ user_id: Optional[str] = None # Optional in dev, ignored in prod
350
+ ):
351
+ """Get statistics for a knowledge base."""
352
+ # SECURITY: Get tenant_id and user_id from auth context
353
+ auth_context = await get_auth_context(request)
354
+ tenant_id_from_auth = auth_context.get("tenant_id")
355
+ user_id_from_auth = auth_context.get("user_id")
356
+
357
+ if settings.ENV == "prod":
358
+ if not tenant_id_from_auth or not user_id_from_auth:
359
+ raise HTTPException(
360
+ status_code=403,
361
+ detail="tenant_id and user_id must come from authentication token in production mode"
362
+ )
363
+ tenant_id = tenant_id_from_auth
364
+ user_id = user_id_from_auth
365
+ else:
366
+ tenant_id = tenant_id or tenant_id_from_auth
367
+ user_id = user_id or user_id_from_auth
368
+ if not tenant_id or not kb_id or not user_id:
369
+ raise HTTPException(
370
+ status_code=400,
371
+ detail="tenant_id, kb_id, and user_id are required"
372
+ )
373
+
374
+ try:
375
+ vector_store = get_vector_store()
376
+ stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id)
377
+
378
+ return KnowledgeBaseStats(
379
+ tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
380
+ kb_id=kb_id,
381
+ user_id=user_id,
382
+ total_documents=len(stats.get("file_names", [])),
383
+ total_chunks=stats.get("total_chunks", 0),
384
+ file_names=stats.get("file_names", []),
385
+ last_updated=datetime.utcnow()
386
+ )
387
+ except Exception as e:
388
+ logger.error(f"Error getting KB stats: {e}")
389
+ raise HTTPException(status_code=500, detail=str(e))
390
+
391
+
392
+ @app.delete("/kb/document")
393
+ async def delete_document(
394
+ request: Request,
395
+ tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
396
+ kb_id: Optional[str] = None,
397
+ user_id: Optional[str] = None, # Optional in dev, ignored in prod
398
+ file_name: Optional[str] = None
399
+ ):
400
+ """Delete a document from the knowledge base."""
401
+ # SECURITY: Get tenant_id and user_id from auth context
402
+ auth_context = await get_auth_context(request)
403
+ tenant_id_from_auth = auth_context.get("tenant_id")
404
+ user_id_from_auth = auth_context.get("user_id")
405
+
406
+ if settings.ENV == "prod":
407
+ if not tenant_id_from_auth or not user_id_from_auth:
408
+ raise HTTPException(
409
+ status_code=403,
410
+ detail="tenant_id and user_id must come from authentication token in production mode"
411
+ )
412
+ tenant_id = tenant_id_from_auth
413
+ user_id = user_id_from_auth
414
+ else:
415
+ tenant_id = tenant_id or tenant_id_from_auth
416
+ user_id = user_id or user_id_from_auth
417
+ if not tenant_id or not kb_id or not user_id or not file_name:
418
+ raise HTTPException(
419
+ status_code=400,
420
+ detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)"
421
+ )
422
+
423
+ try:
424
+ vector_store = get_vector_store()
425
+ deleted = vector_store.delete_by_filter({
426
+ "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
427
+ "kb_id": kb_id,
428
+ "user_id": user_id,
429
+ "file_name": file_name
430
+ })
431
+
432
+ return {
433
+ "success": True,
434
+ "message": f"Deleted {deleted} chunks",
435
+ "file_name": file_name
436
+ }
437
+ except Exception as e:
438
+ logger.error(f"Error deleting document: {e}")
439
+ raise HTTPException(status_code=500, detail=str(e))
440
+
441
+
442
+ @app.delete("/kb/clear")
443
+ async def clear_kb(
444
+ request: Request,
445
+ tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
446
+ kb_id: Optional[str] = None,
447
+ user_id: Optional[str] = None # Optional in dev, ignored in prod
448
+ ):
449
+ """Clear all documents from a knowledge base."""
450
+ # SECURITY: Get tenant_id and user_id from auth context
451
+ auth_context = await get_auth_context(request)
452
+ tenant_id_from_auth = auth_context.get("tenant_id")
453
+ user_id_from_auth = auth_context.get("user_id")
454
+
455
+ if settings.ENV == "prod":
456
+ if not tenant_id_from_auth or not user_id_from_auth:
457
+ raise HTTPException(
458
+ status_code=403,
459
+ detail="tenant_id and user_id must come from authentication token in production mode"
460
+ )
461
+ tenant_id = tenant_id_from_auth
462
+ user_id = user_id_from_auth
463
+ else:
464
+ tenant_id = tenant_id or tenant_id_from_auth
465
+ user_id = user_id or user_id_from_auth
466
+ if not tenant_id or not kb_id or not user_id:
467
+ raise HTTPException(
468
+ status_code=400,
469
+ detail="tenant_id, kb_id, and user_id are required"
470
+ )
471
+ try:
472
+ vector_store = get_vector_store()
473
+ deleted = vector_store.delete_by_filter({
474
+ "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
475
+ "kb_id": kb_id,
476
+ "user_id": user_id
477
+ })
478
+
479
+ return {
480
+ "success": True,
481
+ "message": f"Cleared knowledge base. Deleted {deleted} chunks.",
482
+ "kb_id": kb_id
483
+ }
484
+ except Exception as e:
485
+ logger.error(f"Error clearing KB: {e}")
486
+ raise HTTPException(status_code=500, detail=str(e))
487
+
488
+
489
+ # ============== Chat Endpoints ==============
490
+
491
+ @app.post("/chat", response_model=ChatResponse)
492
+ @limiter.limit("10/minute", key_func=get_tenant_rate_limit_key)
493
+ async def chat(chat_request: ChatRequest, request: Request):
494
+ """
495
+ Process a chat message using RAG.
496
+
497
+ - Retrieves relevant context from knowledge base
498
+ - Generates answer using LLM
499
+ - Returns answer with citations
500
+ """
501
+ conversation_id = "unknown"
502
+ try:
503
+ logger.info(f"=== CHAT REQUEST RECEIVED ===")
504
+ logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}")
505
+ logger.info(f"Request headers: {dict(request.headers)}")
506
+
507
+ # SECURITY: Get tenant_id and user_id from auth context
508
+ # In PROD: MUST come from JWT token (never from request body)
509
+ try:
510
+ auth_context = await get_auth_context(request)
511
+ except Exception as e:
512
+ logger.error(f"Error getting auth context: {e}", exc_info=True)
513
+ raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}")
514
+
515
+ tenant_id_from_auth = auth_context.get("tenant_id")
516
+ user_id_from_auth = auth_context.get("user_id")
517
+
518
+ if settings.ENV == "prod":
519
+ if not tenant_id_from_auth or not user_id_from_auth:
520
+ raise HTTPException(
521
+ status_code=403,
522
+ detail="tenant_id and user_id must come from authentication token in production mode"
523
+ )
524
+ # Override request values with auth context (security enforcement)
525
+ chat_request.tenant_id = tenant_id_from_auth
526
+ chat_request.user_id = user_id_from_auth
527
+ else:
528
+ # DEV mode: use from request if provided, otherwise from auth context
529
+ if not chat_request.tenant_id:
530
+ chat_request.tenant_id = tenant_id_from_auth
531
+ if not chat_request.user_id:
532
+ chat_request.user_id = user_id_from_auth
533
+ if not chat_request.tenant_id or not chat_request.user_id:
534
+ raise HTTPException(
535
+ status_code=400,
536
+ detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)"
537
+ )
538
+
539
+ # Log without PII in production
540
+ if settings.ENV == "prod":
541
+ logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}")
542
+ else:
543
+ logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...")
544
+
545
+ # Generate conversation ID if not provided
546
+ conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}"
547
+
548
+ # Get database session
549
+ try:
550
+ db = next(get_db())
551
+ except Exception as e:
552
+ logger.error(f"Database connection error: {e}", exc_info=True)
553
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
554
+
555
+ try:
556
+ # Ensure tenant exists in billing DB
557
+ ensure_tenant_exists(db, chat_request.tenant_id)
558
+
559
+ # Check quota BEFORE making LLM call
560
+ has_quota, quota_error = check_quota(db, chat_request.tenant_id)
561
+ if not has_quota:
562
+ logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}")
563
+ raise HTTPException(
564
+ status_code=402,
565
+ detail=quota_error or "AI quota exceeded. Upgrade your plan."
566
+ )
567
+
568
+ # Retrieve relevant context
569
+ retrieval_service = get_retrieval_service()
570
+ results, confidence, has_relevant = retrieval_service.retrieve(
571
+ query=chat_request.question,
572
+ tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation
573
+ kb_id=chat_request.kb_id,
574
+ user_id=chat_request.user_id
575
+ )
576
+
577
+ logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}")
578
+
579
+ # Format context for LLM
580
+ context, citations_info = retrieval_service.get_context_for_llm(results)
581
+
582
+ logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}")
583
+
584
+ # Generate answer
585
+ answer_service = get_answer_service()
586
+ answer_result = answer_service.generate_answer(
587
+ question=chat_request.question,
588
+ context=context,
589
+ citations_info=citations_info,
590
+ confidence=confidence,
591
+ has_relevant_results=has_relevant
592
+ )
593
+
594
+ # Track usage if LLM was called (usage info present)
595
+ usage_info = answer_result.get("usage")
596
+ if usage_info:
597
+ try:
598
+ track_usage(
599
+ db=db,
600
+ tenant_id=chat_request.tenant_id,
601
+ user_id=chat_request.user_id,
602
+ kb_id=chat_request.kb_id,
603
+ provider=settings.LLM_PROVIDER,
604
+ model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL),
605
+ prompt_tokens=usage_info.get("prompt_tokens", 0),
606
+ completion_tokens=usage_info.get("completion_tokens", 0)
607
+ )
608
+ except Exception as e:
609
+ logger.error(f"Failed to track usage: {e}", exc_info=True)
610
+ # Don't fail the request if usage tracking fails
611
+
612
+ # Build metadata with refusal info
613
+ metadata = {
614
+ "chunks_retrieved": len(results),
615
+ "kb_id": chat_request.kb_id
616
+ }
617
+ if "refused" in answer_result:
618
+ metadata["refused"] = answer_result["refused"]
619
+ if "refusal_reason" in answer_result:
620
+ metadata["refusal_reason"] = answer_result["refusal_reason"]
621
+ if "verifier_passed" in answer_result:
622
+ metadata["verifier_passed"] = answer_result["verifier_passed"]
623
+
624
+ return ChatResponse(
625
+ success=True,
626
+ answer=answer_result["answer"],
627
+ citations=answer_result["citations"],
628
+ confidence=answer_result["confidence"],
629
+ from_knowledge_base=answer_result["from_knowledge_base"],
630
+ escalation_suggested=answer_result["escalation_suggested"],
631
+ conversation_id=conversation_id,
632
+ refused=answer_result.get("refused", False),
633
+ metadata=metadata
634
+ )
635
+
636
+ except ValueError as e:
637
+ # API key or configuration error
638
+ error_msg = str(e)
639
+ logger.error(f"Configuration error: {error_msg}")
640
+ if "API key" in error_msg.lower():
641
+ return ChatResponse(
642
+ success=False,
643
+ answer="⚠️ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.",
644
+ citations=[],
645
+ confidence=0.0,
646
+ from_knowledge_base=False,
647
+ escalation_suggested=True,
648
+ conversation_id=conversation_id,
649
+ metadata={"error": error_msg, "error_type": "configuration"}
650
+ )
651
+ else:
652
+ return ChatResponse(
653
+ success=False,
654
+ answer=f"Configuration error: {error_msg}",
655
+ citations=[],
656
+ confidence=0.0,
657
+ from_knowledge_base=False,
658
+ escalation_suggested=True,
659
+ conversation_id=conversation_id,
660
+ metadata={"error": error_msg}
661
+ )
662
+ except HTTPException:
663
+ # Re-raise HTTP exceptions (they have proper status codes)
664
+ raise
665
+ except Exception as e:
666
+ logger.error(f"Chat error: {e}", exc_info=True)
667
+ logger.error(f"Error type: {type(e).__name__}", exc_info=True)
668
+ return ChatResponse(
669
+ success=False,
670
+ answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
671
+ citations=[],
672
+ confidence=0.0,
673
+ from_knowledge_base=False,
674
+ escalation_suggested=True,
675
+ conversation_id=conversation_id,
676
+ metadata={"error": str(e), "error_type": type(e).__name__}
677
+ )
678
+ except HTTPException:
679
+ # Re-raise HTTP exceptions from outer try block
680
+ raise
681
+ except Exception as e:
682
+ logger.error(f"Outer chat error: {e}", exc_info=True)
683
+ return ChatResponse(
684
+ success=False,
685
+ answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
686
+ citations=[],
687
+ confidence=0.0,
688
+ from_knowledge_base=False,
689
+ escalation_suggested=True,
690
+ conversation_id=conversation_id,
691
+ metadata={"error": str(e), "error_type": type(e).__name__}
692
+ )
693
+
694
+
695
+ # ============== Utility Endpoints ==============
696
+
697
+ @app.get("/kb/search")
698
+ @limiter.limit("30/minute", key_func=get_tenant_rate_limit_key)
699
+ async def search_kb(
700
+ request: Request,
701
+ query: str,
702
+ tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
703
+ kb_id: Optional[str] = None,
704
+ user_id: Optional[str] = None, # Optional in dev, ignored in prod
705
+ top_k: int = 5
706
+ ):
707
+ """
708
+ Search the knowledge base without generating an answer.
709
+ Useful for debugging and testing retrieval.
710
+ """
711
+ # SECURITY: Extract tenant_id from auth token in production
712
+ if settings.ENV == "prod":
713
+ auth_context = await require_auth(request)
714
+ tenant_id = auth_context.get("tenant_id")
715
+ user_id = auth_context.get("user_id")
716
+ if not tenant_id or not user_id:
717
+ raise HTTPException(
718
+ status_code=403,
719
+ detail="tenant_id and user_id must come from authentication token in production mode"
720
+ )
721
+ elif not tenant_id or not kb_id or not user_id:
722
+ raise HTTPException(
723
+ status_code=400,
724
+ detail="tenant_id, kb_id, and user_id are required"
725
+ )
726
+
727
+ try:
728
+ retrieval_service = get_retrieval_service()
729
+ results, confidence, has_relevant = retrieval_service.retrieve(
730
+ query=query,
731
+ tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
732
+ kb_id=kb_id,
733
+ user_id=user_id,
734
+ top_k=top_k
735
+ )
736
+
737
+ return {
738
+ "success": True,
739
+ "results": [
740
+ {
741
+ "chunk_id": r.chunk_id,
742
+ "content": r.content[:500] + "..." if len(r.content) > 500 else r.content,
743
+ "metadata": r.metadata,
744
+ "similarity_score": r.similarity_score
745
+ }
746
+ for r in results
747
+ ],
748
+ "confidence": confidence,
749
+ "has_relevant_results": has_relevant
750
+ }
751
+ except Exception as e:
752
+ logger.error(f"Search error: {e}")
753
+ raise HTTPException(status_code=500, detail=str(e))
754
+
755
+
756
+ # ============== Billing & Usage Endpoints ==============
757
+
758
+ @app.get("/billing/usage", response_model=UsageResponse)
759
+ async def get_usage(
760
+ request: Request,
761
+ range: str = "month", # "day" or "month"
762
+ year: Optional[int] = None,
763
+ month: Optional[int] = None,
764
+ day: Optional[int] = None
765
+ ):
766
+ """
767
+ Get usage statistics for the current tenant.
768
+
769
+ Args:
770
+ range: "day" or "month"
771
+ year: Year (optional, defaults to current)
772
+ month: Month 1-12 (optional, defaults to current)
773
+ day: Day 1-31 (optional, defaults to current, only for range="day")
774
+ """
775
+ # Get tenant from auth
776
+ auth_context = await get_auth_context(request)
777
+ tenant_id = auth_context.get("tenant_id")
778
+
779
+ if not tenant_id:
780
+ raise HTTPException(status_code=403, detail="tenant_id required")
781
+
782
+ db = next(get_db())
783
+
784
+ try:
785
+ from app.db.models import UsageDaily, UsageMonthly
786
+ from datetime import datetime
787
+ from calendar import monthrange
788
+
789
+ now = datetime.utcnow()
790
+ target_year = year or now.year
791
+ target_month = month or now.month
792
+
793
+ if range == "day":
794
+ target_day = day or now.day
795
+ date_start = datetime(target_year, target_month, target_day)
796
+
797
+ daily = db.query(UsageDaily).filter(
798
+ UsageDaily.tenant_id == tenant_id,
799
+ UsageDaily.date == date_start
800
+ ).first()
801
+
802
+ if not daily:
803
+ return UsageResponse(
804
+ tenant_id=tenant_id,
805
+ period="day",
806
+ total_requests=0,
807
+ total_tokens=0,
808
+ total_cost_usd=0.0,
809
+ start_date=date_start,
810
+ end_date=date_start
811
+ )
812
+
813
+ return UsageResponse(
814
+ tenant_id=tenant_id,
815
+ period="day",
816
+ total_requests=daily.total_requests,
817
+ total_tokens=daily.total_tokens,
818
+ total_cost_usd=daily.total_cost_usd,
819
+ gemini_requests=daily.gemini_requests,
820
+ openai_requests=daily.openai_requests,
821
+ start_date=daily.date,
822
+ end_date=daily.date
823
+ )
824
+ else: # month
825
+ monthly = db.query(UsageMonthly).filter(
826
+ UsageMonthly.tenant_id == tenant_id,
827
+ UsageMonthly.year == target_year,
828
+ UsageMonthly.month == target_month
829
+ ).first()
830
+
831
+ if not monthly:
832
+ # Calculate date range for the month
833
+ _, last_day = monthrange(target_year, target_month)
834
+ start_date = datetime(target_year, target_month, 1)
835
+ end_date = datetime(target_year, target_month, last_day)
836
+
837
+ return UsageResponse(
838
+ tenant_id=tenant_id,
839
+ period="month",
840
+ total_requests=0,
841
+ total_tokens=0,
842
+ total_cost_usd=0.0,
843
+ start_date=start_date,
844
+ end_date=end_date
845
+ )
846
+
847
+ _, last_day = monthrange(monthly.year, monthly.month)
848
+ start_date = datetime(monthly.year, monthly.month, 1)
849
+ end_date = datetime(monthly.year, monthly.month, last_day)
850
+
851
+ return UsageResponse(
852
+ tenant_id=tenant_id,
853
+ period="month",
854
+ total_requests=monthly.total_requests,
855
+ total_tokens=monthly.total_tokens,
856
+ total_cost_usd=monthly.total_cost_usd,
857
+ gemini_requests=monthly.gemini_requests,
858
+ openai_requests=monthly.openai_requests,
859
+ start_date=start_date,
860
+ end_date=end_date
861
+ )
862
+ except Exception as e:
863
+ logger.error(f"Error getting usage: {e}", exc_info=True)
864
+ raise HTTPException(status_code=500, detail=str(e))
865
+
866
+
867
+ @app.get("/billing/limits", response_model=PlanLimitsResponse)
868
+ async def get_limits(request: Request):
869
+ """Get current plan limits and usage for the tenant."""
870
+ # Get tenant from auth
871
+ auth_context = await get_auth_context(request)
872
+ tenant_id = auth_context.get("tenant_id")
873
+
874
+ if not tenant_id:
875
+ raise HTTPException(status_code=403, detail="tenant_id required")
876
+
877
+ db = next(get_db())
878
+
879
+ try:
880
+ from app.billing.quota import get_tenant_plan, get_monthly_usage
881
+ from datetime import datetime
882
+
883
+ plan = get_tenant_plan(db, tenant_id)
884
+ if not plan:
885
+ # Default to starter
886
+ plan_name = "starter"
887
+ monthly_limit = 500
888
+ else:
889
+ plan_name = plan.plan_name
890
+ monthly_limit = plan.monthly_chat_limit
891
+
892
+ # Get current month usage
893
+ now = datetime.utcnow()
894
+ monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
895
+ current_usage = monthly_usage.total_requests if monthly_usage else 0
896
+
897
+ remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage)
898
+
899
+ return PlanLimitsResponse(
900
+ tenant_id=tenant_id,
901
+ plan_name=plan_name,
902
+ monthly_chat_limit=monthly_limit,
903
+ current_month_usage=current_usage,
904
+ remaining_chats=remaining
905
+ )
906
+ except Exception as e:
907
+ logger.error(f"Error getting limits: {e}", exc_info=True)
908
+ raise HTTPException(status_code=500, detail=str(e))
909
+
910
+
911
+ @app.post("/billing/plan")
912
+ async def set_plan(request_body: SetPlanRequest, http_request: Request):
913
+ """
914
+ Set tenant's subscription plan (admin only in production).
915
+
916
+ In dev mode, allows any tenant to set their plan.
917
+ In prod mode, should be restricted to admin users.
918
+ """
919
+ # Get tenant from auth
920
+ auth_context = await get_auth_context(http_request)
921
+ auth_tenant_id = auth_context.get("tenant_id")
922
+
923
+ # In prod, verify admin role (placeholder - implement actual admin check)
924
+ if settings.ENV == "prod":
925
+ # TODO: Add admin role check
926
+ if auth_tenant_id != request_body.tenant_id:
927
+ raise HTTPException(status_code=403, detail="Cannot set plan for other tenants")
928
+
929
+ db = next(get_db())
930
+
931
+ try:
932
+ from app.billing.quota import set_tenant_plan
933
+
934
+ plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name)
935
+
936
+ return {
937
+ "success": True,
938
+ "tenant_id": request_body.tenant_id,
939
+ "plan_name": plan.plan_name,
940
+ "monthly_chat_limit": plan.monthly_chat_limit
941
+ }
942
+ except ValueError as e:
943
+ raise HTTPException(status_code=400, detail=str(e))
944
+ except Exception as e:
945
+ logger.error(f"Error setting plan: {e}", exc_info=True)
946
+ raise HTTPException(status_code=500, detail=str(e))
947
+
948
+
949
+ @app.get("/billing/cost-report", response_model=CostReportResponse)
950
+ async def get_cost_report(
951
+ request: Request,
952
+ range: str = "month",
953
+ year: Optional[int] = None,
954
+ month: Optional[int] = None
955
+ ):
956
+ """Get cost report with breakdown by provider and model."""
957
+ # Get tenant from auth
958
+ auth_context = await get_auth_context(request)
959
+ tenant_id = auth_context.get("tenant_id")
960
+
961
+ if not tenant_id:
962
+ raise HTTPException(status_code=403, detail="tenant_id required")
963
+
964
+ db = next(get_db())
965
+
966
+ try:
967
+ from app.db.models import UsageEvent
968
+ from datetime import datetime
969
+ from sqlalchemy import func, and_
970
+
971
+ now = datetime.utcnow()
972
+ target_year = year or now.year
973
+ target_month = month or now.month
974
+
975
+ # Query usage events for the period
976
+ if range == "month":
977
+ query = db.query(UsageEvent).filter(
978
+ and_(
979
+ UsageEvent.tenant_id == tenant_id,
980
+ func.extract('year', UsageEvent.request_timestamp) == target_year,
981
+ func.extract('month', UsageEvent.request_timestamp) == target_month
982
+ )
983
+ )
984
+ else: # all time
985
+ query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id)
986
+
987
+ events = query.all()
988
+
989
+ # Calculate totals
990
+ total_cost = sum(e.estimated_cost_usd for e in events)
991
+ total_requests = len(events)
992
+ total_tokens = sum(e.total_tokens for e in events)
993
+
994
+ # Breakdown by provider
995
+ breakdown_by_provider = {}
996
+ for event in events:
997
+ provider = event.provider
998
+ if provider not in breakdown_by_provider:
999
+ breakdown_by_provider[provider] = {
1000
+ "requests": 0,
1001
+ "tokens": 0,
1002
+ "cost_usd": 0.0
1003
+ }
1004
+ breakdown_by_provider[provider]["requests"] += 1
1005
+ breakdown_by_provider[provider]["tokens"] += event.total_tokens
1006
+ breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd
1007
+
1008
+ # Breakdown by model
1009
+ breakdown_by_model = {}
1010
+ for event in events:
1011
+ model = event.model
1012
+ if model not in breakdown_by_model:
1013
+ breakdown_by_model[model] = {
1014
+ "requests": 0,
1015
+ "tokens": 0,
1016
+ "cost_usd": 0.0
1017
+ }
1018
+ breakdown_by_model[model]["requests"] += 1
1019
+ breakdown_by_model[model]["tokens"] += event.total_tokens
1020
+ breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd
1021
+
1022
+ return CostReportResponse(
1023
+ tenant_id=tenant_id,
1024
+ period=range,
1025
+ total_cost_usd=total_cost,
1026
+ total_requests=total_requests,
1027
+ total_tokens=total_tokens,
1028
+ breakdown_by_provider=breakdown_by_provider,
1029
+ breakdown_by_model=breakdown_by_model
1030
+ )
1031
+ except Exception as e:
1032
+ logger.error(f"Error getting cost report: {e}", exc_info=True)
1033
+ raise HTTPException(status_code=500, detail=str(e))
1034
+
1035
+
1036
+ if __name__ == "__main__":
1037
+ import uvicorn
1038
+ uvicorn.run(app, host="0.0.0.0", port=8000)
1039
+
app/middleware/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Middleware for authentication, rate limiting, etc.
3
+ """
4
+ from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth
5
+
6
+ __all__ = [
7
+ "verify_tenant_access",
8
+ "get_tenant_from_token",
9
+ "require_auth",
10
+ ]
11
+
12
+
13
+
app/middleware/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (370 Bytes). View file
 
app/middleware/__pycache__/auth.cpython-313.pyc ADDED
Binary file (6.5 kB). View file
 
app/middleware/__pycache__/rate_limit.cpython-313.pyc ADDED
Binary file (1.45 kB). View file
 
app/middleware/auth.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication and authorization middleware.
3
+ Extracts tenant_id from JWT token in production mode.
4
+ """
5
+ from fastapi import Request, HTTPException, Depends
6
+ from typing import Optional, Dict, Any
7
+ import logging
8
+ from jose import JWTError, jwt
9
+
10
+ from app.config import settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ async def verify_tenant_access(
16
+ request: Request,
17
+ tenant_id: str,
18
+ user_id: str
19
+ ) -> bool:
20
+ """
21
+ Verify that the user has access to the specified tenant.
22
+
23
+ TODO: Implement actual authentication logic:
24
+ 1. Extract JWT token from Authorization header
25
+ 2. Verify token signature
26
+ 3. Extract user_id and tenant_id from token
27
+ 4. Verify user belongs to tenant
28
+ 5. Check permissions
29
+
30
+ Args:
31
+ request: FastAPI request object
32
+ tenant_id: Tenant ID from request
33
+ user_id: User ID from request
34
+
35
+ Returns:
36
+ True if access is granted, False otherwise
37
+ """
38
+ # TODO: Implement actual authentication
39
+ # For now, this is a placeholder that always returns True
40
+ # In production, you MUST implement proper auth
41
+
42
+ # Example implementation:
43
+ # token = request.headers.get("Authorization", "").replace("Bearer ", "")
44
+ # if not token:
45
+ # return False
46
+ #
47
+ # decoded = verify_jwt_token(token)
48
+ # if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id:
49
+ # return False
50
+ #
51
+ # return True
52
+
53
+ logger.warning("⚠️ Authentication middleware not implemented - using placeholder")
54
+ return True
55
+
56
+
57
+ def get_tenant_from_token(request: Request) -> Optional[str]:
58
+ """
59
+ Extract tenant_id from authentication token.
60
+
61
+ In production mode, extracts tenant_id from JWT token.
62
+ In dev mode, returns None (allows request tenant_id).
63
+
64
+ Args:
65
+ request: FastAPI request object
66
+
67
+ Returns:
68
+ Tenant ID if found in token, None otherwise
69
+ """
70
+ if settings.ENV == "dev":
71
+ # Dev mode: allow request tenant_id
72
+ return None
73
+
74
+ # Production mode: extract from JWT
75
+ auth_header = request.headers.get("Authorization", "")
76
+ if not auth_header.startswith("Bearer "):
77
+ logger.warning("Missing or invalid Authorization header")
78
+ return None
79
+
80
+ token = auth_header.replace("Bearer ", "").strip()
81
+ if not token:
82
+ return None
83
+
84
+ try:
85
+ # TODO: Replace with your actual JWT secret key
86
+ # For now, this is a placeholder that expects a specific token format
87
+ # In production, you should:
88
+ # 1. Get JWT_SECRET from environment
89
+ # 2. Verify token signature
90
+ # 3. Extract tenant_id from token payload
91
+
92
+ # Example implementation (replace with your actual JWT verification):
93
+ # JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key")
94
+ # decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
95
+ # return decoded.get("tenant_id")
96
+
97
+ # Placeholder: Try to decode without verification (for testing)
98
+ # In production, you MUST verify the signature
99
+ try:
100
+ decoded = jwt.decode(token, options={"verify_signature": False})
101
+ tenant_id = decoded.get("tenant_id")
102
+ if tenant_id:
103
+ logger.info(f"Extracted tenant_id from token: {tenant_id}")
104
+ return tenant_id
105
+ except jwt.DecodeError:
106
+ logger.warning("Failed to decode JWT token")
107
+ return None
108
+
109
+ except Exception as e:
110
+ logger.error(f"Error extracting tenant from token: {e}")
111
+ return None
112
+
113
+ return None
114
+
115
+
116
+ async def get_auth_context(request: Request) -> Dict[str, Any]:
117
+ """
118
+ Get authentication context from request.
119
+
120
+ DEV mode:
121
+ - Allows X-Tenant-Id and X-User-Id headers
122
+ - Falls back to defaults if missing
123
+
124
+ PROD mode:
125
+ - Requires Authorization: Bearer <JWT>
126
+ - Verifies JWT using JWT_SECRET
127
+ - Extracts tenant_id and user_id from token claims
128
+ - NEVER accepts tenant_id from request body/query params
129
+
130
+ Args:
131
+ request: FastAPI request object
132
+
133
+ Returns:
134
+ Dictionary with user_id and tenant_id
135
+
136
+ Raises:
137
+ HTTPException: If authentication fails (production mode only)
138
+ """
139
+ if settings.ENV == "dev":
140
+ # Dev mode: allow headers, fallback to defaults
141
+ tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant")
142
+ user_id = request.headers.get("X-User-Id", "dev_user")
143
+ return {
144
+ "user_id": user_id,
145
+ "tenant_id": tenant_id
146
+ }
147
+
148
+ # Production mode: require JWT token
149
+ auth_header = request.headers.get("Authorization")
150
+ if not auth_header or not auth_header.startswith("Bearer "):
151
+ raise HTTPException(
152
+ status_code=401,
153
+ detail="Authentication required. Provide valid Bearer token in Authorization header."
154
+ )
155
+
156
+ token = auth_header.replace("Bearer ", "").strip()
157
+ if not token:
158
+ raise HTTPException(
159
+ status_code=401,
160
+ detail="Invalid token format."
161
+ )
162
+
163
+ # Verify JWT token
164
+ if not settings.JWT_SECRET:
165
+ logger.error("JWT_SECRET not configured in production mode")
166
+ raise HTTPException(
167
+ status_code=500,
168
+ detail="Server configuration error: JWT_SECRET not set"
169
+ )
170
+
171
+ try:
172
+ decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
173
+
174
+ user_id = decoded.get("user_id") or decoded.get("sub")
175
+ tenant_id = decoded.get("tenant_id")
176
+
177
+ if not user_id or not tenant_id:
178
+ raise HTTPException(
179
+ status_code=401,
180
+ detail="Token missing required claims (user_id, tenant_id)."
181
+ )
182
+
183
+ logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}")
184
+ return {
185
+ "user_id": user_id,
186
+ "tenant_id": tenant_id,
187
+ "email": decoded.get("email"),
188
+ "role": decoded.get("role")
189
+ }
190
+
191
+ except JWTError as e:
192
+ logger.warning(f"JWT verification failed: {e}")
193
+ raise HTTPException(
194
+ status_code=401,
195
+ detail="Invalid or expired token."
196
+ )
197
+ except Exception as e:
198
+ logger.error(f"Auth error: {e}", exc_info=True)
199
+ raise HTTPException(
200
+ status_code=401,
201
+ detail="Authentication failed."
202
+ )
203
+
204
+
205
+ # FastAPI dependency for easy use in endpoints
206
+ async def require_auth(request: Request) -> Dict[str, Any]:
207
+ """
208
+ FastAPI dependency for requiring authentication.
209
+ Alias for get_auth_context for backward compatibility.
210
+ """
211
+ return await get_auth_context(request)
212
+
app/middleware/rate_limit.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rate limiting middleware using slowapi.
3
+ """
4
+ from slowapi import Limiter, _rate_limit_exceeded_handler
5
+ from slowapi.util import get_remote_address
6
+ from slowapi.errors import RateLimitExceeded
7
+ from fastapi import Request
8
+ import logging
9
+
10
+ from app.config import settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Initialize limiter with default limits (can be overridden per endpoint)
15
+ limiter = Limiter(
16
+ key_func=get_remote_address,
17
+ default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else []
18
+ )
19
+
20
+
21
+ def get_tenant_rate_limit_key(request: Request) -> str:
22
+ """
23
+ Get rate limit key based on tenant_id from headers (dev) or IP (prod).
24
+
25
+ Note: This is a sync function called by slowapi, so we can't await async functions.
26
+ In dev mode, we extract tenant_id from X-Tenant-Id header.
27
+ In prod mode, we fall back to IP address (rate limiting happens before auth).
28
+ """
29
+ # Try to get tenant_id from headers (works in dev mode)
30
+ tenant_id = request.headers.get("X-Tenant-Id")
31
+ if tenant_id:
32
+ return f"tenant:{tenant_id}"
33
+
34
+ # Fallback to IP address (for prod mode or if no header)
35
+ return get_remote_address(request)
36
+
37
+
38
+ # Export limiter and key function
39
+ __all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"]
40
+
app/models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for the RAG backend.
3
+ """
4
+ from app.models.schemas import (
5
+ DocumentStatus,
6
+ ChunkMetadata,
7
+ DocumentChunk,
8
+ UploadRequest,
9
+ UploadResponse,
10
+ Citation,
11
+ ChatRequest,
12
+ ChatResponse,
13
+ RetrievalResult,
14
+ KnowledgeBaseStats,
15
+ HealthResponse,
16
+ )
17
+
18
+ __all__ = [
19
+ "DocumentStatus",
20
+ "ChunkMetadata",
21
+ "DocumentChunk",
22
+ "UploadRequest",
23
+ "UploadResponse",
24
+ "Citation",
25
+ "ChatRequest",
26
+ "ChatResponse",
27
+ "RetrievalResult",
28
+ "KnowledgeBaseStats",
29
+ "HealthResponse",
30
+ ]
31
+
32
+
33
+
app/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (542 Bytes). View file
 
app/models/__pycache__/billing_schemas.cpython-313.pyc ADDED
Binary file (2.09 kB). View file
 
app/models/__pycache__/schemas.cpython-313.pyc ADDED
Binary file (5.26 kB). View file
 
app/models/billing_schemas.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for billing endpoints.
3
+ """
4
+ from pydantic import BaseModel
5
+ from typing import Optional, List
6
+ from datetime import datetime
7
+
8
+
9
+ class UsageResponse(BaseModel):
10
+ """Usage statistics response."""
11
+ tenant_id: str
12
+ period: str # "day" or "month"
13
+ total_requests: int
14
+ total_tokens: int
15
+ total_cost_usd: float
16
+ gemini_requests: int = 0
17
+ openai_requests: int = 0
18
+ start_date: datetime
19
+ end_date: datetime
20
+
21
+
22
+ class PlanLimitsResponse(BaseModel):
23
+ """Current plan limits response."""
24
+ tenant_id: str
25
+ plan_name: str
26
+ monthly_chat_limit: int # -1 for unlimited
27
+ current_month_usage: int
28
+ remaining_chats: Optional[int] # None if unlimited
29
+
30
+
31
+ class CostReportResponse(BaseModel):
32
+ """Cost report response."""
33
+ tenant_id: str
34
+ period: str
35
+ total_cost_usd: float
36
+ total_requests: int
37
+ total_tokens: int
38
+ breakdown_by_provider: dict
39
+ breakdown_by_model: dict
40
+
41
+
42
+ class SetPlanRequest(BaseModel):
43
+ """Request to set tenant plan."""
44
+ tenant_id: str
45
+ plan_name: str # "starter", "growth", or "pro"
46
+
app/models/schemas.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for API request/response schemas.
3
+ """
4
+ from pydantic import BaseModel, Field
5
+ from typing import List, Optional, Dict, Any
6
+ from datetime import datetime
7
+ from enum import Enum
8
+
9
+
10
+ class DocumentStatus(str, Enum):
11
+ PENDING = "pending"
12
+ PROCESSING = "processing"
13
+ COMPLETED = "completed"
14
+ FAILED = "failed"
15
+
16
+
17
+ class ChunkMetadata(BaseModel):
18
+ """Metadata for a document chunk."""
19
+ tenant_id: str # CRITICAL: Multi-tenant isolation
20
+ kb_id: str
21
+ user_id: str
22
+ file_name: str
23
+ file_type: str
24
+ chunk_id: str
25
+ chunk_index: int
26
+ page_number: Optional[int] = None
27
+ total_chunks: int
28
+ document_id: Optional[str] = None # Track original document
29
+ created_at: datetime = Field(default_factory=datetime.utcnow)
30
+
31
+
32
+ class DocumentChunk(BaseModel):
33
+ """A chunk of text with metadata."""
34
+ id: str
35
+ content: str
36
+ metadata: ChunkMetadata
37
+ embedding: Optional[List[float]] = None
38
+
39
+
40
+ class UploadRequest(BaseModel):
41
+ """Request model for file upload."""
42
+ tenant_id: str # CRITICAL: Multi-tenant isolation
43
+ user_id: str
44
+ kb_id: str
45
+
46
+
47
+ class UploadResponse(BaseModel):
48
+ """Response model for file upload."""
49
+ success: bool
50
+ message: str
51
+ document_id: Optional[str] = None
52
+ file_name: str
53
+ chunks_created: int = 0
54
+ status: DocumentStatus = DocumentStatus.PENDING
55
+
56
+
57
+ class Citation(BaseModel):
58
+ """Citation reference for an answer."""
59
+ file_name: str
60
+ chunk_id: str
61
+ page_number: Optional[int] = None
62
+ relevance_score: float
63
+ excerpt: str # Short excerpt from the chunk
64
+
65
+
66
+ class ChatRequest(BaseModel):
67
+ """Request model for chat endpoint."""
68
+ tenant_id: str # CRITICAL: Multi-tenant isolation
69
+ user_id: str
70
+ kb_id: str
71
+ conversation_id: Optional[str] = None
72
+ question: str
73
+
74
+
75
+ class ChatResponse(BaseModel):
76
+ """Response model for chat endpoint."""
77
+ success: bool
78
+ answer: str
79
+ citations: List[Citation] = []
80
+ confidence: float # 0-1 score
81
+ from_knowledge_base: bool = True
82
+ escalation_suggested: bool = False
83
+ conversation_id: str
84
+ metadata: Dict[str, Any] = {}
85
+
86
+
87
+ class RetrievalResult(BaseModel):
88
+ """Result from vector store retrieval."""
89
+ chunk_id: str
90
+ content: str
91
+ metadata: Dict[str, Any]
92
+ similarity_score: float
93
+
94
+
95
+ class KnowledgeBaseStats(BaseModel):
96
+ """Statistics for a knowledge base."""
97
+ tenant_id: str # CRITICAL: Multi-tenant isolation
98
+ kb_id: str
99
+ user_id: str
100
+ total_documents: int
101
+ total_chunks: int
102
+ file_names: List[str]
103
+ last_updated: Optional[datetime] = None
104
+
105
+
106
+ class HealthResponse(BaseModel):
107
+ """Health check response."""
108
+ status: str
109
+ version: str = "1.0.0"
110
+ vector_db_connected: bool
111
+ llm_configured: bool
112
+
app/rag/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG (Retrieval-Augmented Generation) pipeline modules.
3
+ """
4
+ from app.rag.ingest import parser, DocumentParser
5
+ from app.rag.chunking import chunker, DocumentChunker
6
+ from app.rag.embeddings import get_embedding_service, EmbeddingService
7
+ from app.rag.vectorstore import get_vector_store, VectorStore
8
+ from app.rag.retrieval import get_retrieval_service, RetrievalService
9
+ from app.rag.answer import get_answer_service, AnswerService
10
+
11
+ __all__ = [
12
+ "parser",
13
+ "DocumentParser",
14
+ "chunker",
15
+ "DocumentChunker",
16
+ "get_embedding_service",
17
+ "EmbeddingService",
18
+ "get_vector_store",
19
+ "VectorStore",
20
+ "get_retrieval_service",
21
+ "RetrievalService",
22
+ "get_answer_service",
23
+ "AnswerService",
24
+ ]
25
+
26
+
27
+
app/rag/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (799 Bytes). View file
 
app/rag/__pycache__/answer.cpython-313.pyc ADDED
Binary file (18.6 kB). View file
 
app/rag/__pycache__/chunking.cpython-313.pyc ADDED
Binary file (7.26 kB). View file
 
app/rag/__pycache__/embeddings.cpython-313.pyc ADDED
Binary file (6.06 kB). View file
 
app/rag/__pycache__/ingest.cpython-313.pyc ADDED
Binary file (10.5 kB). View file
 
app/rag/__pycache__/intent.cpython-313.pyc ADDED
Binary file (5.51 kB). View file
 
app/rag/__pycache__/prompts.cpython-313.pyc ADDED
Binary file (6.94 kB). View file
 
app/rag/__pycache__/retrieval.cpython-313.pyc ADDED
Binary file (9.74 kB). View file
 
app/rag/__pycache__/vectorstore.cpython-313.pyc ADDED
Binary file (9.35 kB). View file
 
app/rag/__pycache__/verifier.cpython-313.pyc ADDED
Binary file (9.77 kB). View file
 
app/rag/answer.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer generation using LLM with RAG context.
3
+ Supports Gemini and OpenAI as providers.
4
+ """
5
+ import google.generativeai as genai
6
+ from openai import OpenAI
7
+ from typing import Optional, Dict, Any, List
8
+ import logging
9
+ import os
10
+ import re
11
+
12
+ from app.config import settings
13
+ from app.rag.prompts import (
14
+ format_rag_prompt,
15
+ format_draft_prompt,
16
+ get_no_context_response,
17
+ get_low_confidence_response
18
+ )
19
+ from app.rag.verifier import get_verifier_service
20
+ from app.rag.intent import detect_intents
21
+ from app.models.schemas import Citation
22
+ from abc import ABC, abstractmethod
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class LLMProvider(ABC):
29
+ """Base class for LLM providers."""
30
+
31
+ @abstractmethod
32
+ def generate(self, system_prompt: str, user_prompt: str) -> str:
33
+ """Generate response from prompts."""
34
+ raise NotImplementedError
35
+
36
+ @abstractmethod
37
+ def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
38
+ """
39
+ Generate response and return usage information.
40
+
41
+ Returns:
42
+ (response_text, usage_info)
43
+ usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used
44
+ """
45
+ raise NotImplementedError
46
+
47
+
48
+ class GeminiProvider(LLMProvider):
49
+ """Google Gemini LLM provider."""
50
+
51
+ def __init__(self, api_key: Optional[str] = None, model: str = None):
52
+ self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY")
53
+ # Default to gemini-1.5-flash if not specified
54
+ self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash"
55
+
56
+ if not self.api_key:
57
+ raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.")
58
+
59
+ genai.configure(api_key=self.api_key)
60
+
61
+ # Don't initialize client here - do it lazily in generate() to handle errors better
62
+ self._client = None
63
+ logger.info(f"Gemini provider initialized (will use model: {self.model})")
64
+
65
+ def generate(self, system_prompt: str, user_prompt: str) -> str:
66
+ """Generate response using Gemini."""
67
+ text, _ = self.generate_with_usage(system_prompt, user_prompt)
68
+ return text
69
+
70
+ def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
71
+ """Generate response using Gemini and return usage info."""
72
+ # Combine system and user prompts for Gemini
73
+ full_prompt = f"{system_prompt}\n\n{user_prompt}"
74
+
75
+ # Estimate prompt tokens (rough: 1 token ≈ 4 chars)
76
+ prompt_tokens = len(full_prompt) // 4
77
+
78
+ # Try to list available models first, then use the first available one
79
+ # If that fails, try common model names
80
+ models_to_try = []
81
+
82
+ # First, try to get available models
83
+ try:
84
+ available_models = genai.list_models()
85
+ model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods]
86
+ if model_names:
87
+ # Extract just the model name (remove 'models/' prefix if present)
88
+ clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names]
89
+ models_to_try.extend(clean_names[:3]) # Use first 3 available models
90
+ logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}")
91
+ except Exception as e:
92
+ logger.warning(f"Could not list available models: {e}, using fallback list")
93
+
94
+ # Fallback to common model names if listing failed
95
+ if not models_to_try:
96
+ models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"]
97
+
98
+ # Add configured model if different
99
+ if self.model and self.model not in models_to_try:
100
+ models_to_try.insert(0, self.model)
101
+
102
+ # Remove duplicates while preserving order
103
+ seen = set()
104
+ models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))]
105
+
106
+ last_error = None
107
+ for model_name in models_to_try:
108
+ try:
109
+ logger.info(f"Attempting to generate with model: {model_name}")
110
+ # Create a new client for this model
111
+ client = genai.GenerativeModel(model_name)
112
+ response = client.generate_content(
113
+ full_prompt,
114
+ generation_config=genai.types.GenerationConfig(
115
+ temperature=settings.TEMPERATURE,
116
+ max_output_tokens=1024,
117
+ )
118
+ )
119
+
120
+ # Extract response text
121
+ response_text = response.text
122
+
123
+ # Try to get usage info from response
124
+ usage_info = {
125
+ "prompt_tokens": prompt_tokens,
126
+ "completion_tokens": len(response_text) // 4, # Estimate
127
+ "total_tokens": prompt_tokens + (len(response_text) // 4),
128
+ "model_used": model_name.split('/')[-1] if '/' in model_name else model_name
129
+ }
130
+
131
+ # Try to get actual usage from response if available
132
+ if hasattr(response, 'usage_metadata'):
133
+ usage_metadata = response.usage_metadata
134
+ if hasattr(usage_metadata, 'prompt_token_count'):
135
+ usage_info["prompt_tokens"] = usage_metadata.prompt_token_count
136
+ if hasattr(usage_metadata, 'candidates_token_count'):
137
+ usage_info["completion_tokens"] = usage_metadata.candidates_token_count
138
+ if hasattr(usage_metadata, 'total_token_count'):
139
+ usage_info["total_tokens"] = usage_metadata.total_token_count
140
+
141
+ if model_name != self.model:
142
+ logger.info(f"✅ Successfully used model: {model_name}")
143
+
144
+ return response_text, usage_info
145
+ except Exception as e:
146
+ error_str = str(e).lower()
147
+ last_error = e
148
+ if "not found" in error_str or "not supported" in error_str or "404" in error_str:
149
+ logger.warning(f"Model {model_name} failed: {e}")
150
+ continue # Try next model
151
+ else:
152
+ # Different error (not model not found), re-raise
153
+ logger.error(f"Gemini generation error with {model_name}: {e}")
154
+ raise
155
+
156
+ # All models failed - return a helpful error message
157
+ error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models."
158
+ logger.error(error_msg)
159
+ raise Exception(error_msg)
160
+
161
+
162
+ class OpenAIProvider(LLMProvider):
163
+ """OpenAI LLM provider."""
164
+
165
+ def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL):
166
+ self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY")
167
+ self.model = model
168
+
169
+ if not self.api_key:
170
+ raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.")
171
+
172
+ self.client = OpenAI(api_key=self.api_key)
173
+ logger.info(f"OpenAI provider initialized with model: {model}")
174
+
175
+ def generate(self, system_prompt: str, user_prompt: str) -> str:
176
+ """Generate response using OpenAI."""
177
+ text, _ = self.generate_with_usage(system_prompt, user_prompt)
178
+ return text
179
+
180
+ def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]:
181
+ """Generate response using OpenAI and return usage info."""
182
+ try:
183
+ response = self.client.chat.completions.create(
184
+ model=self.model,
185
+ messages=[
186
+ {"role": "system", "content": system_prompt},
187
+ {"role": "user", "content": user_prompt}
188
+ ],
189
+ temperature=settings.TEMPERATURE,
190
+ max_tokens=1024
191
+ )
192
+
193
+ response_text = response.choices[0].message.content
194
+
195
+ # Extract usage info from OpenAI response
196
+ usage_info = {
197
+ "prompt_tokens": response.usage.prompt_tokens,
198
+ "completion_tokens": response.usage.completion_tokens,
199
+ "total_tokens": response.usage.total_tokens,
200
+ "model_used": self.model
201
+ }
202
+
203
+ return response_text, usage_info
204
+ except Exception as e:
205
+ logger.error(f"OpenAI generation error: {e}")
206
+ raise
207
+
208
+
209
+ class AnswerService:
210
+ """
211
+ Generates answers using RAG context and LLM.
212
+ Handles confidence scoring and citation extraction.
213
+ """
214
+
215
+ # Confidence thresholds
216
+ HIGH_CONFIDENCE_THRESHOLD = 0.5
217
+ LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold
218
+ STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results)
219
+
220
+ def __init__(self, provider: str = settings.LLM_PROVIDER):
221
+ """
222
+ Initialize the answer service.
223
+
224
+ Args:
225
+ provider: LLM provider to use ("gemini" or "openai")
226
+ """
227
+ self.provider_name = provider
228
+ self._provider: Optional[LLMProvider] = None
229
+
230
+ @property
231
+ def provider(self) -> LLMProvider:
232
+ """Lazy load the LLM provider."""
233
+ if self._provider is None:
234
+ if self.provider_name == "gemini":
235
+ self._provider = GeminiProvider()
236
+ elif self.provider_name == "openai":
237
+ self._provider = OpenAIProvider()
238
+ else:
239
+ raise ValueError(f"Unknown LLM provider: {self.provider_name}")
240
+ return self._provider
241
+
242
+ def generate_answer(
243
+ self,
244
+ question: str,
245
+ context: str,
246
+ citations_info: List[Dict[str, Any]],
247
+ confidence: float,
248
+ has_relevant_results: bool,
249
+ use_verifier: bool = None # None = use config default
250
+ ) -> Dict[str, Any]:
251
+ """
252
+ Generate an answer based on retrieved context with mandatory verifier.
253
+
254
+ Args:
255
+ question: User's question
256
+ context: Retrieved context from knowledge base
257
+ citations_info: List of citation information
258
+ confidence: Average confidence score from retrieval
259
+ has_relevant_results: Whether any results passed the threshold
260
+ use_verifier: Whether to use verifier mode (None = use config default)
261
+
262
+ Returns:
263
+ Dictionary with answer, citations, confidence, and metadata
264
+ """
265
+ # Determine if verifier should be used (mandatory by default)
266
+ if use_verifier is None:
267
+ use_verifier = settings.REQUIRE_VERIFIER
268
+
269
+ # GATE 1: No relevant results found - REFUSE
270
+ if not has_relevant_results or not context:
271
+ logger.info("No relevant context found, returning no-context response")
272
+ return {
273
+ "answer": get_no_context_response(),
274
+ "citations": [],
275
+ "confidence": 0.0,
276
+ "from_knowledge_base": False,
277
+ "escalation_suggested": True,
278
+ "refused": True
279
+ }
280
+
281
+ # GATE 2: Strict confidence threshold - REFUSE if below strict threshold
282
+ if confidence < self.STRICT_CONFIDENCE_THRESHOLD:
283
+ logger.warning(
284
+ f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), "
285
+ f"REFUSING to answer to prevent hallucination"
286
+ )
287
+ return {
288
+ "answer": get_no_context_response(),
289
+ "citations": [],
290
+ "confidence": confidence,
291
+ "from_knowledge_base": False,
292
+ "escalation_suggested": True,
293
+ "refused": True,
294
+ "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}"
295
+ }
296
+
297
+ # GATE 3: Intent-based gating for specific intents (integration, API, etc.)
298
+ intents = detect_intents(question)
299
+ if "integration" in intents or "api" in question.lower():
300
+ # For integration/API questions, require strong relevance
301
+ if confidence < 0.50: # Even stricter for integration questions
302
+ logger.warning(
303
+ f"Integration/API question with low confidence ({confidence:.3f}), "
304
+ f"REFUSING to prevent hallucination"
305
+ )
306
+ return {
307
+ "answer": get_no_context_response(),
308
+ "citations": [],
309
+ "confidence": confidence,
310
+ "from_knowledge_base": False,
311
+ "escalation_suggested": True,
312
+ "refused": True,
313
+ "refusal_reason": "Integration/API questions require higher confidence"
314
+ }
315
+
316
+ # Case 3: Passed all gates - generate answer with MANDATORY verifier
317
+ logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}")
318
+
319
+ try:
320
+ # VERIFIER MODE IS MANDATORY: Draft → Verify → Final
321
+ # Step 1: Generate draft answer with usage tracking
322
+ draft_system, draft_user = format_draft_prompt(context, question)
323
+ draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user)
324
+ logger.info("Generated draft answer, running verifier...")
325
+
326
+ # Step 2: Verify draft answer (MANDATORY)
327
+ verifier = get_verifier_service()
328
+ verification = verifier.verify_answer(
329
+ draft_answer=draft_answer,
330
+ context=context,
331
+ citations_info=citations_info
332
+ )
333
+
334
+ # Step 3: Handle verification result
335
+ if verification["pass"]:
336
+ logger.info("✅ Verifier PASSED - Using draft answer")
337
+ citations = self._extract_citations(draft_answer, citations_info)
338
+ return {
339
+ "answer": draft_answer,
340
+ "citations": citations,
341
+ "confidence": confidence,
342
+ "from_knowledge_base": True,
343
+ "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD,
344
+ "verifier_passed": True,
345
+ "refused": False,
346
+ "usage": usage_info # Include usage info for tracking
347
+ }
348
+ else:
349
+ # Verifier failed - REFUSE to answer
350
+ issues = verification.get('issues', [])
351
+ unsupported = verification.get('unsupported_claims', [])
352
+ logger.warning(
353
+ f"❌ Verifier FAILED - Issues: {issues}, "
354
+ f"Unsupported claims: {unsupported}"
355
+ )
356
+ refusal_message = (
357
+ get_no_context_response() +
358
+ "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. "
359
+ "This helps prevent providing incorrect information."
360
+ )
361
+ return {
362
+ "answer": refusal_message,
363
+ "citations": [],
364
+ "confidence": 0.0,
365
+ "from_knowledge_base": False,
366
+ "escalation_suggested": True,
367
+ "verifier_passed": False,
368
+ "verifier_issues": issues,
369
+ "unsupported_claims": unsupported,
370
+ "refused": True,
371
+ "refusal_reason": "Verifier failed: claims not supported by context",
372
+ "usage": usage_info # Still track usage even if refused
373
+ }
374
+
375
+ except ValueError as e:
376
+ # Configuration errors (e.g., missing API key)
377
+ error_msg = str(e)
378
+ logger.error(f"Configuration error in answer generation: {error_msg}")
379
+ if "API key" in error_msg.lower():
380
+ raise ValueError(f"LLM API key not configured: {error_msg}")
381
+ raise
382
+ except Exception as e:
383
+ logger.error(f"Error generating answer: {e}", exc_info=True)
384
+ # Re-raise to be handled by the endpoint
385
+ raise
386
+
387
+ def _extract_citations(
388
+ self,
389
+ answer: str,
390
+ citations_info: List[Dict[str, Any]]
391
+ ) -> List[Citation]:
392
+ """
393
+ Extract and format citations from the answer.
394
+
395
+ Args:
396
+ answer: Generated answer with [Source X] references
397
+ citations_info: Available citation information
398
+
399
+ Returns:
400
+ List of Citation objects
401
+ """
402
+ citations = []
403
+
404
+ # Find all [Source X] references in the answer
405
+ source_pattern = r'\[Source\s*(\d+)\]'
406
+ matches = re.findall(source_pattern, answer)
407
+ referenced_indices = set(int(m) for m in matches)
408
+
409
+ # Build citation objects for referenced sources
410
+ for info in citations_info:
411
+ if info.get("index") in referenced_indices:
412
+ citations.append(Citation(
413
+ file_name=info.get("file_name", "Unknown"),
414
+ chunk_id=info.get("chunk_id", ""),
415
+ page_number=info.get("page_number"),
416
+ relevance_score=info.get("similarity_score", 0.0),
417
+ excerpt=info.get("excerpt", "")
418
+ ))
419
+
420
+ # If no specific citations found but we have context, include top sources
421
+ if not citations and citations_info:
422
+ for info in citations_info[:3]: # Top 3 sources
423
+ citations.append(Citation(
424
+ file_name=info.get("file_name", "Unknown"),
425
+ chunk_id=info.get("chunk_id", ""),
426
+ page_number=info.get("page_number"),
427
+ relevance_score=info.get("similarity_score", 0.0),
428
+ excerpt=info.get("excerpt", "")
429
+ ))
430
+
431
+ return citations
432
+
433
+
434
+ # Global answer service instance
435
+ _answer_service: Optional[AnswerService] = None
436
+
437
+
438
+ def get_answer_service() -> AnswerService:
439
+ """Get the global answer service instance."""
440
+ global _answer_service
441
+ if _answer_service is None:
442
+ _answer_service = AnswerService()
443
+ return _answer_service
444
+
app/rag/chunking.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document chunking with overlap and metadata preservation.
3
+ """
4
+ import tiktoken
5
+ from typing import List, Dict, Any, Optional
6
+ from dataclasses import dataclass
7
+ import re
8
+ import uuid
9
+ from datetime import datetime
10
+
11
+ from app.config import settings
12
+
13
+
14
+ @dataclass
15
+ class TextChunk:
16
+ """Represents a chunk of text with metadata."""
17
+ content: str
18
+ chunk_index: int
19
+ start_char: int
20
+ end_char: int
21
+ page_number: Optional[int] = None
22
+ token_count: int = 0
23
+
24
+
25
+ class DocumentChunker:
26
+ """
27
+ Chunks documents into smaller pieces with overlap.
28
+ Uses tiktoken for accurate token counting.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ chunk_size: int = settings.CHUNK_SIZE,
34
+ chunk_overlap: int = settings.CHUNK_OVERLAP,
35
+ min_chunk_size: int = settings.MIN_CHUNK_SIZE
36
+ ):
37
+ self.chunk_size = chunk_size
38
+ self.chunk_overlap = chunk_overlap
39
+ self.min_chunk_size = min_chunk_size
40
+ # Use cl100k_base encoding (same as GPT-4, good general purpose)
41
+ self.encoding = tiktoken.get_encoding("cl100k_base")
42
+
43
+ def count_tokens(self, text: str) -> int:
44
+ """Count tokens in text."""
45
+ return len(self.encoding.encode(text))
46
+
47
+ def _split_into_sentences(self, text: str) -> List[str]:
48
+ """Split text into sentences while preserving structure."""
49
+ # Split on sentence boundaries but keep delimiters
50
+ sentence_endings = r'(?<=[.!?])\s+'
51
+ sentences = re.split(sentence_endings, text)
52
+ return [s.strip() for s in sentences if s.strip()]
53
+
54
+ def _split_into_paragraphs(self, text: str) -> List[str]:
55
+ """Split text into paragraphs."""
56
+ paragraphs = re.split(r'\n\s*\n', text)
57
+ return [p.strip() for p in paragraphs if p.strip()]
58
+
59
+ def chunk_text(
60
+ self,
61
+ text: str,
62
+ page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number
63
+ ) -> List[TextChunk]:
64
+ """
65
+ Chunk text into smaller pieces with overlap.
66
+
67
+ Args:
68
+ text: The text to chunk
69
+ page_numbers: Optional mapping of character positions to page numbers
70
+
71
+ Returns:
72
+ List of TextChunk objects
73
+ """
74
+ if not text.strip():
75
+ return []
76
+
77
+ chunks = []
78
+ current_chunk = ""
79
+ current_start = 0
80
+ chunk_index = 0
81
+
82
+ # First, split into paragraphs for natural boundaries
83
+ paragraphs = self._split_into_paragraphs(text)
84
+
85
+ char_position = 0
86
+ for para in paragraphs:
87
+ para_tokens = self.count_tokens(para)
88
+ current_tokens = self.count_tokens(current_chunk)
89
+
90
+ # If adding this paragraph exceeds chunk size
91
+ if current_tokens + para_tokens > self.chunk_size and current_chunk:
92
+ # Save current chunk if it meets minimum size
93
+ if current_tokens >= self.min_chunk_size:
94
+ page_num = None
95
+ if page_numbers:
96
+ # Find the page number for this chunk's start position
97
+ for pos, page in sorted(page_numbers.items()):
98
+ if pos <= current_start:
99
+ page_num = page
100
+
101
+ chunks.append(TextChunk(
102
+ content=current_chunk.strip(),
103
+ chunk_index=chunk_index,
104
+ start_char=current_start,
105
+ end_char=char_position,
106
+ page_number=page_num,
107
+ token_count=current_tokens
108
+ ))
109
+ chunk_index += 1
110
+
111
+ # Start new chunk with overlap
112
+ overlap_text = self._get_overlap_text(current_chunk)
113
+ current_chunk = overlap_text + "\n\n" + para if overlap_text else para
114
+ current_start = char_position - len(overlap_text) if overlap_text else char_position
115
+ else:
116
+ # Add paragraph to current chunk
117
+ if current_chunk:
118
+ current_chunk += "\n\n" + para
119
+ else:
120
+ current_chunk = para
121
+ current_start = char_position
122
+
123
+ char_position += len(para) + 2 # +2 for paragraph separator
124
+
125
+ # Don't forget the last chunk
126
+ if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size:
127
+ page_num = None
128
+ if page_numbers:
129
+ for pos, page in sorted(page_numbers.items()):
130
+ if pos <= current_start:
131
+ page_num = page
132
+
133
+ chunks.append(TextChunk(
134
+ content=current_chunk.strip(),
135
+ chunk_index=chunk_index,
136
+ start_char=current_start,
137
+ end_char=len(text),
138
+ page_number=page_num,
139
+ token_count=self.count_tokens(current_chunk)
140
+ ))
141
+
142
+ return chunks
143
+
144
+ def _get_overlap_text(self, text: str) -> str:
145
+ """Get the overlap text from the end of a chunk."""
146
+ sentences = self._split_into_sentences(text)
147
+ if not sentences:
148
+ return ""
149
+
150
+ overlap = ""
151
+ tokens = 0
152
+
153
+ # Work backwards through sentences
154
+ for sentence in reversed(sentences):
155
+ sentence_tokens = self.count_tokens(sentence)
156
+ if tokens + sentence_tokens <= self.chunk_overlap:
157
+ overlap = sentence + " " + overlap if overlap else sentence
158
+ tokens += sentence_tokens
159
+ else:
160
+ break
161
+
162
+ return overlap.strip()
163
+
164
+ def create_chunk_metadata(
165
+ self,
166
+ chunk: TextChunk,
167
+ tenant_id: str, # CRITICAL: Multi-tenant isolation
168
+ kb_id: str,
169
+ user_id: str,
170
+ file_name: str,
171
+ file_type: str,
172
+ total_chunks: int,
173
+ document_id: Optional[str] = None
174
+ ) -> Dict[str, Any]:
175
+ """Create metadata dictionary for a chunk."""
176
+ chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}"
177
+
178
+ return {
179
+ "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
180
+ "kb_id": kb_id,
181
+ "user_id": user_id,
182
+ "file_name": file_name,
183
+ "file_type": file_type,
184
+ "chunk_id": chunk_id,
185
+ "chunk_index": chunk.chunk_index,
186
+ "page_number": chunk.page_number,
187
+ "total_chunks": total_chunks,
188
+ "token_count": chunk.token_count,
189
+ "document_id": document_id, # Track original document
190
+ "created_at": datetime.utcnow().isoformat()
191
+ }
192
+
193
+
194
+ # Global chunker instance
195
+ chunker = DocumentChunker()
196
+
app/rag/embeddings.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedding generation using Sentence Transformers.
3
+ Supports local models for privacy and offline use.
4
+ """
5
+ from sentence_transformers import SentenceTransformer
6
+ from typing import List, Optional
7
+ import numpy as np
8
+ import logging
9
+ from functools import lru_cache
10
+
11
+ from app.config import settings
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class EmbeddingService:
18
+ """
19
+ Generates embeddings for text using Sentence Transformers.
20
+ Uses a lightweight model optimized for semantic search.
21
+ """
22
+
23
+ def __init__(self, model_name: str = settings.EMBEDDING_MODEL):
24
+ """
25
+ Initialize the embedding service.
26
+
27
+ Args:
28
+ model_name: Name of the Sentence Transformer model to use
29
+ """
30
+ self.model_name = model_name
31
+ self._model: Optional[SentenceTransformer] = None
32
+ logger.info(f"Embedding service initialized with model: {model_name}")
33
+
34
+ @property
35
+ def model(self) -> SentenceTransformer:
36
+ """Lazy load the model."""
37
+ if self._model is None:
38
+ logger.info(f"Loading embedding model: {self.model_name}")
39
+ self._model = SentenceTransformer(self.model_name)
40
+ logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}")
41
+ return self._model
42
+
43
+ def embed_text(self, text: str) -> List[float]:
44
+ """
45
+ Generate embedding for a single text.
46
+
47
+ Args:
48
+ text: Text to embed
49
+
50
+ Returns:
51
+ List of floats representing the embedding vector
52
+ """
53
+ if not text.strip():
54
+ raise ValueError("Cannot embed empty text")
55
+
56
+ embedding = self.model.encode(text, convert_to_numpy=True)
57
+ return embedding.tolist()
58
+
59
+ def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
60
+ """
61
+ Generate embeddings for multiple texts.
62
+
63
+ Args:
64
+ texts: List of texts to embed
65
+ batch_size: Batch size for processing
66
+
67
+ Returns:
68
+ List of embedding vectors
69
+ """
70
+ if not texts:
71
+ return []
72
+
73
+ # Filter out empty texts
74
+ valid_texts = [t for t in texts if t.strip()]
75
+ if len(valid_texts) != len(texts):
76
+ logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts")
77
+
78
+ logger.info(f"Generating embeddings for {len(valid_texts)} texts")
79
+
80
+ embeddings = self.model.encode(
81
+ valid_texts,
82
+ batch_size=batch_size,
83
+ show_progress_bar=len(valid_texts) > 100,
84
+ convert_to_numpy=True
85
+ )
86
+
87
+ return embeddings.tolist()
88
+
89
+ def embed_query(self, query: str) -> List[float]:
90
+ """
91
+ Generate embedding for a search query.
92
+ Some models have different embeddings for queries vs documents.
93
+
94
+ Args:
95
+ query: Search query to embed
96
+
97
+ Returns:
98
+ Embedding vector for the query
99
+ """
100
+ # For most models, query embedding is the same as document embedding
101
+ # But we keep this separate for models that differentiate
102
+ return self.embed_text(query)
103
+
104
+ def get_dimension(self) -> int:
105
+ """Get the embedding dimension."""
106
+ return self.model.get_sentence_embedding_dimension()
107
+
108
+ def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
109
+ """
110
+ Compute cosine similarity between two embeddings.
111
+
112
+ Args:
113
+ embedding1: First embedding vector
114
+ embedding2: Second embedding vector
115
+
116
+ Returns:
117
+ Cosine similarity score (0-1)
118
+ """
119
+ vec1 = np.array(embedding1)
120
+ vec2 = np.array(embedding2)
121
+
122
+ # Cosine similarity
123
+ dot_product = np.dot(vec1, vec2)
124
+ norm1 = np.linalg.norm(vec1)
125
+ norm2 = np.linalg.norm(vec2)
126
+
127
+ if norm1 == 0 or norm2 == 0:
128
+ return 0.0
129
+
130
+ return float(dot_product / (norm1 * norm2))
131
+
132
+
133
+ # Global embedding service instance (lazy loaded)
134
+ _embedding_service: Optional[EmbeddingService] = None
135
+
136
+
137
+ def get_embedding_service() -> EmbeddingService:
138
+ """Get the global embedding service instance."""
139
+ global _embedding_service
140
+ if _embedding_service is None:
141
+ _embedding_service = EmbeddingService()
142
+ return _embedding_service
143
+
144
+
145
+
app/rag/ingest.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document ingestion and parsing pipeline.
3
+ Supports PDF, DOCX, TXT, and Markdown files.
4
+ """
5
+ import fitz # PyMuPDF
6
+ from docx import Document as DocxDocument
7
+ import markdown
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ import chardet
11
+ import re
12
+ from dataclasses import dataclass
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class ParsedDocument:
21
+ """Represents a parsed document with text and metadata."""
22
+ text: str
23
+ file_name: str
24
+ file_type: str
25
+ page_count: Optional[int] = None
26
+ page_map: Optional[Dict[int, int]] = None # char_position -> page_number
27
+ metadata: Dict[str, Any] = None
28
+
29
+
30
+ class DocumentParser:
31
+ """
32
+ Parses various document formats into plain text.
33
+ Preserves page information for citations.
34
+ """
35
+
36
+ SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.doc', '.txt', '.md', '.markdown'}
37
+
38
+ def __init__(self):
39
+ self.parsers = {
40
+ '.pdf': self._parse_pdf,
41
+ '.docx': self._parse_docx,
42
+ '.doc': self._parse_docx, # Try docx parser for doc files
43
+ '.txt': self._parse_text,
44
+ '.md': self._parse_markdown,
45
+ '.markdown': self._parse_markdown,
46
+ }
47
+
48
+ def parse(self, file_path: Path) -> ParsedDocument:
49
+ """
50
+ Parse a document file into text.
51
+
52
+ Args:
53
+ file_path: Path to the document file
54
+
55
+ Returns:
56
+ ParsedDocument with extracted text and metadata
57
+ """
58
+ if not file_path.exists():
59
+ raise FileNotFoundError(f"File not found: {file_path}")
60
+
61
+ ext = file_path.suffix.lower()
62
+ if ext not in self.SUPPORTED_EXTENSIONS:
63
+ raise ValueError(f"Unsupported file type: {ext}")
64
+
65
+ parser = self.parsers.get(ext)
66
+ if not parser:
67
+ raise ValueError(f"No parser available for: {ext}")
68
+
69
+ logger.info(f"Parsing document: {file_path.name} ({ext})")
70
+ return parser(file_path)
71
+
72
+ def _parse_pdf(self, file_path: Path) -> ParsedDocument:
73
+ """Parse PDF file with page tracking."""
74
+ try:
75
+ doc = fitz.open(file_path)
76
+ text_parts = []
77
+ page_map = {}
78
+ current_pos = 0
79
+
80
+ for page_num, page in enumerate(doc, start=1):
81
+ page_text = page.get_text("text")
82
+ if page_text.strip():
83
+ # Record where this page starts
84
+ page_map[current_pos] = page_num
85
+ text_parts.append(page_text)
86
+ current_pos += len(page_text) + 2 # +2 for separator
87
+
88
+ page_count = len(doc)
89
+ doc.close()
90
+
91
+ full_text = "\n\n".join(text_parts)
92
+
93
+ return ParsedDocument(
94
+ text=self._clean_text(full_text),
95
+ file_name=file_path.name,
96
+ file_type="pdf",
97
+ page_count=page_count,
98
+ page_map=page_map,
99
+ metadata={"source": str(file_path)}
100
+ )
101
+ except Exception as e:
102
+ logger.error(f"Error parsing PDF {file_path}: {e}")
103
+ raise
104
+
105
+ def _parse_docx(self, file_path: Path) -> ParsedDocument:
106
+ """Parse DOCX file."""
107
+ try:
108
+ doc = DocxDocument(file_path)
109
+ paragraphs = []
110
+
111
+ for para in doc.paragraphs:
112
+ if para.text.strip():
113
+ paragraphs.append(para.text)
114
+
115
+ # Also extract text from tables
116
+ for table in doc.tables:
117
+ for row in table.rows:
118
+ row_text = []
119
+ for cell in row.cells:
120
+ if cell.text.strip():
121
+ row_text.append(cell.text.strip())
122
+ if row_text:
123
+ paragraphs.append(" | ".join(row_text))
124
+
125
+ full_text = "\n\n".join(paragraphs)
126
+
127
+ return ParsedDocument(
128
+ text=self._clean_text(full_text),
129
+ file_name=file_path.name,
130
+ file_type="docx",
131
+ metadata={"source": str(file_path)}
132
+ )
133
+ except Exception as e:
134
+ logger.error(f"Error parsing DOCX {file_path}: {e}")
135
+ raise
136
+
137
+ def _parse_text(self, file_path: Path) -> ParsedDocument:
138
+ """Parse plain text file with encoding detection."""
139
+ try:
140
+ # Detect encoding
141
+ with open(file_path, 'rb') as f:
142
+ raw_data = f.read()
143
+ detected = chardet.detect(raw_data)
144
+ encoding = detected.get('encoding', 'utf-8')
145
+
146
+ # Read with detected encoding
147
+ with open(file_path, 'r', encoding=encoding, errors='replace') as f:
148
+ text = f.read()
149
+
150
+ return ParsedDocument(
151
+ text=self._clean_text(text),
152
+ file_name=file_path.name,
153
+ file_type="txt",
154
+ metadata={"source": str(file_path), "encoding": encoding}
155
+ )
156
+ except Exception as e:
157
+ logger.error(f"Error parsing text file {file_path}: {e}")
158
+ raise
159
+
160
+ def _parse_markdown(self, file_path: Path) -> ParsedDocument:
161
+ """Parse Markdown file, converting to plain text."""
162
+ try:
163
+ with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
164
+ md_content = f.read()
165
+
166
+ # Convert markdown to HTML, then strip tags
167
+ html = markdown.markdown(md_content)
168
+ text = self._strip_html_tags(html)
169
+
170
+ # Also keep the original markdown structure for better context
171
+ # Remove markdown syntax but keep structure
172
+ clean_md = self._clean_markdown(md_content)
173
+
174
+ return ParsedDocument(
175
+ text=self._clean_text(clean_md),
176
+ file_name=file_path.name,
177
+ file_type="markdown",
178
+ metadata={"source": str(file_path)}
179
+ )
180
+ except Exception as e:
181
+ logger.error(f"Error parsing Markdown {file_path}: {e}")
182
+ raise
183
+
184
+ def _clean_text(self, text: str) -> str:
185
+ """Clean and normalize text."""
186
+ # Replace multiple whitespace with single space
187
+ text = re.sub(r'[ \t]+', ' ', text)
188
+ # Replace multiple newlines with double newline
189
+ text = re.sub(r'\n{3,}', '\n\n', text)
190
+ # Remove leading/trailing whitespace from lines
191
+ lines = [line.strip() for line in text.split('\n')]
192
+ text = '\n'.join(lines)
193
+ # Remove leading/trailing whitespace
194
+ return text.strip()
195
+
196
+ def _strip_html_tags(self, html: str) -> str:
197
+ """Remove HTML tags from text."""
198
+ clean = re.sub(r'<[^>]+>', '', html)
199
+ return clean
200
+
201
+ def _clean_markdown(self, md_text: str) -> str:
202
+ """Clean markdown syntax while preserving structure."""
203
+ # Remove code blocks but keep content
204
+ md_text = re.sub(r'```[\s\S]*?```', '', md_text)
205
+ md_text = re.sub(r'`([^`]+)`', r'\1', md_text)
206
+
207
+ # Convert headers to plain text with emphasis
208
+ md_text = re.sub(r'^#{1,6}\s+(.+)$', r'\1:', md_text, flags=re.MULTILINE)
209
+
210
+ # Remove bold/italic markers
211
+ md_text = re.sub(r'\*\*([^*]+)\*\*', r'\1', md_text)
212
+ md_text = re.sub(r'\*([^*]+)\*', r'\1', md_text)
213
+ md_text = re.sub(r'__([^_]+)__', r'\1', md_text)
214
+ md_text = re.sub(r'_([^_]+)_', r'\1', md_text)
215
+
216
+ # Remove links but keep text
217
+ md_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', md_text)
218
+
219
+ # Remove images
220
+ md_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', '', md_text)
221
+
222
+ # Convert lists to plain text
223
+ md_text = re.sub(r'^[\*\-\+]\s+', '• ', md_text, flags=re.MULTILINE)
224
+ md_text = re.sub(r'^\d+\.\s+', '', md_text, flags=re.MULTILINE)
225
+
226
+ return md_text
227
+
228
+
229
+ # Global parser instance
230
+ parser = DocumentParser()
231
+