diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..52252fbf3d731f6c0ff561ae05589bb95d222a09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ +๏ปฟ# Python files +*.py +*.pyc +__pycache__/ + +# Config files +*.txt +*.yaml +*.yml +*.toml +*.json +*.md +*.sh +*.bat +*.ps1 + +# Directories to include +app/ +requirements.txt +Dockerfile +app.py +README.md +.gitignore +.env.example.txt + +# Exclude binaries +*.png +*.jpg +*.jpeg +*.db +*.pdf +*.bin +*.sqlite3 +public/ +data/billing/ +data/vectordb/ +venv/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..fdcdc2c2faa1cd512f9982d4d32c18462c73882d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.11-slim + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first (for better caching) +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create necessary directories +RUN mkdir -p data/uploads data/processed data/vectordb data/billing + +# Expose port (Hugging Face Spaces uses 7860, but we'll use PORT env var) +EXPOSE 7860 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:${PORT:-7860}/health/live || exit 1 + +# Start the application (Hugging Face Spaces provides PORT env var) +CMD uvicorn app.main:app --host 0.0.0.0 --port ${PORT:-7860} + diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..098f39f9861376a5fc28a208d4e8d9c43f8ecd3b --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +--- +title: ClientSphere RAG Backend +emoji: ๐Ÿค– +colorFrom: blue +colorTo: purple +sdk: docker +sdk_version: "4.0.0" +python_version: "3.11" +app_file: app.py +pinned: false +--- + +# ClientSphere RAG Backend + +FastAPI-based RAG (Retrieval-Augmented Generation) backend for ClientSphere AI customer support platform. + +## Features + +- ๐Ÿ“š Knowledge base management +- ๐Ÿ” Semantic search with embeddings +- ๐Ÿ’ฌ AI-powered chat with citations +- ๐Ÿ“Š Confidence scoring +- ๐Ÿ”’ Multi-tenant isolation +- ๐Ÿ“ˆ Usage tracking and billing + +## API Endpoints + +- `GET /health/live` - Health check +- `GET /kb/stats` - Knowledge base statistics +- `POST /kb/upload` - Upload documents +- `POST /chat` - Chat with RAG +- `GET /kb/search` - Search knowledge base + +## Environment Variables + +Required: +- `GEMINI_API_KEY` - Google Gemini API key +- `ENV` - Set to `prod` for production +- `LLM_PROVIDER` - `gemini` or `openai` + +Optional: +- `ALLOWED_ORIGINS` - CORS allowed origins (comma-separated) +- `JWT_SECRET` - JWT secret for authentication + +## Documentation + +See `README_HF_SPACES.md` for deployment details. diff --git a/README_HF_SPACES.md b/README_HF_SPACES.md new file mode 100644 index 0000000000000000000000000000000000000000..428acd39865110495f94d2ee6bf1ada09f3a5ff4 --- /dev/null +++ b/README_HF_SPACES.md @@ -0,0 +1,231 @@ +# ๐Ÿš€ Deploy RAG Backend to Hugging Face Spaces + +Hugging Face Spaces is **perfect** for deploying Python/FastAPI applications with ML dependencies! + +## โœ… Why Hugging Face Spaces? + +- โœ… **Free tier** with generous limits +- โœ… **Full Python 3.11+** support +- โœ… **ML libraries** fully supported (sentence-transformers, chromadb, etc.) +- โœ… **Persistent storage** for vector database +- โœ… **No bundle size limits** +- โœ… **GPU support** available (paid) +- โœ… **Automatic HTTPS** and custom domains +- โœ… **GitHub integration** (auto-deploy on push) + +## ๐Ÿ“‹ Prerequisites + +1. **Hugging Face Account**: Sign up at [huggingface.co](https://huggingface.co) +2. **GitHub Repository**: Your code should be in a GitHub repository +3. **Gemini API Key**: Get from [Google AI Studio](https://aistudio.google.com/app/apikey) + +## ๐Ÿš€ Step-by-Step Deployment + +### Step 1: Prepare Your Repository + +Your `rag-backend/` directory should contain: +- โœ… `app.py` - Entry point (already created) +- โœ… `requirements.txt` - Dependencies +- โœ… `app/main.py` - FastAPI application +- โœ… All other application files + +### Step 2: Create Hugging Face Space + +1. Go to [Hugging Face Spaces](https://huggingface.co/spaces) +2. Click **"Create new Space"** +3. Configure: + - **Owner**: Your username + - **Space name**: `clientsphere-rag-backend` (or your choice) + - **SDK**: **Docker** (recommended) or **Gradio** (if you want UI) + - **Hardware**: + - **CPU basic** (free) - Good for testing + - **CPU upgrade** (paid) - Better performance + - **GPU** (paid) - For heavy ML workloads + +### Step 3: Connect GitHub Repository + +1. In Space creation, select **"Repository"** as source +2. Choose your GitHub repository +3. Set **Repository path** to: `rag-backend/` (subdirectory) +4. Click **"Create Space"** + +### Step 4: Configure Environment Variables + +1. Go to your Space's **Settings** tab +2. Scroll to **"Repository secrets"** or **"Variables"** +3. Add these secrets: + +**Required:** +``` +GEMINI_API_KEY=your_gemini_api_key_here +ENV=prod +LLM_PROVIDER=gemini +``` + +**Optional (but recommended):** +``` +ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://abaa49a3.clientsphere.pages.dev +JWT_SECRET=your_secure_jwt_secret +DEBUG=false +``` + +### Step 5: Configure Docker (if using Docker SDK) + +If you selected **Docker** SDK, Hugging Face will use your `Dockerfile`. + +**Your existing `Dockerfile` should work!** It's already configured correctly. + +### Step 6: Alternative - Use app.py (Simpler) + +If you want to use the simpler `app.py` approach: + +1. In Space settings, set: + - **SDK**: **Gradio** or **Streamlit** (but we'll override) + - **App file**: `app.py` + +2. Hugging Face will automatically: + - Install dependencies from `requirements.txt` + - Run `python app.py` + - Expose on port 7860 + +### Step 7: Deploy! + +1. **Push to GitHub** (if not already): + ```bash + git add rag-backend/app.py + git commit -m "Add Hugging Face Spaces entry point" + git push origin main + ``` + +2. **Hugging Face will auto-deploy** from your GitHub repo! + +3. **Wait for build** (5-10 minutes first time, faster after) + +4. **Your Space URL**: `https://your-username-clientsphere-rag-backend.hf.space` + +## ๐Ÿ”ง Configuration Options + +### Option A: Docker (Recommended) + +**Advantages:** +- Full control over environment +- Can customize Python version +- Better for production + +**Setup:** +- Use existing `Dockerfile` +- Hugging Face will build and run it +- Exposes on port 7860 automatically + +### Option B: app.py (Simpler) + +**Advantages:** +- Simpler setup +- Faster builds +- Good for development + +**Setup:** +- Create `app.py` in `rag-backend/` (already done) +- Hugging Face runs it automatically + +## ๐Ÿ“ Environment Variables Reference + +| Variable | Required | Description | +|----------|----------|-------------| +| `GEMINI_API_KEY` | โœ… Yes | Your Google Gemini API key | +| `ENV` | โœ… Yes | Set to `prod` for production | +| `LLM_PROVIDER` | โœ… Yes | `gemini` or `openai` | +| `ALLOWED_ORIGINS` | โš ๏ธ Recommended | CORS allowed origins (comma-separated) | +| `JWT_SECRET` | โš ๏ธ Recommended | JWT secret for authentication | +| `DEBUG` | โŒ Optional | Set to `false` in production | +| `OPENAI_API_KEY` | โŒ Optional | If using OpenAI instead of Gemini | + +## ๐ŸŒ CORS Configuration + +After deployment, update `ALLOWED_ORIGINS` to include: +- Your Cloudflare Pages frontend URL +- Your Cloudflare Workers backend URL +- Any other origins that need access + +Example: +``` +ALLOWED_ORIGINS=https://main.clientsphere.pages.dev,https://mcp-backend.officialchiragp1605.workers.dev +``` + +## ๐Ÿ”„ Updating Deployment + +**Automatic (Recommended):** +- Push to GitHub โ†’ Hugging Face auto-deploys + +**Manual:** +- Go to Space โ†’ Settings โ†’ "Rebuild Space" + +## ๐Ÿ“Š Resource Limits + +### Free Tier: +- โœ… **CPU**: Basic (sufficient for RAG) +- โœ… **Storage**: 50GB (plenty for vector DB) +- โœ… **Memory**: 16GB RAM +- โœ… **Build time**: 20 minutes +- โœ… **Sleep after inactivity**: 48 hours (wakes on request) + +### Paid Tiers: +- **CPU upgrade**: Better performance +- **GPU**: For heavy ML workloads +- **No sleep**: Always-on service + +## ๐Ÿงช Testing Deployment + +After deployment, test your endpoints: + +```bash +# Health check +curl https://your-username-clientsphere-rag-backend.hf.space/health/live + +# KB Stats (with auth) +curl https://your-username-clientsphere-rag-backend.hf.space/kb/stats?kb_id=default&tenant_id=test&user_id=test +``` + +## ๐Ÿ”— Update Frontend + +After deployment, update Cloudflare Pages environment variable: + +``` +VITE_RAG_API_URL=https://your-username-clientsphere-rag-backend.hf.space +``` + +Then redeploy frontend: +```bash +npm run build +npx wrangler pages deploy dist --project-name=clientsphere +``` + +## โœ… Advantages Over Render + +| Feature | Hugging Face Spaces | Render | +|---------|-------------------|--------| +| Free Tier | โœ… Generous | โš ๏ธ Limited | +| ML Libraries | โœ… Full support | โœ… Full support | +| Auto-deploy | โœ… GitHub integration | โœ… GitHub integration | +| Storage | โœ… 50GB free | โš ๏ธ Limited | +| Sleep Mode | โœ… Wakes on request | โŒ No sleep mode | +| GPU Support | โœ… Available | โŒ Not available | +| Community | โœ… Large ML community | โš ๏ธ Smaller | + +## ๐ŸŽฏ Summary + +1. โœ… Create Hugging Face Space +2. โœ… Connect GitHub repository (rag-backend/) +3. โœ… Set environment variables +4. โœ… Deploy (automatic on push) +5. โœ… Update frontend `VITE_RAG_API_URL` +6. โœ… Test and enjoy! + +**Your RAG backend will be live at:** +`https://your-username-clientsphere-rag-backend.hf.space` + +--- + +**Need help?** Check [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces) + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7cbfba8081e99baf78061bdfc4b9b0d469bc07 --- /dev/null +++ b/app.py @@ -0,0 +1,13 @@ +""" +Hugging Face Spaces entry point for RAG Backend. +This file is used when deploying to Hugging Face Spaces. +""" +import os +import uvicorn +from app.main import app + +if __name__ == "__main__": + # Hugging Face Spaces provides PORT environment variable (defaults to 7860) + port = int(os.getenv("PORT", 7860)) + uvicorn.run(app, host="0.0.0.0", port=port) + diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2927b3f94e820da8fb6011f7728162ff6588b1e0 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,7 @@ +""" +ClientSphere RAG Backend Application. +""" +__version__ = "1.0.0" + + + diff --git a/app/__pycache__/__init__.cpython-313.pyc b/app/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdbcbc7f7b1606aa3483782b38a838732e262abd Binary files /dev/null and b/app/__pycache__/__init__.cpython-313.pyc differ diff --git a/app/__pycache__/config.cpython-313.pyc b/app/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff6d5db559d78685cbb1b28e0111047ae20418e Binary files /dev/null and b/app/__pycache__/config.cpython-313.pyc differ diff --git a/app/__pycache__/main.cpython-313.pyc b/app/__pycache__/main.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73ca17bb267d5c5aaa06bafd899f1f4fe20eda53 Binary files /dev/null and b/app/__pycache__/main.cpython-313.pyc differ diff --git a/app/billing/__pycache__/pricing.cpython-313.pyc b/app/billing/__pycache__/pricing.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43ab282823946c4f77c61b7fa9bfaece57eb5052 Binary files /dev/null and b/app/billing/__pycache__/pricing.cpython-313.pyc differ diff --git a/app/billing/__pycache__/quota.cpython-313.pyc b/app/billing/__pycache__/quota.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf06d53ff56842e6bd33bac131f1abc7d744757 Binary files /dev/null and b/app/billing/__pycache__/quota.cpython-313.pyc differ diff --git a/app/billing/__pycache__/usage_tracker.cpython-313.pyc b/app/billing/__pycache__/usage_tracker.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9613384d8067dd1a3c356c90ab5f648563fbe573 Binary files /dev/null and b/app/billing/__pycache__/usage_tracker.cpython-313.pyc differ diff --git a/app/billing/pricing.py b/app/billing/pricing.py new file mode 100644 index 0000000000000000000000000000000000000000..186832f56e89bc5d8d3c501d065f816414e6b4fb --- /dev/null +++ b/app/billing/pricing.py @@ -0,0 +1,57 @@ +""" +Pricing table for LLM providers. +Used to calculate estimated costs from token usage. +""" +from typing import Dict, Optional + +# Pricing per 1M tokens (as of 2024, update as needed) +PRICING_TABLE: Dict[str, Dict[str, float]] = { + "gemini": { + "gemini-pro": {"input": 0.50, "output": 1.50}, # $0.50/$1.50 per 1M tokens + "gemini-1.5-pro": {"input": 1.25, "output": 5.00}, + "gemini-1.5-flash": {"input": 0.075, "output": 0.30}, + "gemini-1.0-pro": {"input": 0.50, "output": 1.50}, + "default": {"input": 0.50, "output": 1.50} + }, + "openai": { + "gpt-4": {"input": 30.00, "output": 60.00}, + "gpt-4-turbo": {"input": 10.00, "output": 30.00}, + "gpt-3.5-turbo": {"input": 0.50, "output": 1.50}, + "default": {"input": 0.50, "output": 1.50} + } +} + + +def calculate_cost( + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int +) -> float: + """ + Calculate estimated cost in USD based on token usage. + + Args: + provider: "gemini" or "openai" + model: Model name (e.g., "gemini-pro", "gpt-3.5-turbo") + prompt_tokens: Number of input tokens + completion_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + provider_pricing = PRICING_TABLE.get(provider.lower(), {}) + model_pricing = provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) + + # Calculate cost: (tokens / 1M) * price_per_1M + input_cost = (prompt_tokens / 1_000_000) * model_pricing["input"] + output_cost = (completion_tokens / 1_000_000) * model_pricing["output"] + + return input_cost + output_cost + + +def get_model_pricing(provider: str, model: str) -> Dict[str, float]: + """Get pricing for a specific model.""" + provider_pricing = PRICING_TABLE.get(provider.lower(), {}) + return provider_pricing.get(model.lower(), provider_pricing.get("default", {"input": 0.50, "output": 1.50})) + diff --git a/app/billing/quota.py b/app/billing/quota.py new file mode 100644 index 0000000000000000000000000000000000000000..c527a7b5fe3f79c0362eb72f50a6b8d152447541 --- /dev/null +++ b/app/billing/quota.py @@ -0,0 +1,131 @@ +""" +Quota management and enforcement. +""" +from sqlalchemy.orm import Session +from sqlalchemy import func, and_ +from datetime import datetime, timedelta +from typing import Optional, Tuple +import logging + +from app.db.models import TenantPlan, UsageMonthly, Tenant +logger = logging.getLogger(__name__) + +# Plan limits (chats per month) +PLAN_LIMITS = { + "starter": 500, + "growth": 5000, + "pro": -1 # -1 means unlimited +} + + +def get_tenant_plan(db: Session, tenant_id: str) -> Optional[TenantPlan]: + """Get tenant's current plan.""" + return db.query(TenantPlan).filter(TenantPlan.tenant_id == tenant_id).first() + + +def get_monthly_usage(db: Session, tenant_id: str, year: Optional[int] = None, month: Optional[int] = None) -> Optional[UsageMonthly]: + """Get monthly usage for tenant.""" + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + return db.query(UsageMonthly).filter( + and_( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == target_year, + UsageMonthly.month == target_month + ) + ).first() + + +def check_quota(db: Session, tenant_id: str) -> Tuple[bool, Optional[str]]: + """ + Check if tenant has quota remaining for the current month. + + Returns: + (has_quota, error_message) + has_quota: True if quota available, False if exceeded + error_message: None if quota available, error message if exceeded + """ + # Get tenant plan + plan = get_tenant_plan(db, tenant_id) + + if not plan: + # Default to starter plan if no plan assigned + logger.warning(f"Tenant {tenant_id} has no plan assigned, defaulting to starter") + monthly_limit = PLAN_LIMITS.get("starter", 500) + else: + monthly_limit = plan.monthly_chat_limit + + # Unlimited plan (-1) always passes + if monthly_limit == -1: + return True, None + + # Get current month usage + now = datetime.utcnow() + monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) + + current_usage = monthly_usage.total_requests if monthly_usage else 0 + + # Check if quota exceeded + if current_usage >= monthly_limit: + return False, f"AI quota exceeded ({current_usage}/{monthly_limit} chats this month). Upgrade your plan." + + return True, None + + +def ensure_tenant_exists(db: Session, tenant_id: str) -> None: + """Ensure tenant record exists in database.""" + tenant = db.query(Tenant).filter(Tenant.id == tenant_id).first() + if not tenant: + # Create tenant with default starter plan + tenant = Tenant(id=tenant_id, name=f"Tenant {tenant_id}") + db.add(tenant) + + # Create default starter plan + plan = TenantPlan( + tenant_id=tenant_id, + plan_name="starter", + monthly_chat_limit=PLAN_LIMITS["starter"] + ) + db.add(plan) + db.commit() + logger.info(f"Created tenant {tenant_id} with starter plan") + + +def set_tenant_plan(db: Session, tenant_id: str, plan_name: str) -> TenantPlan: + """ + Set tenant's subscription plan. + + Args: + db: Database session + tenant_id: Tenant ID + plan_name: "starter", "growth", or "pro" + + Returns: + Updated TenantPlan + """ + if plan_name not in PLAN_LIMITS: + raise ValueError(f"Invalid plan name: {plan_name}. Must be one of: {list(PLAN_LIMITS.keys())}") + + # Ensure tenant exists + ensure_tenant_exists(db, tenant_id) + + # Get or create plan + plan = get_tenant_plan(db, tenant_id) + if plan: + plan.plan_name = plan_name + plan.monthly_chat_limit = PLAN_LIMITS[plan_name] + plan.updated_at = datetime.utcnow() + else: + plan = TenantPlan( + tenant_id=tenant_id, + plan_name=plan_name, + monthly_chat_limit=PLAN_LIMITS[plan_name] + ) + db.add(plan) + + db.commit() + db.refresh(plan) + return plan + diff --git a/app/billing/usage_tracker.py b/app/billing/usage_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..390dda761ef82aa5ae678798a906c3232799bcb8 --- /dev/null +++ b/app/billing/usage_tracker.py @@ -0,0 +1,173 @@ +""" +Usage tracking service. +Tracks token usage and costs for each LLM request. +""" +from sqlalchemy.orm import Session +from sqlalchemy import func, and_ +from datetime import datetime, timedelta +from typing import Optional +import uuid +import logging + +from app.db.models import UsageEvent, UsageDaily, UsageMonthly, Tenant +from app.billing.pricing import calculate_cost +from app.billing.quota import ensure_tenant_exists + +logger = logging.getLogger(__name__) + + +def track_usage( + db: Session, + tenant_id: str, + user_id: str, + kb_id: str, + provider: str, + model: str, + prompt_tokens: int, + completion_tokens: int, + request_timestamp: Optional[datetime] = None +) -> UsageEvent: + """ + Track a single usage event. + + Args: + db: Database session + tenant_id: Tenant ID + user_id: User ID + kb_id: Knowledge base ID + provider: "gemini" or "openai" + model: Model name + prompt_tokens: Input tokens + completion_tokens: Output tokens + request_timestamp: Request timestamp (defaults to now) + + Returns: + Created UsageEvent + """ + # Ensure tenant exists + ensure_tenant_exists(db, tenant_id) + + # Calculate cost + total_tokens = prompt_tokens + completion_tokens + estimated_cost = calculate_cost(provider, model, prompt_tokens, completion_tokens) + + # Create usage event + request_id = f"req_{uuid.uuid4().hex[:16]}" + timestamp = request_timestamp or datetime.utcnow() + + usage_event = UsageEvent( + request_id=request_id, + tenant_id=tenant_id, + user_id=user_id, + kb_id=kb_id, + provider=provider, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + estimated_cost_usd=estimated_cost, + request_timestamp=timestamp + ) + + db.add(usage_event) + + # Update daily aggregation + _update_daily_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) + + # Update monthly aggregation + _update_monthly_usage(db, tenant_id, timestamp, provider, total_tokens, estimated_cost) + + db.commit() + db.refresh(usage_event) + + logger.info( + f"Tracked usage: tenant={tenant_id}, provider={provider}, " + f"tokens={total_tokens}, cost=${estimated_cost:.6f}" + ) + + return usage_event + + +def _update_daily_usage( + db: Session, + tenant_id: str, + timestamp: datetime, + provider: str, + tokens: int, + cost: float +): + """Update daily usage aggregation.""" + date = timestamp.date() + date_start = datetime.combine(date, datetime.min.time()) + + daily = db.query(UsageDaily).filter( + and_( + UsageDaily.tenant_id == tenant_id, + UsageDaily.date == date_start + ) + ).first() + + if daily: + daily.total_requests += 1 + daily.total_tokens += tokens + daily.total_cost_usd += cost + if provider == "gemini": + daily.gemini_requests += 1 + elif provider == "openai": + daily.openai_requests += 1 + daily.updated_at = datetime.utcnow() + else: + daily = UsageDaily( + tenant_id=tenant_id, + date=date_start, + total_requests=1, + total_tokens=tokens, + total_cost_usd=cost, + gemini_requests=1 if provider == "gemini" else 0, + openai_requests=1 if provider == "openai" else 0 + ) + db.add(daily) + + +def _update_monthly_usage( + db: Session, + tenant_id: str, + timestamp: datetime, + provider: str, + tokens: int, + cost: float +): + """Update monthly usage aggregation.""" + year = timestamp.year + month = timestamp.month + + monthly = db.query(UsageMonthly).filter( + and_( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == year, + UsageMonthly.month == month + ) + ).first() + + if monthly: + monthly.total_requests += 1 + monthly.total_tokens += tokens + monthly.total_cost_usd += cost + if provider == "gemini": + monthly.gemini_requests += 1 + elif provider == "openai": + monthly.openai_requests += 1 + monthly.updated_at = datetime.utcnow() + else: + monthly = UsageMonthly( + tenant_id=tenant_id, + year=year, + month=month, + total_requests=1, + total_tokens=tokens, + total_cost_usd=cost, + gemini_requests=1 if provider == "gemini" else 0, + openai_requests=1 if provider == "openai" else 0 + ) + db.add(monthly) + diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000000000000000000000000000000000000..093e2aa07dbf518609daaac9b79d06f5d9df2c03 --- /dev/null +++ b/app/config.py @@ -0,0 +1,77 @@ +""" +Configuration settings for the RAG backend. +""" +from pydantic_settings import BaseSettings +from pathlib import Path +from typing import Optional +import os + + +class Settings(BaseSettings): + """Application settings with environment variable support.""" + + # App settings + APP_NAME: str = "ClientSphere RAG Backend" + DEBUG: bool = True + ENV: str = "dev" # "dev" or "prod" - controls tenant_id security + + # Paths + BASE_DIR: Path = Path(__file__).parent.parent + DATA_DIR: Path = BASE_DIR / "data" + UPLOADS_DIR: Path = DATA_DIR / "uploads" + PROCESSED_DIR: Path = DATA_DIR / "processed" + VECTORDB_DIR: Path = DATA_DIR / "vectordb" + + # Chunking settings (optimized for retrieval quality) + CHUNK_SIZE: int = 600 # tokens (increased for better context) + CHUNK_OVERLAP: int = 150 # tokens (increased for better continuity) + MIN_CHUNK_SIZE: int = 100 # minimum tokens per chunk (increased to avoid tiny chunks) + + # Embedding settings + EMBEDDING_MODEL: str = "all-MiniLM-L6-v2" # Fast, good quality + EMBEDDING_DIMENSION: int = 384 + + # Vector store settings + COLLECTION_NAME: str = "clientsphere_kb" + + # Retrieval settings (optimized for maximum confidence) + TOP_K: int = 10 # Number of chunks to retrieve (increased to maximize chance of finding strong matches) + SIMILARITY_THRESHOLD: float = 0.15 # Minimum similarity score (0-1) - lowered to include more potentially relevant chunks + SIMILARITY_THRESHOLD_STRICT: float = 0.45 # Strict threshold for answer generation (anti-hallucination) + + # LLM settings + LLM_PROVIDER: str = "gemini" # Options: "gemini", "openai" + GEMINI_API_KEY: Optional[str] = None + OPENAI_API_KEY: Optional[str] = None + GEMINI_MODEL: str = "gemini-1.5-flash" # Use latest stable model + OPENAI_MODEL: str = "gpt-3.5-turbo" + + # Response settings + MAX_CONTEXT_TOKENS: int = 2500 # Max tokens for context in prompt (reduced for focus) + TEMPERATURE: float = 0.0 # Zero temperature for maximum determinism (anti-hallucination) + REQUIRE_VERIFIER: bool = True # Always use verifier for hallucination prevention + + # Security settings + MAX_FILE_SIZE_MB: int = 50 # Maximum file size in MB + ALLOWED_ORIGINS: str = "*" # CORS allowed origins (comma-separated, use "*" for all) + RATE_LIMIT_PER_MINUTE: int = 60 # Rate limit per user per minute + JWT_SECRET: Optional[str] = None # JWT secret for authentication + + # Rate limiting + RATE_LIMIT_ENABLED: bool = True # Enable/disable rate limiting + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Create directories if they don't exist + self.UPLOADS_DIR.mkdir(parents=True, exist_ok=True) + self.PROCESSED_DIR.mkdir(parents=True, exist_ok=True) + self.VECTORDB_DIR.mkdir(parents=True, exist_ok=True) + + +# Global settings instance +settings = Settings() + diff --git a/app/db/__init__.py b/app/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cae2a10692977751489e2380a25106789a1b3d25 --- /dev/null +++ b/app/db/__init__.py @@ -0,0 +1,2 @@ +# Database module + diff --git a/app/db/__pycache__/__init__.cpython-313.pyc b/app/db/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e4546ddb536d3c73c9d6ccea735b061980d2cb3 Binary files /dev/null and b/app/db/__pycache__/__init__.cpython-313.pyc differ diff --git a/app/db/__pycache__/database.cpython-313.pyc b/app/db/__pycache__/database.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bc47a58f3e4cf8175f382a1b85d5d5ae853d18d Binary files /dev/null and b/app/db/__pycache__/database.cpython-313.pyc differ diff --git a/app/db/__pycache__/models.cpython-313.pyc b/app/db/__pycache__/models.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33cf42d384da0659676e20df5a2cd086b03c3361 Binary files /dev/null and b/app/db/__pycache__/models.cpython-313.pyc differ diff --git a/app/db/database.py b/app/db/database.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecb1d6017dc1c1f51fca7e96a3a7130b1d3947b --- /dev/null +++ b/app/db/database.py @@ -0,0 +1,53 @@ +""" +Database setup and session management. +Uses SQLAlchemy with SQLite for local dev, Postgres-compatible schema. +""" +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, Session +from pathlib import Path +import logging + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Database path +DB_DIR = settings.DATA_DIR / "billing" +DB_DIR.mkdir(parents=True, exist_ok=True) +DATABASE_URL = f"sqlite:///{DB_DIR / 'billing.db'}" + +# Create engine +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False}, # SQLite specific + echo=False # Set to True for SQL query logging +) + +# Session factory +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# Base class for models +Base = declarative_base() + + +def get_db() -> Session: + """Get database session (dependency for FastAPI).""" + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def init_db(): + """Initialize database tables.""" + Base.metadata.create_all(bind=engine) + logger.info("Database tables created/verified") + + +def drop_db(): + """Drop all tables (use with caution!).""" + Base.metadata.drop_all(bind=engine) + logger.warning("All database tables dropped") + diff --git a/app/db/models.py b/app/db/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f07c57061c84a15bff09f00a89467dcd41d9cb09 --- /dev/null +++ b/app/db/models.py @@ -0,0 +1,129 @@ +""" +Database models for billing and usage tracking. +""" +from sqlalchemy import Column, String, Integer, Float, DateTime, Boolean, ForeignKey, Text +from sqlalchemy.orm import relationship +from datetime import datetime +from typing import Optional + +from app.db.database import Base + + +class Tenant(Base): + """Tenant/organization model.""" + __tablename__ = "tenants" + + id = Column(String, primary_key=True, index=True) + name = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Relationships + plan = relationship("TenantPlan", back_populates="tenant", uselist=False) + usage_events = relationship("UsageEvent", back_populates="tenant") + daily_usage = relationship("UsageDaily", back_populates="tenant") + monthly_usage = relationship("UsageMonthly", back_populates="tenant") + + +class TenantPlan(Base): + """Tenant subscription plan.""" + __tablename__ = "tenant_plans" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), unique=True, nullable=False, index=True) + plan_name = Column(String, nullable=False, index=True) # "starter", "growth", "pro" + monthly_chat_limit = Column(Integer, nullable=False) # -1 for unlimited + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Relationships + tenant = relationship("Tenant", back_populates="plan") + + +class UsageEvent(Base): + """Individual usage event (each /chat request).""" + __tablename__ = "usage_events" + + id = Column(Integer, primary_key=True, autoincrement=True) + request_id = Column(String, unique=True, nullable=False, index=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + user_id = Column(String, nullable=False, index=True) + kb_id = Column(String, nullable=False) + + # LLM details + provider = Column(String, nullable=False) # "gemini" or "openai" + model = Column(String, nullable=False) + + # Token usage + prompt_tokens = Column(Integer, nullable=False, default=0) + completion_tokens = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + + # Cost tracking + estimated_cost_usd = Column(Float, nullable=False, default=0.0) + + # Timestamp + request_timestamp = Column(DateTime, default=datetime.utcnow, nullable=False, index=True) + + # Relationships + tenant = relationship("Tenant", back_populates="usage_events") + + +class UsageDaily(Base): + """Daily aggregated usage per tenant.""" + __tablename__ = "usage_daily" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + date = Column(DateTime, nullable=False, index=True) + + # Aggregated metrics + total_requests = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + total_cost_usd = Column(Float, nullable=False, default=0.0) + + # Provider breakdown + gemini_requests = Column(Integer, nullable=False, default=0) + openai_requests = Column(Integer, nullable=False, default=0) + + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Unique constraint: one record per tenant per day + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + # Relationships + tenant = relationship("Tenant", back_populates="daily_usage") + + +class UsageMonthly(Base): + """Monthly aggregated usage per tenant.""" + __tablename__ = "usage_monthly" + + id = Column(Integer, primary_key=True, autoincrement=True) + tenant_id = Column(String, ForeignKey("tenants.id"), nullable=False, index=True) + year = Column(Integer, nullable=False, index=True) + month = Column(Integer, nullable=False, index=True) # 1-12 + + # Aggregated metrics + total_requests = Column(Integer, nullable=False, default=0) + total_tokens = Column(Integer, nullable=False, default=0) + total_cost_usd = Column(Float, nullable=False, default=0.0) + + # Provider breakdown + gemini_requests = Column(Integer, nullable=False, default=0) + openai_requests = Column(Integer, nullable=False, default=0) + + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + # Unique constraint: one record per tenant per month + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + # Relationships + tenant = relationship("Tenant", back_populates="monthly_usage") + diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe9293084087a8cf8545b2702f91549db453aa6 --- /dev/null +++ b/app/main.py @@ -0,0 +1,1039 @@ +""" +FastAPI application for ClientSphere RAG Backend. +Provides endpoints for knowledge base management and chat. +""" +from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends +from fastapi.middleware.cors import CORSMiddleware +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from pathlib import Path +import shutil +import uuid +from datetime import datetime +from typing import Optional +import logging + +from app.config import settings +from app.middleware.auth import get_auth_context, require_auth +from app.middleware.rate_limit import ( + limiter, + get_tenant_rate_limit_key, + RateLimitExceeded, + _rate_limit_exceeded_handler +) +from app.models.schemas import ( + UploadResponse, + ChatRequest, + ChatResponse, + KnowledgeBaseStats, + HealthResponse, + DocumentStatus, + Citation, +) +from app.models.billing_schemas import ( + UsageResponse, + PlanLimitsResponse, + CostReportResponse, + SetPlanRequest +) +from app.rag.ingest import parser +from app.rag.chunking import chunker +from app.rag.embeddings import get_embedding_service +from app.rag.vectorstore import get_vector_store +from app.rag.retrieval import get_retrieval_service +from app.rag.answer import get_answer_service +from app.db.database import get_db, init_db +from app.billing.quota import check_quota, ensure_tenant_exists +from app.billing.usage_tracker import track_usage + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI( + title=settings.APP_NAME, + description="RAG-based customer support chatbot API", + version="1.0.0", +) + +# Initialize database on startup +@app.on_event("startup") +async def startup_event(): + """Initialize database on application startup.""" + init_db() + logger.info("Database initialized") + +# Configure CORS - SECURITY: Restrict in production +if settings.ALLOWED_ORIGINS == "*": + allowed_origins = ["*"] +else: + # Split by comma and strip whitespace + allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()] + +# Default to allowing localhost if no origins specified +if not allowed_origins or allowed_origins == ["*"]: + allowed_origins = ["*"] # Allow all in dev mode + +logger.info(f"CORS configured with origins: {allowed_origins}") + +app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight + allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers +) + +# Configure rate limiting +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + +# Add exception handler for validation errors +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors with detailed logging.""" + body = await request.body() + logger.error(f"Request validation error: {exc.errors()}") + logger.error(f"Request body (raw): {body}") + logger.error(f"Request headers: {dict(request.headers)}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')} + ) + +# Add exception handler for validation errors +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors with detailed logging.""" + logger.error(f"Request validation error: {exc.errors()}") + logger.error(f"Request body: {await request.body()}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": str(await request.body())} + ) + + +# ============== Health & Status Endpoints ============== + +@app.get("/", response_model=HealthResponse) +async def root(): + """Root endpoint with basic info.""" + return HealthResponse( + status="ok", + version="1.0.0", + vector_db_connected=True, + llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + ) + + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint.""" + try: + vector_store = get_vector_store() + stats = vector_store.get_stats() + + return HealthResponse( + status="healthy", + version="1.0.0", + vector_db_connected=True, + llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + ) + except Exception as e: + logger.error(f"Health check failed: {e}") + return HealthResponse( + status="unhealthy", + version="1.0.0", + vector_db_connected=False, + llm_configured=False + ) + + +@app.get("/health/live") +async def liveness(): + """Kubernetes liveness probe - always returns alive.""" + return {"status": "alive"} + + +@app.get("/health/ready") +async def readiness(): + """Kubernetes readiness probe - checks dependencies.""" + checks = { + "vector_db": False, + "llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) + } + + # Check vector DB connection + try: + vector_store = get_vector_store() + vector_store.get_stats() + checks["vector_db"] = True + except Exception as e: + logger.warning(f"Vector DB check failed: {e}") + checks["vector_db"] = False + + # All checks must pass + if all(checks.values()): + return {"status": "ready", "checks": checks} + else: + from fastapi import HTTPException + raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks}) + + +# ============== Knowledge Base Endpoints ============== + +@app.post("/kb/upload", response_model=UploadResponse) +@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key) +async def upload_document( + background_tasks: BackgroundTasks, + request: Request, + file: UploadFile = File(...), + tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod + user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod + kb_id: str = Form(...) +): + """ + Upload a document to the knowledge base. + + - Saves file to disk + - Parses and chunks the document + - Generates embeddings + - Stores in vector database + """ + # SECURITY: Extract tenant_id from auth token in production + if settings.ENV == "prod": + auth_context = await require_auth(request) + tenant_id = auth_context.get("tenant_id") + if not tenant_id: + raise HTTPException( + status_code=403, + detail="tenant_id must come from authentication token in production mode" + ) + elif not tenant_id: + raise HTTPException( + status_code=400, + detail="tenant_id is required" + ) + + # Validate file type + file_ext = Path(file.filename).suffix.lower() + if file_ext not in parser.SUPPORTED_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}" + ) + + # Validate file size (SECURITY) + file.file.seek(0, 2) # Seek to end + file_size = file.file.tell() + file.file.seek(0) # Reset to start + max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024 + if file_size > max_size_bytes: + raise HTTPException( + status_code=400, + detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB" + ) + + # Generate document ID + doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}" + + # Save file to uploads directory + upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}" + try: + with open(upload_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + logger.info(f"Saved file: {upload_path}") + except Exception as e: + logger.error(f"Error saving file: {e}") + raise HTTPException(status_code=500, detail="Failed to save file") + + # Process document in background + background_tasks.add_task( + process_document, + upload_path, + tenant_id, # CRITICAL: Multi-tenant isolation + user_id, + kb_id, + file.filename, + doc_id + ) + + return UploadResponse( + success=True, + message="Document upload started. Processing in background.", + document_id=doc_id, + file_name=file.filename, + chunks_created=0, + status=DocumentStatus.PROCESSING + ) + + +async def process_document( + file_path: Path, + tenant_id: str, # CRITICAL: Multi-tenant isolation + user_id: str, + kb_id: str, + original_filename: str, + document_id: str +): + """ + Background task to process an uploaded document. + """ + try: + logger.info(f"Processing document: {original_filename}") + + # Parse document + parsed_doc = parser.parse(file_path) + logger.info(f"Parsed document: {len(parsed_doc.text)} characters") + + # Chunk document + chunks = chunker.chunk_text( + parsed_doc.text, + page_numbers=parsed_doc.page_map + ) + logger.info(f"Created {len(chunks)} chunks") + + if not chunks: + logger.warning(f"No chunks created from {original_filename}") + return + + # Create metadata for each chunk + metadatas = [] + chunk_ids = [] + chunk_texts = [] + + for chunk in chunks: + metadata = chunker.create_chunk_metadata( + chunk=chunk, + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + file_name=original_filename, + file_type=parsed_doc.file_type, + total_chunks=len(chunks), + document_id=document_id + ) + metadatas.append(metadata) + chunk_ids.append(metadata["chunk_id"]) + chunk_texts.append(chunk.content) + + # Generate embeddings + embedding_service = get_embedding_service() + embeddings = embedding_service.embed_texts(chunk_texts) + logger.info(f"Generated {len(embeddings)} embeddings") + + # Store in vector database + vector_store = get_vector_store() + vector_store.add_documents( + documents=chunk_texts, + embeddings=embeddings, + metadatas=metadatas, + ids=chunk_ids + ) + + logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored") + + except Exception as e: + logger.error(f"Error processing document {original_filename}: {e}") + raise + + +@app.get("/kb/stats", response_model=KnowledgeBaseStats) +async def get_kb_stats( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None # Optional in dev, ignored in prod +): + """Get statistics for a knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + + try: + vector_store = get_vector_store() + stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id) + + return KnowledgeBaseStats( + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + total_documents=len(stats.get("file_names", [])), + total_chunks=stats.get("total_chunks", 0), + file_names=stats.get("file_names", []), + last_updated=datetime.utcnow() + ) + except Exception as e: + logger.error(f"Error getting KB stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.delete("/kb/document") +async def delete_document( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None, # Optional in dev, ignored in prod + file_name: Optional[str] = None +): + """Delete a document from the knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id or not file_name: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)" + ) + + try: + vector_store = get_vector_store() + deleted = vector_store.delete_by_filter({ + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id, + "file_name": file_name + }) + + return { + "success": True, + "message": f"Deleted {deleted} chunks", + "file_name": file_name + } + except Exception as e: + logger.error(f"Error deleting document: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.delete("/kb/clear") +async def clear_kb( + request: Request, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None # Optional in dev, ignored in prod +): + """Clear all documents from a knowledge base.""" + # SECURITY: Get tenant_id and user_id from auth context + auth_context = await get_auth_context(request) + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + tenant_id = tenant_id_from_auth + user_id = user_id_from_auth + else: + tenant_id = tenant_id or tenant_id_from_auth + user_id = user_id or user_id_from_auth + if not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + try: + vector_store = get_vector_store() + deleted = vector_store.delete_by_filter({ + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id + }) + + return { + "success": True, + "message": f"Cleared knowledge base. Deleted {deleted} chunks.", + "kb_id": kb_id + } + except Exception as e: + logger.error(f"Error clearing KB: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============== Chat Endpoints ============== + +@app.post("/chat", response_model=ChatResponse) +@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key) +async def chat(chat_request: ChatRequest, request: Request): + """ + Process a chat message using RAG. + + - Retrieves relevant context from knowledge base + - Generates answer using LLM + - Returns answer with citations + """ + conversation_id = "unknown" + try: + logger.info(f"=== CHAT REQUEST RECEIVED ===") + logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}") + logger.info(f"Request headers: {dict(request.headers)}") + + # SECURITY: Get tenant_id and user_id from auth context + # In PROD: MUST come from JWT token (never from request body) + try: + auth_context = await get_auth_context(request) + except Exception as e: + logger.error(f"Error getting auth context: {e}", exc_info=True) + raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}") + + tenant_id_from_auth = auth_context.get("tenant_id") + user_id_from_auth = auth_context.get("user_id") + + if settings.ENV == "prod": + if not tenant_id_from_auth or not user_id_from_auth: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + # Override request values with auth context (security enforcement) + chat_request.tenant_id = tenant_id_from_auth + chat_request.user_id = user_id_from_auth + else: + # DEV mode: use from request if provided, otherwise from auth context + if not chat_request.tenant_id: + chat_request.tenant_id = tenant_id_from_auth + if not chat_request.user_id: + chat_request.user_id = user_id_from_auth + if not chat_request.tenant_id or not chat_request.user_id: + raise HTTPException( + status_code=400, + detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)" + ) + + # Log without PII in production + if settings.ENV == "prod": + logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}") + else: + logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...") + + # Generate conversation ID if not provided + conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}" + + # Get database session + try: + db = next(get_db()) + except Exception as e: + logger.error(f"Database connection error: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") + + try: + # Ensure tenant exists in billing DB + ensure_tenant_exists(db, chat_request.tenant_id) + + # Check quota BEFORE making LLM call + has_quota, quota_error = check_quota(db, chat_request.tenant_id) + if not has_quota: + logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}") + raise HTTPException( + status_code=402, + detail=quota_error or "AI quota exceeded. Upgrade your plan." + ) + + # Retrieve relevant context + retrieval_service = get_retrieval_service() + results, confidence, has_relevant = retrieval_service.retrieve( + query=chat_request.question, + tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=chat_request.kb_id, + user_id=chat_request.user_id + ) + + logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}") + + # Format context for LLM + context, citations_info = retrieval_service.get_context_for_llm(results) + + logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}") + + # Generate answer + answer_service = get_answer_service() + answer_result = answer_service.generate_answer( + question=chat_request.question, + context=context, + citations_info=citations_info, + confidence=confidence, + has_relevant_results=has_relevant + ) + + # Track usage if LLM was called (usage info present) + usage_info = answer_result.get("usage") + if usage_info: + try: + track_usage( + db=db, + tenant_id=chat_request.tenant_id, + user_id=chat_request.user_id, + kb_id=chat_request.kb_id, + provider=settings.LLM_PROVIDER, + model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL), + prompt_tokens=usage_info.get("prompt_tokens", 0), + completion_tokens=usage_info.get("completion_tokens", 0) + ) + except Exception as e: + logger.error(f"Failed to track usage: {e}", exc_info=True) + # Don't fail the request if usage tracking fails + + # Build metadata with refusal info + metadata = { + "chunks_retrieved": len(results), + "kb_id": chat_request.kb_id + } + if "refused" in answer_result: + metadata["refused"] = answer_result["refused"] + if "refusal_reason" in answer_result: + metadata["refusal_reason"] = answer_result["refusal_reason"] + if "verifier_passed" in answer_result: + metadata["verifier_passed"] = answer_result["verifier_passed"] + + return ChatResponse( + success=True, + answer=answer_result["answer"], + citations=answer_result["citations"], + confidence=answer_result["confidence"], + from_knowledge_base=answer_result["from_knowledge_base"], + escalation_suggested=answer_result["escalation_suggested"], + conversation_id=conversation_id, + refused=answer_result.get("refused", False), + metadata=metadata + ) + + except ValueError as e: + # API key or configuration error + error_msg = str(e) + logger.error(f"Configuration error: {error_msg}") + if "API key" in error_msg.lower(): + return ChatResponse( + success=False, + answer="โš ๏ธ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": error_msg, "error_type": "configuration"} + ) + else: + return ChatResponse( + success=False, + answer=f"Configuration error: {error_msg}", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": error_msg} + ) + except HTTPException: + # Re-raise HTTP exceptions (they have proper status codes) + raise + except Exception as e: + logger.error(f"Chat error: {e}", exc_info=True) + logger.error(f"Error type: {type(e).__name__}", exc_info=True) + return ChatResponse( + success=False, + answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": str(e), "error_type": type(e).__name__} + ) + except HTTPException: + # Re-raise HTTP exceptions from outer try block + raise + except Exception as e: + logger.error(f"Outer chat error: {e}", exc_info=True) + return ChatResponse( + success=False, + answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", + citations=[], + confidence=0.0, + from_knowledge_base=False, + escalation_suggested=True, + conversation_id=conversation_id, + metadata={"error": str(e), "error_type": type(e).__name__} + ) + + +# ============== Utility Endpoints ============== + +@app.get("/kb/search") +@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key) +async def search_kb( + request: Request, + query: str, + tenant_id: Optional[str] = None, # Optional in dev, ignored in prod + kb_id: Optional[str] = None, + user_id: Optional[str] = None, # Optional in dev, ignored in prod + top_k: int = 5 +): + """ + Search the knowledge base without generating an answer. + Useful for debugging and testing retrieval. + """ + # SECURITY: Extract tenant_id from auth token in production + if settings.ENV == "prod": + auth_context = await require_auth(request) + tenant_id = auth_context.get("tenant_id") + user_id = auth_context.get("user_id") + if not tenant_id or not user_id: + raise HTTPException( + status_code=403, + detail="tenant_id and user_id must come from authentication token in production mode" + ) + elif not tenant_id or not kb_id or not user_id: + raise HTTPException( + status_code=400, + detail="tenant_id, kb_id, and user_id are required" + ) + + try: + retrieval_service = get_retrieval_service() + results, confidence, has_relevant = retrieval_service.retrieve( + query=query, + tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation + kb_id=kb_id, + user_id=user_id, + top_k=top_k + ) + + return { + "success": True, + "results": [ + { + "chunk_id": r.chunk_id, + "content": r.content[:500] + "..." if len(r.content) > 500 else r.content, + "metadata": r.metadata, + "similarity_score": r.similarity_score + } + for r in results + ], + "confidence": confidence, + "has_relevant_results": has_relevant + } + except Exception as e: + logger.error(f"Search error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============== Billing & Usage Endpoints ============== + +@app.get("/billing/usage", response_model=UsageResponse) +async def get_usage( + request: Request, + range: str = "month", # "day" or "month" + year: Optional[int] = None, + month: Optional[int] = None, + day: Optional[int] = None +): + """ + Get usage statistics for the current tenant. + + Args: + range: "day" or "month" + year: Year (optional, defaults to current) + month: Month 1-12 (optional, defaults to current) + day: Day 1-31 (optional, defaults to current, only for range="day") + """ + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.db.models import UsageDaily, UsageMonthly + from datetime import datetime + from calendar import monthrange + + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + if range == "day": + target_day = day or now.day + date_start = datetime(target_year, target_month, target_day) + + daily = db.query(UsageDaily).filter( + UsageDaily.tenant_id == tenant_id, + UsageDaily.date == date_start + ).first() + + if not daily: + return UsageResponse( + tenant_id=tenant_id, + period="day", + total_requests=0, + total_tokens=0, + total_cost_usd=0.0, + start_date=date_start, + end_date=date_start + ) + + return UsageResponse( + tenant_id=tenant_id, + period="day", + total_requests=daily.total_requests, + total_tokens=daily.total_tokens, + total_cost_usd=daily.total_cost_usd, + gemini_requests=daily.gemini_requests, + openai_requests=daily.openai_requests, + start_date=daily.date, + end_date=daily.date + ) + else: # month + monthly = db.query(UsageMonthly).filter( + UsageMonthly.tenant_id == tenant_id, + UsageMonthly.year == target_year, + UsageMonthly.month == target_month + ).first() + + if not monthly: + # Calculate date range for the month + _, last_day = monthrange(target_year, target_month) + start_date = datetime(target_year, target_month, 1) + end_date = datetime(target_year, target_month, last_day) + + return UsageResponse( + tenant_id=tenant_id, + period="month", + total_requests=0, + total_tokens=0, + total_cost_usd=0.0, + start_date=start_date, + end_date=end_date + ) + + _, last_day = monthrange(monthly.year, monthly.month) + start_date = datetime(monthly.year, monthly.month, 1) + end_date = datetime(monthly.year, monthly.month, last_day) + + return UsageResponse( + tenant_id=tenant_id, + period="month", + total_requests=monthly.total_requests, + total_tokens=monthly.total_tokens, + total_cost_usd=monthly.total_cost_usd, + gemini_requests=monthly.gemini_requests, + openai_requests=monthly.openai_requests, + start_date=start_date, + end_date=end_date + ) + except Exception as e: + logger.error(f"Error getting usage: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/billing/limits", response_model=PlanLimitsResponse) +async def get_limits(request: Request): + """Get current plan limits and usage for the tenant.""" + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.billing.quota import get_tenant_plan, get_monthly_usage + from datetime import datetime + + plan = get_tenant_plan(db, tenant_id) + if not plan: + # Default to starter + plan_name = "starter" + monthly_limit = 500 + else: + plan_name = plan.plan_name + monthly_limit = plan.monthly_chat_limit + + # Get current month usage + now = datetime.utcnow() + monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) + current_usage = monthly_usage.total_requests if monthly_usage else 0 + + remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage) + + return PlanLimitsResponse( + tenant_id=tenant_id, + plan_name=plan_name, + monthly_chat_limit=monthly_limit, + current_month_usage=current_usage, + remaining_chats=remaining + ) + except Exception as e: + logger.error(f"Error getting limits: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/billing/plan") +async def set_plan(request_body: SetPlanRequest, http_request: Request): + """ + Set tenant's subscription plan (admin only in production). + + In dev mode, allows any tenant to set their plan. + In prod mode, should be restricted to admin users. + """ + # Get tenant from auth + auth_context = await get_auth_context(http_request) + auth_tenant_id = auth_context.get("tenant_id") + + # In prod, verify admin role (placeholder - implement actual admin check) + if settings.ENV == "prod": + # TODO: Add admin role check + if auth_tenant_id != request_body.tenant_id: + raise HTTPException(status_code=403, detail="Cannot set plan for other tenants") + + db = next(get_db()) + + try: + from app.billing.quota import set_tenant_plan + + plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name) + + return { + "success": True, + "tenant_id": request_body.tenant_id, + "plan_name": plan.plan_name, + "monthly_chat_limit": plan.monthly_chat_limit + } + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error setting plan: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/billing/cost-report", response_model=CostReportResponse) +async def get_cost_report( + request: Request, + range: str = "month", + year: Optional[int] = None, + month: Optional[int] = None +): + """Get cost report with breakdown by provider and model.""" + # Get tenant from auth + auth_context = await get_auth_context(request) + tenant_id = auth_context.get("tenant_id") + + if not tenant_id: + raise HTTPException(status_code=403, detail="tenant_id required") + + db = next(get_db()) + + try: + from app.db.models import UsageEvent + from datetime import datetime + from sqlalchemy import func, and_ + + now = datetime.utcnow() + target_year = year or now.year + target_month = month or now.month + + # Query usage events for the period + if range == "month": + query = db.query(UsageEvent).filter( + and_( + UsageEvent.tenant_id == tenant_id, + func.extract('year', UsageEvent.request_timestamp) == target_year, + func.extract('month', UsageEvent.request_timestamp) == target_month + ) + ) + else: # all time + query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id) + + events = query.all() + + # Calculate totals + total_cost = sum(e.estimated_cost_usd for e in events) + total_requests = len(events) + total_tokens = sum(e.total_tokens for e in events) + + # Breakdown by provider + breakdown_by_provider = {} + for event in events: + provider = event.provider + if provider not in breakdown_by_provider: + breakdown_by_provider[provider] = { + "requests": 0, + "tokens": 0, + "cost_usd": 0.0 + } + breakdown_by_provider[provider]["requests"] += 1 + breakdown_by_provider[provider]["tokens"] += event.total_tokens + breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd + + # Breakdown by model + breakdown_by_model = {} + for event in events: + model = event.model + if model not in breakdown_by_model: + breakdown_by_model[model] = { + "requests": 0, + "tokens": 0, + "cost_usd": 0.0 + } + breakdown_by_model[model]["requests"] += 1 + breakdown_by_model[model]["tokens"] += event.total_tokens + breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd + + return CostReportResponse( + tenant_id=tenant_id, + period=range, + total_cost_usd=total_cost, + total_requests=total_requests, + total_tokens=total_tokens, + breakdown_by_provider=breakdown_by_provider, + breakdown_by_model=breakdown_by_model + ) + except Exception as e: + logger.error(f"Error getting cost report: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) + diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c669ae02552cd455662b0442a17d5819ccf09f0d --- /dev/null +++ b/app/middleware/__init__.py @@ -0,0 +1,13 @@ +""" +Middleware for authentication, rate limiting, etc. +""" +from app.middleware.auth import verify_tenant_access, get_tenant_from_token, require_auth + +__all__ = [ + "verify_tenant_access", + "get_tenant_from_token", + "require_auth", +] + + + diff --git a/app/middleware/__pycache__/__init__.cpython-313.pyc b/app/middleware/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3d66b94a8536c5164d1b6f2815522d1f7adbe3e Binary files /dev/null and b/app/middleware/__pycache__/__init__.cpython-313.pyc differ diff --git a/app/middleware/__pycache__/auth.cpython-313.pyc b/app/middleware/__pycache__/auth.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91aa6e10d78b1ff9efc4a7abbf5a1c4c094d72ba Binary files /dev/null and b/app/middleware/__pycache__/auth.cpython-313.pyc differ diff --git a/app/middleware/__pycache__/rate_limit.cpython-313.pyc b/app/middleware/__pycache__/rate_limit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..719e95084d7c04e1142d97f811d06d6fe5d68a79 Binary files /dev/null and b/app/middleware/__pycache__/rate_limit.cpython-313.pyc differ diff --git a/app/middleware/auth.py b/app/middleware/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..337c397b266e8dd463d0335f878f4998287fb3c2 --- /dev/null +++ b/app/middleware/auth.py @@ -0,0 +1,212 @@ +""" +Authentication and authorization middleware. +Extracts tenant_id from JWT token in production mode. +""" +from fastapi import Request, HTTPException, Depends +from typing import Optional, Dict, Any +import logging +from jose import JWTError, jwt + +from app.config import settings + +logger = logging.getLogger(__name__) + + +async def verify_tenant_access( + request: Request, + tenant_id: str, + user_id: str +) -> bool: + """ + Verify that the user has access to the specified tenant. + + TODO: Implement actual authentication logic: + 1. Extract JWT token from Authorization header + 2. Verify token signature + 3. Extract user_id and tenant_id from token + 4. Verify user belongs to tenant + 5. Check permissions + + Args: + request: FastAPI request object + tenant_id: Tenant ID from request + user_id: User ID from request + + Returns: + True if access is granted, False otherwise + """ + # TODO: Implement actual authentication + # For now, this is a placeholder that always returns True + # In production, you MUST implement proper auth + + # Example implementation: + # token = request.headers.get("Authorization", "").replace("Bearer ", "") + # if not token: + # return False + # + # decoded = verify_jwt_token(token) + # if decoded["user_id"] != user_id or decoded["tenant_id"] != tenant_id: + # return False + # + # return True + + logger.warning("โš ๏ธ Authentication middleware not implemented - using placeholder") + return True + + +def get_tenant_from_token(request: Request) -> Optional[str]: + """ + Extract tenant_id from authentication token. + + In production mode, extracts tenant_id from JWT token. + In dev mode, returns None (allows request tenant_id). + + Args: + request: FastAPI request object + + Returns: + Tenant ID if found in token, None otherwise + """ + if settings.ENV == "dev": + # Dev mode: allow request tenant_id + return None + + # Production mode: extract from JWT + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + logger.warning("Missing or invalid Authorization header") + return None + + token = auth_header.replace("Bearer ", "").strip() + if not token: + return None + + try: + # TODO: Replace with your actual JWT secret key + # For now, this is a placeholder that expects a specific token format + # In production, you should: + # 1. Get JWT_SECRET from environment + # 2. Verify token signature + # 3. Extract tenant_id from token payload + + # Example implementation (replace with your actual JWT verification): + # JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key") + # decoded = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) + # return decoded.get("tenant_id") + + # Placeholder: Try to decode without verification (for testing) + # In production, you MUST verify the signature + try: + decoded = jwt.decode(token, options={"verify_signature": False}) + tenant_id = decoded.get("tenant_id") + if tenant_id: + logger.info(f"Extracted tenant_id from token: {tenant_id}") + return tenant_id + except jwt.DecodeError: + logger.warning("Failed to decode JWT token") + return None + + except Exception as e: + logger.error(f"Error extracting tenant from token: {e}") + return None + + return None + + +async def get_auth_context(request: Request) -> Dict[str, Any]: + """ + Get authentication context from request. + + DEV mode: + - Allows X-Tenant-Id and X-User-Id headers + - Falls back to defaults if missing + + PROD mode: + - Requires Authorization: Bearer + - Verifies JWT using JWT_SECRET + - Extracts tenant_id and user_id from token claims + - NEVER accepts tenant_id from request body/query params + + Args: + request: FastAPI request object + + Returns: + Dictionary with user_id and tenant_id + + Raises: + HTTPException: If authentication fails (production mode only) + """ + if settings.ENV == "dev": + # Dev mode: allow headers, fallback to defaults + tenant_id = request.headers.get("X-Tenant-Id", "dev_tenant") + user_id = request.headers.get("X-User-Id", "dev_user") + return { + "user_id": user_id, + "tenant_id": tenant_id + } + + # Production mode: require JWT token + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Authentication required. Provide valid Bearer token in Authorization header." + ) + + token = auth_header.replace("Bearer ", "").strip() + if not token: + raise HTTPException( + status_code=401, + detail="Invalid token format." + ) + + # Verify JWT token + if not settings.JWT_SECRET: + logger.error("JWT_SECRET not configured in production mode") + raise HTTPException( + status_code=500, + detail="Server configuration error: JWT_SECRET not set" + ) + + try: + decoded = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"]) + + user_id = decoded.get("user_id") or decoded.get("sub") + tenant_id = decoded.get("tenant_id") + + if not user_id or not tenant_id: + raise HTTPException( + status_code=401, + detail="Token missing required claims (user_id, tenant_id)." + ) + + logger.info(f"Authenticated user: {user_id}, tenant: {tenant_id}") + return { + "user_id": user_id, + "tenant_id": tenant_id, + "email": decoded.get("email"), + "role": decoded.get("role") + } + + except JWTError as e: + logger.warning(f"JWT verification failed: {e}") + raise HTTPException( + status_code=401, + detail="Invalid or expired token." + ) + except Exception as e: + logger.error(f"Auth error: {e}", exc_info=True) + raise HTTPException( + status_code=401, + detail="Authentication failed." + ) + + +# FastAPI dependency for easy use in endpoints +async def require_auth(request: Request) -> Dict[str, Any]: + """ + FastAPI dependency for requiring authentication. + Alias for get_auth_context for backward compatibility. + """ + return await get_auth_context(request) + diff --git a/app/middleware/rate_limit.py b/app/middleware/rate_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5c6768e4a53f40556415f00cf35c9587811fc9 --- /dev/null +++ b/app/middleware/rate_limit.py @@ -0,0 +1,40 @@ +""" +Rate limiting middleware using slowapi. +""" +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from fastapi import Request +import logging + +from app.config import settings + +logger = logging.getLogger(__name__) + +# Initialize limiter with default limits (can be overridden per endpoint) +limiter = Limiter( + key_func=get_remote_address, + default_limits=["1000/hour"] if settings.RATE_LIMIT_ENABLED else [] +) + + +def get_tenant_rate_limit_key(request: Request) -> str: + """ + Get rate limit key based on tenant_id from headers (dev) or IP (prod). + + Note: This is a sync function called by slowapi, so we can't await async functions. + In dev mode, we extract tenant_id from X-Tenant-Id header. + In prod mode, we fall back to IP address (rate limiting happens before auth). + """ + # Try to get tenant_id from headers (works in dev mode) + tenant_id = request.headers.get("X-Tenant-Id") + if tenant_id: + return f"tenant:{tenant_id}" + + # Fallback to IP address (for prod mode or if no header) + return get_remote_address(request) + + +# Export limiter and key function +__all__ = ["limiter", "get_tenant_rate_limit_key", "RateLimitExceeded", "_rate_limit_exceeded_handler"] + diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0410d8f5408171132e9f3086da9407c2607fec70 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,33 @@ +""" +Pydantic models for the RAG backend. +""" +from app.models.schemas import ( + DocumentStatus, + ChunkMetadata, + DocumentChunk, + UploadRequest, + UploadResponse, + Citation, + ChatRequest, + ChatResponse, + RetrievalResult, + KnowledgeBaseStats, + HealthResponse, +) + +__all__ = [ + "DocumentStatus", + "ChunkMetadata", + "DocumentChunk", + "UploadRequest", + "UploadResponse", + "Citation", + "ChatRequest", + "ChatResponse", + "RetrievalResult", + "KnowledgeBaseStats", + "HealthResponse", +] + + + diff --git a/app/models/__pycache__/__init__.cpython-313.pyc b/app/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95b6ff7ec9ff159d99735ecb8e2ef5315b1b3da0 Binary files /dev/null and b/app/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/app/models/__pycache__/billing_schemas.cpython-313.pyc b/app/models/__pycache__/billing_schemas.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95849d35dff01c78c13c9077b0b54e56108d2ea9 Binary files /dev/null and b/app/models/__pycache__/billing_schemas.cpython-313.pyc differ diff --git a/app/models/__pycache__/schemas.cpython-313.pyc b/app/models/__pycache__/schemas.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..866862af35e8ce00c238dd94b6ce7fba372a1123 Binary files /dev/null and b/app/models/__pycache__/schemas.cpython-313.pyc differ diff --git a/app/models/billing_schemas.py b/app/models/billing_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..51c26fec17e4ac1497ef17bda4f1eb73a0dbf425 --- /dev/null +++ b/app/models/billing_schemas.py @@ -0,0 +1,46 @@ +""" +Pydantic schemas for billing endpoints. +""" +from pydantic import BaseModel +from typing import Optional, List +from datetime import datetime + + +class UsageResponse(BaseModel): + """Usage statistics response.""" + tenant_id: str + period: str # "day" or "month" + total_requests: int + total_tokens: int + total_cost_usd: float + gemini_requests: int = 0 + openai_requests: int = 0 + start_date: datetime + end_date: datetime + + +class PlanLimitsResponse(BaseModel): + """Current plan limits response.""" + tenant_id: str + plan_name: str + monthly_chat_limit: int # -1 for unlimited + current_month_usage: int + remaining_chats: Optional[int] # None if unlimited + + +class CostReportResponse(BaseModel): + """Cost report response.""" + tenant_id: str + period: str + total_cost_usd: float + total_requests: int + total_tokens: int + breakdown_by_provider: dict + breakdown_by_model: dict + + +class SetPlanRequest(BaseModel): + """Request to set tenant plan.""" + tenant_id: str + plan_name: str # "starter", "growth", or "pro" + diff --git a/app/models/schemas.py b/app/models/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..d99a50081cf254b745a87c68b0652f1af62cf113 --- /dev/null +++ b/app/models/schemas.py @@ -0,0 +1,112 @@ +""" +Pydantic models for API request/response schemas. +""" +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime +from enum import Enum + + +class DocumentStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class ChunkMetadata(BaseModel): + """Metadata for a document chunk.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + kb_id: str + user_id: str + file_name: str + file_type: str + chunk_id: str + chunk_index: int + page_number: Optional[int] = None + total_chunks: int + document_id: Optional[str] = None # Track original document + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class DocumentChunk(BaseModel): + """A chunk of text with metadata.""" + id: str + content: str + metadata: ChunkMetadata + embedding: Optional[List[float]] = None + + +class UploadRequest(BaseModel): + """Request model for file upload.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + user_id: str + kb_id: str + + +class UploadResponse(BaseModel): + """Response model for file upload.""" + success: bool + message: str + document_id: Optional[str] = None + file_name: str + chunks_created: int = 0 + status: DocumentStatus = DocumentStatus.PENDING + + +class Citation(BaseModel): + """Citation reference for an answer.""" + file_name: str + chunk_id: str + page_number: Optional[int] = None + relevance_score: float + excerpt: str # Short excerpt from the chunk + + +class ChatRequest(BaseModel): + """Request model for chat endpoint.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + user_id: str + kb_id: str + conversation_id: Optional[str] = None + question: str + + +class ChatResponse(BaseModel): + """Response model for chat endpoint.""" + success: bool + answer: str + citations: List[Citation] = [] + confidence: float # 0-1 score + from_knowledge_base: bool = True + escalation_suggested: bool = False + conversation_id: str + metadata: Dict[str, Any] = {} + + +class RetrievalResult(BaseModel): + """Result from vector store retrieval.""" + chunk_id: str + content: str + metadata: Dict[str, Any] + similarity_score: float + + +class KnowledgeBaseStats(BaseModel): + """Statistics for a knowledge base.""" + tenant_id: str # CRITICAL: Multi-tenant isolation + kb_id: str + user_id: str + total_documents: int + total_chunks: int + file_names: List[str] + last_updated: Optional[datetime] = None + + +class HealthResponse(BaseModel): + """Health check response.""" + status: str + version: str = "1.0.0" + vector_db_connected: bool + llm_configured: bool + diff --git a/app/rag/__init__.py b/app/rag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95b30c0bc030b289b6aa3aab2f26057c7f70cf4b --- /dev/null +++ b/app/rag/__init__.py @@ -0,0 +1,27 @@ +""" +RAG (Retrieval-Augmented Generation) pipeline modules. +""" +from app.rag.ingest import parser, DocumentParser +from app.rag.chunking import chunker, DocumentChunker +from app.rag.embeddings import get_embedding_service, EmbeddingService +from app.rag.vectorstore import get_vector_store, VectorStore +from app.rag.retrieval import get_retrieval_service, RetrievalService +from app.rag.answer import get_answer_service, AnswerService + +__all__ = [ + "parser", + "DocumentParser", + "chunker", + "DocumentChunker", + "get_embedding_service", + "EmbeddingService", + "get_vector_store", + "VectorStore", + "get_retrieval_service", + "RetrievalService", + "get_answer_service", + "AnswerService", +] + + + diff --git a/app/rag/__pycache__/__init__.cpython-313.pyc b/app/rag/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c093719d5648f029ecb92720f214f11f478add69 Binary files /dev/null and b/app/rag/__pycache__/__init__.cpython-313.pyc differ diff --git a/app/rag/__pycache__/answer.cpython-313.pyc b/app/rag/__pycache__/answer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2880fb72e015542920dec13dc30fb5486762860 Binary files /dev/null and b/app/rag/__pycache__/answer.cpython-313.pyc differ diff --git a/app/rag/__pycache__/chunking.cpython-313.pyc b/app/rag/__pycache__/chunking.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b45bdda267276672a8be6b3b6200d83595264e82 Binary files /dev/null and b/app/rag/__pycache__/chunking.cpython-313.pyc differ diff --git a/app/rag/__pycache__/embeddings.cpython-313.pyc b/app/rag/__pycache__/embeddings.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06676746587b8db8e25df1e535bb09d87f1a7c8c Binary files /dev/null and b/app/rag/__pycache__/embeddings.cpython-313.pyc differ diff --git a/app/rag/__pycache__/ingest.cpython-313.pyc b/app/rag/__pycache__/ingest.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a4eac74a2ef070cedb6c097df085ce4e31658e1 Binary files /dev/null and b/app/rag/__pycache__/ingest.cpython-313.pyc differ diff --git a/app/rag/__pycache__/intent.cpython-313.pyc b/app/rag/__pycache__/intent.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa36906543907097e1f1b98233c17a0a32ebaffa Binary files /dev/null and b/app/rag/__pycache__/intent.cpython-313.pyc differ diff --git a/app/rag/__pycache__/prompts.cpython-313.pyc b/app/rag/__pycache__/prompts.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde3096934253d63debe1720ded8d38248d59035 Binary files /dev/null and b/app/rag/__pycache__/prompts.cpython-313.pyc differ diff --git a/app/rag/__pycache__/retrieval.cpython-313.pyc b/app/rag/__pycache__/retrieval.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f6703630063ae9231587ce8afa11c58290d888 Binary files /dev/null and b/app/rag/__pycache__/retrieval.cpython-313.pyc differ diff --git a/app/rag/__pycache__/vectorstore.cpython-313.pyc b/app/rag/__pycache__/vectorstore.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..576a0de4f6e29a61e5f3c2991c4e1b0f86e39953 Binary files /dev/null and b/app/rag/__pycache__/vectorstore.cpython-313.pyc differ diff --git a/app/rag/__pycache__/verifier.cpython-313.pyc b/app/rag/__pycache__/verifier.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbf584dd66f715dd30b8a3ee6f0fe59cbc83ab31 Binary files /dev/null and b/app/rag/__pycache__/verifier.cpython-313.pyc differ diff --git a/app/rag/answer.py b/app/rag/answer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc461cacce2f3e973b47237996696f032e313b52 --- /dev/null +++ b/app/rag/answer.py @@ -0,0 +1,444 @@ +""" +Answer generation using LLM with RAG context. +Supports Gemini and OpenAI as providers. +""" +import google.generativeai as genai +from openai import OpenAI +from typing import Optional, Dict, Any, List +import logging +import os +import re + +from app.config import settings +from app.rag.prompts import ( + format_rag_prompt, + format_draft_prompt, + get_no_context_response, + get_low_confidence_response +) +from app.rag.verifier import get_verifier_service +from app.rag.intent import detect_intents +from app.models.schemas import Citation +from abc import ABC, abstractmethod + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class LLMProvider(ABC): + """Base class for LLM providers.""" + + @abstractmethod + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response from prompts.""" + raise NotImplementedError + + @abstractmethod + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """ + Generate response and return usage information. + + Returns: + (response_text, usage_info) + usage_info: dict with keys: prompt_tokens, completion_tokens, total_tokens, model_used + """ + raise NotImplementedError + + +class GeminiProvider(LLMProvider): + """Google Gemini LLM provider.""" + + def __init__(self, api_key: Optional[str] = None, model: str = None): + self.api_key = api_key or settings.GEMINI_API_KEY or os.getenv("GEMINI_API_KEY") + # Default to gemini-1.5-flash if not specified + self.model = model or settings.GEMINI_MODEL or "gemini-1.5-flash" + + if not self.api_key: + raise ValueError("Gemini API key not configured. Set GEMINI_API_KEY environment variable.") + + genai.configure(api_key=self.api_key) + + # Don't initialize client here - do it lazily in generate() to handle errors better + self._client = None + logger.info(f"Gemini provider initialized (will use model: {self.model})") + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response using Gemini.""" + text, _ = self.generate_with_usage(system_prompt, user_prompt) + return text + + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """Generate response using Gemini and return usage info.""" + # Combine system and user prompts for Gemini + full_prompt = f"{system_prompt}\n\n{user_prompt}" + + # Estimate prompt tokens (rough: 1 token โ‰ˆ 4 chars) + prompt_tokens = len(full_prompt) // 4 + + # Try to list available models first, then use the first available one + # If that fails, try common model names + models_to_try = [] + + # First, try to get available models + try: + available_models = genai.list_models() + model_names = [m.name for m in available_models if 'generateContent' in m.supported_generation_methods] + if model_names: + # Extract just the model name (remove 'models/' prefix if present) + clean_names = [name.split('/')[-1] if '/' in name else name for name in model_names] + models_to_try.extend(clean_names[:3]) # Use first 3 available models + logger.info(f"Found {len(model_names)} available models, will try: {clean_names[:3]}") + except Exception as e: + logger.warning(f"Could not list available models: {e}, using fallback list") + + # Fallback to common model names if listing failed + if not models_to_try: + models_to_try = ["gemini-pro", "gemini-1.0-pro", "models/gemini-pro"] + + # Add configured model if different + if self.model and self.model not in models_to_try: + models_to_try.insert(0, self.model) + + # Remove duplicates while preserving order + seen = set() + models_to_try = [m for m in models_to_try if not (m in seen or seen.add(m))] + + last_error = None + for model_name in models_to_try: + try: + logger.info(f"Attempting to generate with model: {model_name}") + # Create a new client for this model + client = genai.GenerativeModel(model_name) + response = client.generate_content( + full_prompt, + generation_config=genai.types.GenerationConfig( + temperature=settings.TEMPERATURE, + max_output_tokens=1024, + ) + ) + + # Extract response text + response_text = response.text + + # Try to get usage info from response + usage_info = { + "prompt_tokens": prompt_tokens, + "completion_tokens": len(response_text) // 4, # Estimate + "total_tokens": prompt_tokens + (len(response_text) // 4), + "model_used": model_name.split('/')[-1] if '/' in model_name else model_name + } + + # Try to get actual usage from response if available + if hasattr(response, 'usage_metadata'): + usage_metadata = response.usage_metadata + if hasattr(usage_metadata, 'prompt_token_count'): + usage_info["prompt_tokens"] = usage_metadata.prompt_token_count + if hasattr(usage_metadata, 'candidates_token_count'): + usage_info["completion_tokens"] = usage_metadata.candidates_token_count + if hasattr(usage_metadata, 'total_token_count'): + usage_info["total_tokens"] = usage_metadata.total_token_count + + if model_name != self.model: + logger.info(f"โœ… Successfully used model: {model_name}") + + return response_text, usage_info + except Exception as e: + error_str = str(e).lower() + last_error = e + if "not found" in error_str or "not supported" in error_str or "404" in error_str: + logger.warning(f"Model {model_name} failed: {e}") + continue # Try next model + else: + # Different error (not model not found), re-raise + logger.error(f"Gemini generation error with {model_name}: {e}") + raise + + # All models failed - return a helpful error message + error_msg = f"All Gemini model attempts failed. Last error: {last_error}. Please check your GEMINI_API_KEY and ensure it has access to Gemini models." + logger.error(error_msg) + raise Exception(error_msg) + + +class OpenAIProvider(LLMProvider): + """OpenAI LLM provider.""" + + def __init__(self, api_key: Optional[str] = None, model: str = settings.OPENAI_MODEL): + self.api_key = api_key or settings.OPENAI_API_KEY or os.getenv("OPENAI_API_KEY") + self.model = model + + if not self.api_key: + raise ValueError("OpenAI API key not configured. Set OPENAI_API_KEY environment variable.") + + self.client = OpenAI(api_key=self.api_key) + logger.info(f"OpenAI provider initialized with model: {model}") + + def generate(self, system_prompt: str, user_prompt: str) -> str: + """Generate response using OpenAI.""" + text, _ = self.generate_with_usage(system_prompt, user_prompt) + return text + + def generate_with_usage(self, system_prompt: str, user_prompt: str) -> tuple[str, dict]: + """Generate response using OpenAI and return usage info.""" + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ], + temperature=settings.TEMPERATURE, + max_tokens=1024 + ) + + response_text = response.choices[0].message.content + + # Extract usage info from OpenAI response + usage_info = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "model_used": self.model + } + + return response_text, usage_info + except Exception as e: + logger.error(f"OpenAI generation error: {e}") + raise + + +class AnswerService: + """ + Generates answers using RAG context and LLM. + Handles confidence scoring and citation extraction. + """ + + # Confidence thresholds + HIGH_CONFIDENCE_THRESHOLD = 0.5 + LOW_CONFIDENCE_THRESHOLD = 0.20 # Lowered to match similarity threshold + STRICT_CONFIDENCE_THRESHOLD = 0.30 # Strict threshold for answer generation (lowered from 0.45 to allow good retrieval results) + + def __init__(self, provider: str = settings.LLM_PROVIDER): + """ + Initialize the answer service. + + Args: + provider: LLM provider to use ("gemini" or "openai") + """ + self.provider_name = provider + self._provider: Optional[LLMProvider] = None + + @property + def provider(self) -> LLMProvider: + """Lazy load the LLM provider.""" + if self._provider is None: + if self.provider_name == "gemini": + self._provider = GeminiProvider() + elif self.provider_name == "openai": + self._provider = OpenAIProvider() + else: + raise ValueError(f"Unknown LLM provider: {self.provider_name}") + return self._provider + + def generate_answer( + self, + question: str, + context: str, + citations_info: List[Dict[str, Any]], + confidence: float, + has_relevant_results: bool, + use_verifier: bool = None # None = use config default + ) -> Dict[str, Any]: + """ + Generate an answer based on retrieved context with mandatory verifier. + + Args: + question: User's question + context: Retrieved context from knowledge base + citations_info: List of citation information + confidence: Average confidence score from retrieval + has_relevant_results: Whether any results passed the threshold + use_verifier: Whether to use verifier mode (None = use config default) + + Returns: + Dictionary with answer, citations, confidence, and metadata + """ + # Determine if verifier should be used (mandatory by default) + if use_verifier is None: + use_verifier = settings.REQUIRE_VERIFIER + + # GATE 1: No relevant results found - REFUSE + if not has_relevant_results or not context: + logger.info("No relevant context found, returning no-context response") + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": 0.0, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True + } + + # GATE 2: Strict confidence threshold - REFUSE if below strict threshold + if confidence < self.STRICT_CONFIDENCE_THRESHOLD: + logger.warning( + f"Confidence ({confidence:.3f}) below strict threshold ({self.STRICT_CONFIDENCE_THRESHOLD}), " + f"REFUSING to answer to prevent hallucination" + ) + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": confidence, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True, + "refusal_reason": f"Confidence {confidence:.3f} below strict threshold {self.STRICT_CONFIDENCE_THRESHOLD}" + } + + # GATE 3: Intent-based gating for specific intents (integration, API, etc.) + intents = detect_intents(question) + if "integration" in intents or "api" in question.lower(): + # For integration/API questions, require strong relevance + if confidence < 0.50: # Even stricter for integration questions + logger.warning( + f"Integration/API question with low confidence ({confidence:.3f}), " + f"REFUSING to prevent hallucination" + ) + return { + "answer": get_no_context_response(), + "citations": [], + "confidence": confidence, + "from_knowledge_base": False, + "escalation_suggested": True, + "refused": True, + "refusal_reason": "Integration/API questions require higher confidence" + } + + # Case 3: Passed all gates - generate answer with MANDATORY verifier + logger.info(f"Confidence ({confidence:.3f}) passed all gates, generating answer with verifier={use_verifier}") + + try: + # VERIFIER MODE IS MANDATORY: Draft โ†’ Verify โ†’ Final + # Step 1: Generate draft answer with usage tracking + draft_system, draft_user = format_draft_prompt(context, question) + draft_answer, usage_info = self.provider.generate_with_usage(draft_system, draft_user) + logger.info("Generated draft answer, running verifier...") + + # Step 2: Verify draft answer (MANDATORY) + verifier = get_verifier_service() + verification = verifier.verify_answer( + draft_answer=draft_answer, + context=context, + citations_info=citations_info + ) + + # Step 3: Handle verification result + if verification["pass"]: + logger.info("โœ… Verifier PASSED - Using draft answer") + citations = self._extract_citations(draft_answer, citations_info) + return { + "answer": draft_answer, + "citations": citations, + "confidence": confidence, + "from_knowledge_base": True, + "escalation_suggested": confidence < self.HIGH_CONFIDENCE_THRESHOLD, + "verifier_passed": True, + "refused": False, + "usage": usage_info # Include usage info for tracking + } + else: + # Verifier failed - REFUSE to answer + issues = verification.get('issues', []) + unsupported = verification.get('unsupported_claims', []) + logger.warning( + f"โŒ Verifier FAILED - Issues: {issues}, " + f"Unsupported claims: {unsupported}" + ) + refusal_message = ( + get_no_context_response() + + "\n\n**Note:** The system could not verify the accuracy of the information needed to answer your question. " + "This helps prevent providing incorrect information." + ) + return { + "answer": refusal_message, + "citations": [], + "confidence": 0.0, + "from_knowledge_base": False, + "escalation_suggested": True, + "verifier_passed": False, + "verifier_issues": issues, + "unsupported_claims": unsupported, + "refused": True, + "refusal_reason": "Verifier failed: claims not supported by context", + "usage": usage_info # Still track usage even if refused + } + + except ValueError as e: + # Configuration errors (e.g., missing API key) + error_msg = str(e) + logger.error(f"Configuration error in answer generation: {error_msg}") + if "API key" in error_msg.lower(): + raise ValueError(f"LLM API key not configured: {error_msg}") + raise + except Exception as e: + logger.error(f"Error generating answer: {e}", exc_info=True) + # Re-raise to be handled by the endpoint + raise + + def _extract_citations( + self, + answer: str, + citations_info: List[Dict[str, Any]] + ) -> List[Citation]: + """ + Extract and format citations from the answer. + + Args: + answer: Generated answer with [Source X] references + citations_info: Available citation information + + Returns: + List of Citation objects + """ + citations = [] + + # Find all [Source X] references in the answer + source_pattern = r'\[Source\s*(\d+)\]' + matches = re.findall(source_pattern, answer) + referenced_indices = set(int(m) for m in matches) + + # Build citation objects for referenced sources + for info in citations_info: + if info.get("index") in referenced_indices: + citations.append(Citation( + file_name=info.get("file_name", "Unknown"), + chunk_id=info.get("chunk_id", ""), + page_number=info.get("page_number"), + relevance_score=info.get("similarity_score", 0.0), + excerpt=info.get("excerpt", "") + )) + + # If no specific citations found but we have context, include top sources + if not citations and citations_info: + for info in citations_info[:3]: # Top 3 sources + citations.append(Citation( + file_name=info.get("file_name", "Unknown"), + chunk_id=info.get("chunk_id", ""), + page_number=info.get("page_number"), + relevance_score=info.get("similarity_score", 0.0), + excerpt=info.get("excerpt", "") + )) + + return citations + + +# Global answer service instance +_answer_service: Optional[AnswerService] = None + + +def get_answer_service() -> AnswerService: + """Get the global answer service instance.""" + global _answer_service + if _answer_service is None: + _answer_service = AnswerService() + return _answer_service + diff --git a/app/rag/chunking.py b/app/rag/chunking.py new file mode 100644 index 0000000000000000000000000000000000000000..dccd54e782174e1426ff6b3e962ba4c974649d04 --- /dev/null +++ b/app/rag/chunking.py @@ -0,0 +1,196 @@ +""" +Document chunking with overlap and metadata preservation. +""" +import tiktoken +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +import re +import uuid +from datetime import datetime + +from app.config import settings + + +@dataclass +class TextChunk: + """Represents a chunk of text with metadata.""" + content: str + chunk_index: int + start_char: int + end_char: int + page_number: Optional[int] = None + token_count: int = 0 + + +class DocumentChunker: + """ + Chunks documents into smaller pieces with overlap. + Uses tiktoken for accurate token counting. + """ + + def __init__( + self, + chunk_size: int = settings.CHUNK_SIZE, + chunk_overlap: int = settings.CHUNK_OVERLAP, + min_chunk_size: int = settings.MIN_CHUNK_SIZE + ): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_chunk_size = min_chunk_size + # Use cl100k_base encoding (same as GPT-4, good general purpose) + self.encoding = tiktoken.get_encoding("cl100k_base") + + def count_tokens(self, text: str) -> int: + """Count tokens in text.""" + return len(self.encoding.encode(text)) + + def _split_into_sentences(self, text: str) -> List[str]: + """Split text into sentences while preserving structure.""" + # Split on sentence boundaries but keep delimiters + sentence_endings = r'(?<=[.!?])\s+' + sentences = re.split(sentence_endings, text) + return [s.strip() for s in sentences if s.strip()] + + def _split_into_paragraphs(self, text: str) -> List[str]: + """Split text into paragraphs.""" + paragraphs = re.split(r'\n\s*\n', text) + return [p.strip() for p in paragraphs if p.strip()] + + def chunk_text( + self, + text: str, + page_numbers: Optional[Dict[int, int]] = None # char_position -> page_number + ) -> List[TextChunk]: + """ + Chunk text into smaller pieces with overlap. + + Args: + text: The text to chunk + page_numbers: Optional mapping of character positions to page numbers + + Returns: + List of TextChunk objects + """ + if not text.strip(): + return [] + + chunks = [] + current_chunk = "" + current_start = 0 + chunk_index = 0 + + # First, split into paragraphs for natural boundaries + paragraphs = self._split_into_paragraphs(text) + + char_position = 0 + for para in paragraphs: + para_tokens = self.count_tokens(para) + current_tokens = self.count_tokens(current_chunk) + + # If adding this paragraph exceeds chunk size + if current_tokens + para_tokens > self.chunk_size and current_chunk: + # Save current chunk if it meets minimum size + if current_tokens >= self.min_chunk_size: + page_num = None + if page_numbers: + # Find the page number for this chunk's start position + for pos, page in sorted(page_numbers.items()): + if pos <= current_start: + page_num = page + + chunks.append(TextChunk( + content=current_chunk.strip(), + chunk_index=chunk_index, + start_char=current_start, + end_char=char_position, + page_number=page_num, + token_count=current_tokens + )) + chunk_index += 1 + + # Start new chunk with overlap + overlap_text = self._get_overlap_text(current_chunk) + current_chunk = overlap_text + "\n\n" + para if overlap_text else para + current_start = char_position - len(overlap_text) if overlap_text else char_position + else: + # Add paragraph to current chunk + if current_chunk: + current_chunk += "\n\n" + para + else: + current_chunk = para + current_start = char_position + + char_position += len(para) + 2 # +2 for paragraph separator + + # Don't forget the last chunk + if current_chunk and self.count_tokens(current_chunk) >= self.min_chunk_size: + page_num = None + if page_numbers: + for pos, page in sorted(page_numbers.items()): + if pos <= current_start: + page_num = page + + chunks.append(TextChunk( + content=current_chunk.strip(), + chunk_index=chunk_index, + start_char=current_start, + end_char=len(text), + page_number=page_num, + token_count=self.count_tokens(current_chunk) + )) + + return chunks + + def _get_overlap_text(self, text: str) -> str: + """Get the overlap text from the end of a chunk.""" + sentences = self._split_into_sentences(text) + if not sentences: + return "" + + overlap = "" + tokens = 0 + + # Work backwards through sentences + for sentence in reversed(sentences): + sentence_tokens = self.count_tokens(sentence) + if tokens + sentence_tokens <= self.chunk_overlap: + overlap = sentence + " " + overlap if overlap else sentence + tokens += sentence_tokens + else: + break + + return overlap.strip() + + def create_chunk_metadata( + self, + chunk: TextChunk, + tenant_id: str, # CRITICAL: Multi-tenant isolation + kb_id: str, + user_id: str, + file_name: str, + file_type: str, + total_chunks: int, + document_id: Optional[str] = None + ) -> Dict[str, Any]: + """Create metadata dictionary for a chunk.""" + chunk_id = f"{tenant_id}_{kb_id}_{file_name}_{chunk.chunk_index}_{uuid.uuid4().hex[:8]}" + + return { + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id, + "file_name": file_name, + "file_type": file_type, + "chunk_id": chunk_id, + "chunk_index": chunk.chunk_index, + "page_number": chunk.page_number, + "total_chunks": total_chunks, + "token_count": chunk.token_count, + "document_id": document_id, # Track original document + "created_at": datetime.utcnow().isoformat() + } + + +# Global chunker instance +chunker = DocumentChunker() + diff --git a/app/rag/embeddings.py b/app/rag/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b77cf73d4b971e21cc5d41c068f2f885ea3191 --- /dev/null +++ b/app/rag/embeddings.py @@ -0,0 +1,145 @@ +""" +Embedding generation using Sentence Transformers. +Supports local models for privacy and offline use. +""" +from sentence_transformers import SentenceTransformer +from typing import List, Optional +import numpy as np +import logging +from functools import lru_cache + +from app.config import settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class EmbeddingService: + """ + Generates embeddings for text using Sentence Transformers. + Uses a lightweight model optimized for semantic search. + """ + + def __init__(self, model_name: str = settings.EMBEDDING_MODEL): + """ + Initialize the embedding service. + + Args: + model_name: Name of the Sentence Transformer model to use + """ + self.model_name = model_name + self._model: Optional[SentenceTransformer] = None + logger.info(f"Embedding service initialized with model: {model_name}") + + @property + def model(self) -> SentenceTransformer: + """Lazy load the model.""" + if self._model is None: + logger.info(f"Loading embedding model: {self.model_name}") + self._model = SentenceTransformer(self.model_name) + logger.info(f"Model loaded. Embedding dimension: {self._model.get_sentence_embedding_dimension()}") + return self._model + + def embed_text(self, text: str) -> List[float]: + """ + Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + List of floats representing the embedding vector + """ + if not text.strip(): + raise ValueError("Cannot embed empty text") + + embedding = self.model.encode(text, convert_to_numpy=True) + return embedding.tolist() + + def embed_texts(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: + """ + Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + batch_size: Batch size for processing + + Returns: + List of embedding vectors + """ + if not texts: + return [] + + # Filter out empty texts + valid_texts = [t for t in texts if t.strip()] + if len(valid_texts) != len(texts): + logger.warning(f"Filtered out {len(texts) - len(valid_texts)} empty texts") + + logger.info(f"Generating embeddings for {len(valid_texts)} texts") + + embeddings = self.model.encode( + valid_texts, + batch_size=batch_size, + show_progress_bar=len(valid_texts) > 100, + convert_to_numpy=True + ) + + return embeddings.tolist() + + def embed_query(self, query: str) -> List[float]: + """ + Generate embedding for a search query. + Some models have different embeddings for queries vs documents. + + Args: + query: Search query to embed + + Returns: + Embedding vector for the query + """ + # For most models, query embedding is the same as document embedding + # But we keep this separate for models that differentiate + return self.embed_text(query) + + def get_dimension(self) -> int: + """Get the embedding dimension.""" + return self.model.get_sentence_embedding_dimension() + + def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: + """ + Compute cosine similarity between two embeddings. + + Args: + embedding1: First embedding vector + embedding2: Second embedding vector + + Returns: + Cosine similarity score (0-1) + """ + vec1 = np.array(embedding1) + vec2 = np.array(embedding2) + + # Cosine similarity + dot_product = np.dot(vec1, vec2) + norm1 = np.linalg.norm(vec1) + norm2 = np.linalg.norm(vec2) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + return float(dot_product / (norm1 * norm2)) + + +# Global embedding service instance (lazy loaded) +_embedding_service: Optional[EmbeddingService] = None + + +def get_embedding_service() -> EmbeddingService: + """Get the global embedding service instance.""" + global _embedding_service + if _embedding_service is None: + _embedding_service = EmbeddingService() + return _embedding_service + + + diff --git a/app/rag/ingest.py b/app/rag/ingest.py new file mode 100644 index 0000000000000000000000000000000000000000..348ac7102d07af1ba5ce8f758225dc5d89365aff --- /dev/null +++ b/app/rag/ingest.py @@ -0,0 +1,231 @@ +""" +Document ingestion and parsing pipeline. +Supports PDF, DOCX, TXT, and Markdown files. +""" +import fitz # PyMuPDF +from docx import Document as DocxDocument +import markdown +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +import chardet +import re +from dataclasses import dataclass +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class ParsedDocument: + """Represents a parsed document with text and metadata.""" + text: str + file_name: str + file_type: str + page_count: Optional[int] = None + page_map: Optional[Dict[int, int]] = None # char_position -> page_number + metadata: Dict[str, Any] = None + + +class DocumentParser: + """ + Parses various document formats into plain text. + Preserves page information for citations. + """ + + SUPPORTED_EXTENSIONS = {'.pdf', '.docx', '.doc', '.txt', '.md', '.markdown'} + + def __init__(self): + self.parsers = { + '.pdf': self._parse_pdf, + '.docx': self._parse_docx, + '.doc': self._parse_docx, # Try docx parser for doc files + '.txt': self._parse_text, + '.md': self._parse_markdown, + '.markdown': self._parse_markdown, + } + + def parse(self, file_path: Path) -> ParsedDocument: + """ + Parse a document file into text. + + Args: + file_path: Path to the document file + + Returns: + ParsedDocument with extracted text and metadata + """ + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + ext = file_path.suffix.lower() + if ext not in self.SUPPORTED_EXTENSIONS: + raise ValueError(f"Unsupported file type: {ext}") + + parser = self.parsers.get(ext) + if not parser: + raise ValueError(f"No parser available for: {ext}") + + logger.info(f"Parsing document: {file_path.name} ({ext})") + return parser(file_path) + + def _parse_pdf(self, file_path: Path) -> ParsedDocument: + """Parse PDF file with page tracking.""" + try: + doc = fitz.open(file_path) + text_parts = [] + page_map = {} + current_pos = 0 + + for page_num, page in enumerate(doc, start=1): + page_text = page.get_text("text") + if page_text.strip(): + # Record where this page starts + page_map[current_pos] = page_num + text_parts.append(page_text) + current_pos += len(page_text) + 2 # +2 for separator + + page_count = len(doc) + doc.close() + + full_text = "\n\n".join(text_parts) + + return ParsedDocument( + text=self._clean_text(full_text), + file_name=file_path.name, + file_type="pdf", + page_count=page_count, + page_map=page_map, + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing PDF {file_path}: {e}") + raise + + def _parse_docx(self, file_path: Path) -> ParsedDocument: + """Parse DOCX file.""" + try: + doc = DocxDocument(file_path) + paragraphs = [] + + for para in doc.paragraphs: + if para.text.strip(): + paragraphs.append(para.text) + + # Also extract text from tables + for table in doc.tables: + for row in table.rows: + row_text = [] + for cell in row.cells: + if cell.text.strip(): + row_text.append(cell.text.strip()) + if row_text: + paragraphs.append(" | ".join(row_text)) + + full_text = "\n\n".join(paragraphs) + + return ParsedDocument( + text=self._clean_text(full_text), + file_name=file_path.name, + file_type="docx", + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing DOCX {file_path}: {e}") + raise + + def _parse_text(self, file_path: Path) -> ParsedDocument: + """Parse plain text file with encoding detection.""" + try: + # Detect encoding + with open(file_path, 'rb') as f: + raw_data = f.read() + detected = chardet.detect(raw_data) + encoding = detected.get('encoding', 'utf-8') + + # Read with detected encoding + with open(file_path, 'r', encoding=encoding, errors='replace') as f: + text = f.read() + + return ParsedDocument( + text=self._clean_text(text), + file_name=file_path.name, + file_type="txt", + metadata={"source": str(file_path), "encoding": encoding} + ) + except Exception as e: + logger.error(f"Error parsing text file {file_path}: {e}") + raise + + def _parse_markdown(self, file_path: Path) -> ParsedDocument: + """Parse Markdown file, converting to plain text.""" + try: + with open(file_path, 'r', encoding='utf-8', errors='replace') as f: + md_content = f.read() + + # Convert markdown to HTML, then strip tags + html = markdown.markdown(md_content) + text = self._strip_html_tags(html) + + # Also keep the original markdown structure for better context + # Remove markdown syntax but keep structure + clean_md = self._clean_markdown(md_content) + + return ParsedDocument( + text=self._clean_text(clean_md), + file_name=file_path.name, + file_type="markdown", + metadata={"source": str(file_path)} + ) + except Exception as e: + logger.error(f"Error parsing Markdown {file_path}: {e}") + raise + + def _clean_text(self, text: str) -> str: + """Clean and normalize text.""" + # Replace multiple whitespace with single space + text = re.sub(r'[ \t]+', ' ', text) + # Replace multiple newlines with double newline + text = re.sub(r'\n{3,}', '\n\n', text) + # Remove leading/trailing whitespace from lines + lines = [line.strip() for line in text.split('\n')] + text = '\n'.join(lines) + # Remove leading/trailing whitespace + return text.strip() + + def _strip_html_tags(self, html: str) -> str: + """Remove HTML tags from text.""" + clean = re.sub(r'<[^>]+>', '', html) + return clean + + def _clean_markdown(self, md_text: str) -> str: + """Clean markdown syntax while preserving structure.""" + # Remove code blocks but keep content + md_text = re.sub(r'```[\s\S]*?```', '', md_text) + md_text = re.sub(r'`([^`]+)`', r'\1', md_text) + + # Convert headers to plain text with emphasis + md_text = re.sub(r'^#{1,6}\s+(.+)$', r'\1:', md_text, flags=re.MULTILINE) + + # Remove bold/italic markers + md_text = re.sub(r'\*\*([^*]+)\*\*', r'\1', md_text) + md_text = re.sub(r'\*([^*]+)\*', r'\1', md_text) + md_text = re.sub(r'__([^_]+)__', r'\1', md_text) + md_text = re.sub(r'_([^_]+)_', r'\1', md_text) + + # Remove links but keep text + md_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', md_text) + + # Remove images + md_text = re.sub(r'!\[([^\]]*)\]\([^)]+\)', '', md_text) + + # Convert lists to plain text + md_text = re.sub(r'^[\*\-\+]\s+', 'โ€ข ', md_text, flags=re.MULTILINE) + md_text = re.sub(r'^\d+\.\s+', '', md_text, flags=re.MULTILINE) + + return md_text + + +# Global parser instance +parser = DocumentParser() + diff --git a/app/rag/intent.py b/app/rag/intent.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a7eb814c6f1ed8b99cf385068535a8c345dc17 --- /dev/null +++ b/app/rag/intent.py @@ -0,0 +1,153 @@ +""" +Intent detection module for RAG pipeline. +Detects user intent from queries to enable intent-based gating. +""" +import re +from typing import List, Dict, Set +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Intent keywords mapping +INTENT_KEYWORDS: Dict[str, List[str]] = { + "integration": [ + "integrate", "integration", "api", "connect", "connection", "webhook", + "shopify", "woocommerce", "stripe", "paypal", "payment gateway", + "whatsapp", "telegram", "slack", "zapier", "ifttt", "automation" + ], + "billing": [ + "billing", "invoice", "payment", "subscription", "plan", "pricing", + "cost", "price", "charge", "fee", "refund", "cancel", "renew" + ], + "account": [ + "account", "profile", "settings", "preferences", "user", "login", + "signup", "register", "authentication", "auth" + ], + "password_reset": [ + "password", "reset", "forgot", "change password", "update password", + "password reset link", "expire", "expiry" + ], + "pricing": [ + "pricing", "price", "plan", "cost", "subscription", "tier", "starter", + "pro", "enterprise", "monthly", "yearly", "billing" + ], + "general": [] # Catch-all for general queries +} + + +def detect_intents(query: str) -> List[str]: + """ + Detect intents from a user query. + + Args: + query: User's question + + Returns: + List of detected intent labels (e.g., ["integration", "billing"]) + """ + query_lower = query.lower() + detected = [] + + for intent, keywords in INTENT_KEYWORDS.items(): + if intent == "general": + continue # Skip general, it's a catch-all + + # Check if any keyword matches + for keyword in keywords: + # Use word boundary matching for better accuracy + pattern = r'\b' + re.escape(keyword.lower()) + r'\b' + if re.search(pattern, query_lower): + detected.append(intent) + break # Only add intent once + + # If no specific intent detected, return general + if not detected: + detected = ["general"] + + logger.info(f"Detected intents for query '{query[:50]}...': {detected}") + return detected + + +def get_intent_keywords(intents: List[str]) -> Set[str]: + """ + Get all keywords for a list of intents. + + Args: + intents: List of intent labels + + Returns: + Set of keywords for those intents + """ + keywords = set() + for intent in intents: + if intent in INTENT_KEYWORDS: + keywords.update(INTENT_KEYWORDS[intent]) + return keywords + + +def check_direct_match( + query: str, + retrieved_chunks: List[str], + intent_keywords: Set[str] = None +) -> bool: + """ + Check if at least one retrieved chunk contains direct matches for query intent. + + Args: + query: User's question + retrieved_chunks: List of retrieved chunk texts + intent_keywords: Optional set of intent keywords to check + + Returns: + True if at least one chunk has direct match, False otherwise + """ + if not retrieved_chunks: + return False + + query_lower = query.lower() + query_words = set(re.findall(r'\b\w+\b', query_lower)) + + # Get intent keywords if not provided + if intent_keywords is None: + intents = detect_intents(query) + intent_keywords = get_intent_keywords(intents) + + # Check each chunk for direct matches + for chunk in retrieved_chunks: + chunk_lower = chunk.lower() + + # Check 1: Intent keywords must be present in chunk + if intent_keywords: + intent_found = any( + re.search(r'\b' + re.escape(kw.lower()) + r'\b', chunk_lower) + for kw in intent_keywords + ) + if not intent_found: + continue # Skip this chunk if no intent keywords + + # Check 2: At least 2-3 important query words should be in chunk + # (excluding common stop words) + stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", + "to", "of", "and", "or", "but", "in", "on", "at", "for", + "with", "how", "what", "when", "where", "why", "do", "does"} + important_words = query_words - stop_words + + if len(important_words) >= 2: + # Need at least 2 important words to match + matches = sum(1 for word in important_words if word in chunk_lower) + if matches >= 2: + logger.info(f"Direct match found: {matches} important words matched in chunk") + return True + elif len(important_words) == 1: + # Single important word - require exact phrase match + for word in important_words: + if re.search(r'\b' + re.escape(word) + r'\b', chunk_lower): + logger.info(f"Direct match found: single important word '{word}' matched") + return True + + logger.warning("No direct match found in retrieved chunks") + return False + + diff --git a/app/rag/prompts.py b/app/rag/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..cbc50f895ebc132341fef900d239c57cdb7fad75 --- /dev/null +++ b/app/rag/prompts.py @@ -0,0 +1,162 @@ +""" +Prompt templates for RAG-based question answering. +Implements strict anti-hallucination rules. +""" + +# System prompt for RAG-based answering - ENHANCED for strict anti-hallucination +RAG_SYSTEM_PROMPT = """You are a helpful customer support assistant for ClientSphere. Your ONLY job is to answer questions based STRICTLY on the provided context. + +## CRITICAL RULES - YOU MUST FOLLOW THESE (NO EXCEPTIONS): + +1. **ONLY use information from the provided context** - Do NOT use any prior knowledge, training data, or general knowledge. If it's not in the context, you don't know it. + +2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. + +3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. It's better to refuse than to hallucinate. + +4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. + +5. **Be concise and direct** - Answer the question without unnecessary elaboration or adding information not in the context. + +6. **If the question is unclear** - Ask for clarification rather than guessing what the user means. + +7. **For multi-part questions** - Address each part separately and cite sources for each. If any part isn't in the context, say so. + +8. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. + +9. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. + +10. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. + +## Response Format: +- Start with a direct answer to the question +- Include [Source X] citations inline where you use information +- End with a brief summary if the answer is complex +- If no relevant information: clearly state this and suggest contacting support + +## Example Good Response: +"Based on the documentation, the return policy allows returns within 30 days of purchase [Source 1]. Items must be in original packaging [Source 2]. For items purchased on sale, a 15-day window applies [Source 1]." + +## Example When Information Not Found: +"I couldn't find specific information about warranty extensions in the available documentation. I recommend contacting our support team at support@example.com for detailed warranty inquiries." +""" + +# User prompt template +RAG_USER_PROMPT_TEMPLATE = """## Context from Knowledge Base: + +{context} + +--- + +## User Question: +{question} + +--- + +Please answer the question using ONLY the information provided in the context above. Remember to cite sources using [Source X] notation. + +**IMPORTANT:** If the answer is not explicitly stated in the context above, you MUST say "I couldn't find this information in the knowledge base" and suggest contacting support. DO NOT attempt to answer from memory or general knowledge.""" + + +# Prompt for when no relevant context is found +NO_CONTEXT_RESPONSE = """I apologize, but I couldn't find relevant information in the knowledge base to answer your question. + +This could mean: +1. The topic isn't covered in the current documentation +2. The question might need to be rephrased for better matching + +**Recommended Actions:** +- Try rephrasing your question with different keywords +- Contact our support team directly for personalized assistance +- Check if there's additional documentation that might help + +Would you like me to help you with a different question, or would you prefer to connect with a human support agent?""" + + +# Prompt for low confidence responses +LOW_CONFIDENCE_RESPONSE = """I found some potentially relevant information, but I'm not confident it fully addresses your question. + +Based on what I found: {partial_answer} + +**However**, I recommend verifying this information with our support team, as the context may not fully cover your specific situation. + +Would you like me to connect you with a human support agent for more detailed assistance?""" + + +def format_rag_prompt(context: str, question: str) -> tuple: + """ + Format the RAG prompt for the LLM. + + Args: + context: Retrieved context from knowledge base + question: User's question + + Returns: + Tuple of (system_prompt, user_prompt) + """ + user_prompt = RAG_USER_PROMPT_TEMPLATE.format( + context=context, + question=question + ) + + return RAG_SYSTEM_PROMPT, user_prompt + + +def get_no_context_response() -> str: + """Get the response for when no context is found.""" + return NO_CONTEXT_RESPONSE + + +def get_low_confidence_response(partial_answer: str) -> str: + """Get the response for low confidence answers.""" + return LOW_CONFIDENCE_RESPONSE.format(partial_answer=partial_answer) + + +# Draft prompt for verifier mode (stricter than final prompt) +DRAFT_PROMPT_SYSTEM = """You are a customer support assistant. Generate a DRAFT answer based STRICTLY on the provided context. + +## CRITICAL RULES - THIS DRAFT WILL BE VERIFIED: + +1. **ONLY use information explicitly stated in the context** - Do NOT use any prior knowledge, training data, or general knowledge. + +2. **If the context doesn't contain the answer** - You MUST say: "I couldn't find this information in the knowledge base. Please contact our support team for assistance." DO NOT attempt to answer from memory. + +3. **NEVER guess, infer, or make up information** - If you're unsure, say you don't have that information. + +4. **ALWAYS cite your sources** - Every factual statement MUST include [Source X] notation. If you can't cite it, don't say it. + +5. **DO NOT use general knowledge** - Even if you "know" the answer from training, if it's not in the provided context, you cannot use it. + +6. **DO NOT extrapolate** - If the context says "30 days", don't say "about a month" or make assumptions. + +7. **Verify every claim** - Before stating anything, verify it exists in the provided context with a citation. + +Return ONLY the draft answer with citations. This will be verified for accuracy.""" + +DRAFT_PROMPT_USER = """## Context: +{context} + +## Question: +{question} + +Generate a DRAFT answer with citations. This will be verified for accuracy.""" + + +def format_draft_prompt(context: str, question: str) -> tuple: + """ + Format the draft prompt for initial answer generation. + + Args: + context: Retrieved context from knowledge base + question: User's question + + Returns: + Tuple of (system_prompt, user_prompt) + """ + user_prompt = DRAFT_PROMPT_USER.format( + context=context, + question=question + ) + + return DRAFT_PROMPT_SYSTEM, user_prompt + diff --git a/app/rag/retrieval.py b/app/rag/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1a7a3959dd478d5861bd2ca7b403e970c6ab19 --- /dev/null +++ b/app/rag/retrieval.py @@ -0,0 +1,242 @@ +""" +Retrieval pipeline with confidence scoring and filtering. +""" +from typing import List, Dict, Any, Optional, Tuple +import logging +import re + +from app.config import settings +from app.rag.embeddings import get_embedding_service +from app.rag.vectorstore import get_vector_store +from app.rag.intent import detect_intents, check_direct_match, get_intent_keywords +from app.models.schemas import RetrievalResult + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class RetrievalService: + """ + Handles document retrieval with confidence scoring. + Implements threshold-based filtering for quality control. + """ + + def __init__( + self, + top_k: int = settings.TOP_K, + similarity_threshold: float = settings.SIMILARITY_THRESHOLD + ): + """ + Initialize the retrieval service. + + Args: + top_k: Number of results to retrieve + similarity_threshold: Minimum similarity score to consider relevant + """ + self.top_k = top_k + self.similarity_threshold = similarity_threshold + self.embedding_service = get_embedding_service() + self.vector_store = get_vector_store() + + def retrieve( + self, + query: str, + tenant_id: str, # CRITICAL: Multi-tenant isolation + kb_id: str, + user_id: str, + top_k: Optional[int] = None + ) -> Tuple[List[RetrievalResult], float, bool]: + """ + Retrieve relevant documents for a query. + + Args: + query: User's question + tenant_id: Tenant ID for multi-tenant isolation (CRITICAL) + kb_id: Knowledge base ID to search + user_id: User ID for filtering + top_k: Optional override for number of results + + Returns: + Tuple of (results, average_confidence, has_relevant_results) + """ + k = top_k or self.top_k + + # Generate query embedding + logger.info(f"Generating embedding for query: {query[:50]}...") + query_embedding = self.embedding_service.embed_query(query) + + # Search vector store with filters - MUST include tenant_id for isolation + filter_dict = { + "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation + "kb_id": kb_id, + "user_id": user_id + } + + logger.info(f"Searching vector store with filters: {filter_dict}") + raw_results = self.vector_store.search( + query_embedding=query_embedding, + top_k=k, + filter_dict=filter_dict + ) + + if not raw_results: + logger.warning(f"No results found for query in kb_id={kb_id}") + return [], 0.0, False + + # Convert to RetrievalResult objects + results = [] + for r in raw_results: + results.append(RetrievalResult( + chunk_id=r['id'], + content=r['content'], + metadata=r['metadata'], + similarity_score=r['similarity_score'] + )) + + # HEAVY CONFIDENCE MODE: Use maximum similarity score from top results + # This ensures confidence reflects the best match found, not dragged down by weaker results + if results: + # Get top 3 results and use the maximum similarity score + # This gives maximum confidence if there's at least one strong match + top_results = results[:3] + max_score = max(r.similarity_score for r in top_results) + + # If max score is good (>=0.4), use it directly + # Otherwise, use weighted average of top 3 to avoid over-inflating weak matches + if max_score >= 0.4: + avg_confidence = max_score + else: + # For weaker matches, use weighted average of top 3 + scores = [r.similarity_score for r in top_results] + weights = [1.0, 0.7, 0.5][:len(scores)] # Aggressive weighting + weighted_sum = sum(s * w for s, w in zip(scores, weights)) + total_weight = sum(weights[:len(scores)]) + avg_confidence = weighted_sum / total_weight if total_weight > 0 else max_score + else: + avg_confidence = 0.0 + + # Filter results above threshold + filtered_results = [ + r for r in results + if r.similarity_score >= self.similarity_threshold + ] + + # If no results pass threshold but we have results, use top results anyway + # This prevents over-filtering when threshold is too strict + if not filtered_results and results: + logger.warning(f"No results passed threshold {self.similarity_threshold}, using top {min(3, len(results))} results anyway") + filtered_results = results[:min(3, len(results))] + # Recalculate confidence with the fallback results + if filtered_results: + scores = [r.similarity_score for r in filtered_results] + avg_confidence = sum(scores) / len(scores) if scores else 0.0 + + # DIRECT MATCH GATE: Check if at least one chunk directly matches query intent + # For integration/API questions, this gate is stricter + has_direct_match = False + if filtered_results: + chunk_texts = [r.content for r in filtered_results] + intents = detect_intents(query) + intent_keywords = get_intent_keywords(intents) + + # For integration/API questions, require direct match + if "integration" in intents or "api" in query.lower(): + has_direct_match = check_direct_match(query, chunk_texts, intent_keywords) + logger.info(f"Direct match check (strict for integration): {has_direct_match} (intents: {intents})") + else: + # For other questions, be more lenient - just check if important words match + query_words = set(re.findall(r'\b\w+\b', query.lower())) + stop_words = {"the", "a", "an", "is", "are", "was", "were", "be", "been", + "to", "of", "and", "or", "but", "in", "on", "at", "for", + "with", "how", "what", "when", "where", "why", "do", "does"} + important_words = query_words - stop_words + + # Check if at least one important word appears in chunks + for chunk in chunk_texts: + chunk_lower = chunk.lower() + matches = sum(1 for word in important_words if word in chunk_lower) + if matches >= 1 and len(important_words) > 0: # At least one important word + has_direct_match = True + break + + logger.info(f"Direct match check (lenient): {has_direct_match} (intents: {intents})") + + # Only consider relevant if we have filtered results AND (direct match OR high confidence) + # High confidence (>0.40) can bypass direct match requirement for non-integration questions + has_relevant = len(filtered_results) > 0 and (has_direct_match or avg_confidence > 0.40) + + logger.info( + f"Retrieved {len(results)} results, " + f"{len(filtered_results)} above threshold ({self.similarity_threshold}), " + f"avg confidence: {avg_confidence:.3f}, " + f"direct match: {has_direct_match}" + ) + + return filtered_results, avg_confidence, has_relevant + + def get_context_for_llm( + self, + results: List[RetrievalResult], + max_tokens: int = settings.MAX_CONTEXT_TOKENS + ) -> Tuple[str, List[Dict[str, Any]]]: + """ + Format retrieved results into context for the LLM. + + Args: + results: List of retrieval results + max_tokens: Maximum tokens for context + + Returns: + Tuple of (formatted_context, citation_info) + """ + if not results: + return "", [] + + context_parts = [] + citations = [] + current_tokens = 0 + + # Estimate tokens (rough approximation: 1 token โ‰ˆ 4 chars) + for i, result in enumerate(results): + chunk_text = result.content + estimated_tokens = len(chunk_text) // 4 + + if current_tokens + estimated_tokens > max_tokens: + logger.info(f"Truncating context at {i} chunks due to token limit") + break + + # Format chunk with source info + source_info = f"[Source {i+1}: {result.metadata.get('file_name', 'Unknown')}]" + if result.metadata.get('page_number'): + source_info += f" (Page {result.metadata['page_number']})" + + context_parts.append(f"{source_info}\n{chunk_text}") + + # Build citation info + citations.append({ + "index": i + 1, + "file_name": result.metadata.get('file_name', 'Unknown'), + "chunk_id": result.chunk_id, + "page_number": result.metadata.get('page_number'), + "similarity_score": result.similarity_score, + "excerpt": chunk_text[:200] + "..." if len(chunk_text) > 200 else chunk_text + }) + + current_tokens += estimated_tokens + + formatted_context = "\n\n---\n\n".join(context_parts) + + return formatted_context, citations + + +# Global retrieval service instance +_retrieval_service: Optional[RetrievalService] = None + + +def get_retrieval_service() -> RetrievalService: + """Get the global retrieval service instance.""" + global _retrieval_service + if _retrieval_service is None: + _retrieval_service = RetrievalService() + return _retrieval_service + diff --git a/app/rag/vectorstore.py b/app/rag/vectorstore.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e03c9d0da109065f336a7ed40f09a5f7ce1e63 --- /dev/null +++ b/app/rag/vectorstore.py @@ -0,0 +1,273 @@ +""" +Vector store using ChromaDB for local storage. +Supports efficient similarity search and filtering. +""" +import chromadb +from chromadb.config import Settings as ChromaSettings +from typing import List, Dict, Any, Optional +import logging +from pathlib import Path + +from app.config import settings +from app.rag.embeddings import get_embedding_service + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class VectorStore: + """ + Vector store using ChromaDB for persistent local storage. + Supports CRUD operations and similarity search. + """ + + def __init__( + self, + persist_directory: Path = settings.VECTORDB_DIR, + collection_name: str = settings.COLLECTION_NAME + ): + """ + Initialize the vector store. + + Args: + persist_directory: Directory to persist the database + collection_name: Name of the collection to use + """ + self.persist_directory = persist_directory + self.collection_name = collection_name + + # Initialize ChromaDB client with persistence + self.client = chromadb.PersistentClient( + path=str(persist_directory), + settings=ChromaSettings( + anonymized_telemetry=False, + allow_reset=True + ) + ) + + # Get or create collection + self.collection = self.client.get_or_create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity + ) + + logger.info(f"Vector store initialized. Collection: {collection_name}, Items: {self.collection.count()}") + + def add_documents( + self, + documents: List[str], + embeddings: List[List[float]], + metadatas: List[Dict[str, Any]], + ids: List[str] + ) -> None: + """ + Add documents to the vector store. + + Args: + documents: List of document texts + embeddings: List of embedding vectors + metadatas: List of metadata dictionaries + ids: List of unique document IDs + """ + if not documents: + logger.warning("No documents to add") + return + + # ChromaDB doesn't accept None values in metadata + clean_metadatas = [] + for meta in metadatas: + clean_meta = {} + for k, v in meta.items(): + if v is not None: + clean_meta[k] = v + clean_metadatas.append(clean_meta) + + self.collection.add( + documents=documents, + embeddings=embeddings, + metadatas=clean_metadatas, + ids=ids + ) + + logger.info(f"Added {len(documents)} documents to vector store") + + def search( + self, + query_embedding: List[float], + top_k: int = settings.TOP_K, + filter_dict: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + Search for similar documents. + + Args: + query_embedding: Query embedding vector + top_k: Number of results to return + filter_dict: Optional filter criteria (e.g., {"kb_id": "123"}) + + Returns: + List of results with document, metadata, and similarity score + """ + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if filter_dict: + if len(filter_dict) == 1: + # Single condition - use directly + where_filter = filter_dict + else: + # Multiple conditions - use $and operator + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=top_k, + where=where_filter, + include=["documents", "metadatas", "distances"] + ) + + # Format results + formatted_results = [] + if results and results['ids'] and results['ids'][0]: + for i, doc_id in enumerate(results['ids'][0]): + # ChromaDB returns distances, convert to similarity + # For cosine distance: similarity = 1 - distance + distance = results['distances'][0][i] if results['distances'] else 0 + similarity = 1 - distance # Convert distance to similarity + + formatted_results.append({ + 'id': doc_id, + 'content': results['documents'][0][i] if results['documents'] else "", + 'metadata': results['metadatas'][0][i] if results['metadatas'] else {}, + 'similarity_score': max(0, min(1, similarity)) # Clamp to 0-1 + }) + + return formatted_results + + def delete_by_filter(self, filter_dict: Dict[str, Any]) -> int: + """ + Delete documents matching a filter. + + Args: + filter_dict: Filter criteria + + Returns: + Number of documents deleted + """ + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if len(filter_dict) == 1: + where_filter = filter_dict + else: + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + # First, find matching documents + results = self.collection.get( + where=where_filter, + include=["metadatas"] + ) + + if results and results['ids']: + self.collection.delete(ids=results['ids']) + logger.info(f"Deleted {len(results['ids'])} documents matching filter") + return len(results['ids']) + + return 0 + + def delete_by_ids(self, ids: List[str]) -> None: + """Delete documents by their IDs.""" + if ids: + self.collection.delete(ids=ids) + logger.info(f"Deleted {len(ids)} documents by ID") + + def get_stats( + self, + tenant_id: Optional[str] = None, # CRITICAL: Multi-tenant isolation + kb_id: Optional[str] = None, + user_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get statistics about the vector store. + + Args: + tenant_id: Tenant ID for multi-tenant isolation (REQUIRED if filtering) + kb_id: Optional knowledge base ID to filter + user_id: Optional user ID to filter + + Returns: + Statistics dictionary + """ + filter_dict = {} + if tenant_id: + filter_dict["tenant_id"] = tenant_id # CRITICAL: Multi-tenant isolation + if kb_id: + filter_dict["kb_id"] = kb_id + if user_id: + filter_dict["user_id"] = user_id + + if filter_dict: + # ChromaDB requires filters in $and/$or format for multiple conditions + where_filter = None + if len(filter_dict) == 1: + where_filter = filter_dict + else: + where_filter = { + "$and": [ + {k: v} for k, v in filter_dict.items() + ] + } + + results = self.collection.get( + where=where_filter, + include=["metadatas"] + ) + count = len(results['ids']) if results and results['ids'] else 0 + + # Get unique file names + file_names = set() + if results and results['metadatas']: + for meta in results['metadatas']: + if 'file_name' in meta: + file_names.add(meta['file_name']) + + return { + "total_chunks": count, + "file_names": list(file_names), + "tenant_id": tenant_id, + "kb_id": kb_id, + "user_id": user_id + } + else: + return { + "total_chunks": self.collection.count(), + "collection_name": self.collection_name + } + + def clear_collection(self) -> None: + """Clear all documents from the collection.""" + self.client.delete_collection(self.collection_name) + self.collection = self.client.create_collection( + name=self.collection_name, + metadata={"hnsw:space": "cosine"} + ) + logger.info(f"Cleared collection: {self.collection_name}") + + +# Global vector store instance +_vector_store: Optional[VectorStore] = None + + +def get_vector_store() -> VectorStore: + """Get the global vector store instance.""" + global _vector_store + if _vector_store is None: + _vector_store = VectorStore() + return _vector_store + diff --git a/app/rag/verifier.py b/app/rag/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..16d3a96fff7c55111ef0c13076af2dfc148d4bc5 --- /dev/null +++ b/app/rag/verifier.py @@ -0,0 +1,276 @@ +""" +Verifier module for RAG pipeline. +Implements Draft โ†’ Verify โ†’ Final flow to minimize hallucination. +""" +import json +import re +from typing import Dict, Any, List, Optional +import logging + +from app.config import settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Verifier prompt template +VERIFIER_PROMPT = """You are a strict fact-checker for a customer support chatbot. Your job is to verify that every factual claim in a draft answer is supported by the provided context. + +## Your Task: +1. Review the DRAFT ANSWER below +2. Check each factual claim against the PROVIDED CONTEXT +3. Identify any claims that are NOT supported by the context +4. Return a JSON response with your verification results + +## CRITICAL RULES: +- If ANY claim is not explicitly supported by the context โ†’ FAIL +- If the answer adds information not in context โ†’ FAIL +- If citations are missing or incorrect โ†’ FAIL +- Only PASS if ALL claims are verifiable in the context + +## Response Format (JSON): +{{ + "pass": true/false, + "issues": ["list of issues found"], + "unsupported_claims": ["list of unsupported claims"], + "final_answer": "corrected answer if needed (optional)" +}} + +## Example FAIL Response: +{{ + "pass": false, + "issues": ["Claim about '30 days' not found in context", "Missing citation for pricing information"], + "unsupported_claims": ["Refund window is 30 days", "Starter plan costs โ‚น999"], + "final_answer": null +}} + +## Example PASS Response: +{{ + "pass": true, + "issues": [], + "unsupported_claims": [], + "final_answer": null +}} + +--- + +## PROVIDED CONTEXT: +{context} + +--- + +## DRAFT ANSWER TO VERIFY: +{draft_answer} + +--- + +Now verify the draft answer and return ONLY valid JSON (no markdown, no code blocks, just raw JSON):""" + + +class VerifierService: + """ + Verifies that draft answers are supported by retrieved context. + Implements strict factual validation to prevent hallucination. + """ + + def __init__(self, provider: Optional[Any] = None): + """ + Initialize the verifier service. + + Args: + provider: Optional LLM provider (uses same as answer service if not provided) + """ + self._provider = provider + + @property + def provider(self): + """Get or create the LLM provider for verification.""" + if self._provider is None: + from app.rag.answer import GeminiProvider, OpenAIProvider + if settings.LLM_PROVIDER == "gemini": + self._provider = GeminiProvider() + elif settings.LLM_PROVIDER == "openai": + self._provider = OpenAIProvider() + else: + raise ValueError(f"Unknown LLM provider: {settings.LLM_PROVIDER}") + return self._provider + + def verify_answer( + self, + draft_answer: str, + context: str, + citations_info: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + Verify that a draft answer is supported by the context. + + Args: + draft_answer: The draft answer to verify + context: The retrieved context from knowledge base + citations_info: List of citation information + + Returns: + Dictionary with verification results: + { + "pass": bool, + "issues": List[str], + "unsupported_claims": List[str], + "final_answer": Optional[str] + } + """ + if not context or not draft_answer: + logger.warning("Empty context or draft answer provided to verifier") + return { + "pass": False, + "issues": ["Empty context or draft answer"], + "unsupported_claims": [], + "final_answer": None + } + + # Format verifier prompt + verifier_prompt = VERIFIER_PROMPT.format( + context=context, + draft_answer=draft_answer + ) + + try: + logger.info("Running verifier on draft answer...") + # Use a more deterministic temperature for verification + try: + raw_response = self.provider.generate( + system_prompt="You are a strict fact-checker. Return ONLY valid JSON.", + user_prompt=verifier_prompt + ) + except Exception as e: + logger.error(f"Error calling LLM in verifier: {e}", exc_info=True) + # On LLM error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier LLM error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + # Parse JSON response + try: + verification_result = self._parse_verifier_response(raw_response) + except Exception as e: + logger.error(f"Error parsing verifier response: {e}", exc_info=True) + logger.error(f"Raw response was: {raw_response[:500]}") + # On parse error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier parse error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + if verification_result["pass"]: + logger.info("โœ… Verifier PASSED - All claims supported by context") + else: + logger.warning( + f"โŒ Verifier FAILED - Issues: {verification_result.get('issues', [])}" + ) + + return verification_result + + except Exception as e: + logger.error(f"Unexpected error in verifier: {e}", exc_info=True) + # On error, fail conservatively + return { + "pass": False, + "issues": [f"Verifier error: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + def _parse_verifier_response(self, raw_response: str) -> Dict[str, Any]: + """ + Parse the verifier's JSON response. + + Args: + raw_response: Raw response from LLM + + Returns: + Parsed verification result + """ + # Try to extract JSON from response + # Remove markdown code blocks if present + cleaned = raw_response.strip() + if cleaned.startswith("```json"): + cleaned = cleaned[7:] + if cleaned.startswith("```"): + cleaned = cleaned[3:] + if cleaned.endswith("```"): + cleaned = cleaned[:-3] + cleaned = cleaned.strip() + + # Try to find JSON object in the response + # Look for { ... } pattern + json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', cleaned, re.DOTALL) + if json_match: + cleaned = json_match.group(0) + + try: + result = json.loads(cleaned) + + # Validate structure + if not isinstance(result, dict): + raise ValueError("Response is not a dictionary") + + # Ensure required fields + return { + "pass": result.get("pass", False), + "issues": result.get("issues", []), + "unsupported_claims": result.get("unsupported_claims", []), + "final_answer": result.get("final_answer") + } + + except (json.JSONDecodeError, ValueError) as e: + logger.error(f"Failed to parse verifier JSON: {e}") + logger.error(f"Raw response (first 500 chars): {raw_response[:500]}") + logger.error(f"Cleaned response (first 500 chars): {cleaned[:500]}") + + # Fallback: try to infer pass/fail from text + response_lower = raw_response.lower() + # Check for explicit pass indicators + if ("pass" in response_lower and ("true" in response_lower or "yes" in response_lower)) or \ + ("all claims" in response_lower and "supported" in response_lower): + logger.warning("Using fallback: inferred PASS from text") + return { + "pass": True, + "issues": [], + "unsupported_claims": [], + "final_answer": None + } + elif ("pass" in response_lower and ("false" in response_lower or "no" in response_lower)) or \ + ("not supported" in response_lower or "unsupported" in response_lower): + logger.warning("Using fallback: inferred FAIL from text") + return { + "pass": False, + "issues": ["Failed to parse verifier response - inferred fail from text"], + "unsupported_claims": [], + "final_answer": None + } + else: + # Default to fail for safety + logger.warning("Using fallback: defaulting to FAIL (could not infer from text)") + return { + "pass": False, + "issues": [f"Failed to parse verifier response: {str(e)}"], + "unsupported_claims": [], + "final_answer": None + } + + +# Global verifier instance +_verifier_service: Optional[VerifierService] = None + + +def get_verifier_service() -> VerifierService: + """Get the global verifier service instance.""" + global _verifier_service + if _verifier_service is None: + _verifier_service = VerifierService() + return _verifier_service + diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10b7dfe004fe2fedc72db0e075cc7bff7b84a2e0 --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,6 @@ +""" +Utility functions for the RAG backend. +""" + + + diff --git a/env.example.txt b/env.example.txt new file mode 100644 index 0000000000000000000000000000000000000000..5f4628b9c7deb46c206407beb92aa992c2b99e24 --- /dev/null +++ b/env.example.txt @@ -0,0 +1,32 @@ +# ClientSphere RAG Backend Configuration +# Rename this file to .env and fill in your values + +# LLM Provider (gemini or openai) +LLM_PROVIDER=gemini + +# Gemini API Key (get from https://makersuite.google.com/app/apikey) +GEMINI_API_KEY=your_gemini_api_key_here + +# OpenAI API Key (optional, if using OpenAI instead) +# OPENAI_API_KEY=your_openai_api_key_here + +# Model names (optional, defaults provided) +# GEMINI_MODEL=gemini-pro +# OPENAI_MODEL=gpt-3.5-turbo + +# Embedding model (optional, default is all-MiniLM-L6-v2) +# EMBEDDING_MODEL=all-MiniLM-L6-v2 + +# Chunking settings (optional) +# CHUNK_SIZE=500 +# CHUNK_OVERLAP=100 + +# Retrieval settings (optional) +# TOP_K=6 +# SIMILARITY_THRESHOLD=0.35 + +# Debug mode +DEBUG=true + + + diff --git a/render.yaml b/render.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0dc60cfcab610a4709e40695aaa13c777b7bb74 --- /dev/null +++ b/render.yaml @@ -0,0 +1,21 @@ +services: + - type: web + name: clientsphere-rag-backend + env: python + buildCommand: pip install -r requirements.txt + startCommand: uvicorn app.main:app --host 0.0.0.0 --port $PORT + envVars: + - key: GEMINI_API_KEY + sync: false # Set in Render dashboard + - key: ENV + value: prod + - key: LLM_PROVIDER + value: gemini + - key: ALLOWED_ORIGINS + sync: false # Set in Render dashboard + - key: JWT_SECRET + sync: false # Set in Render dashboard (same as Node.js backend) + - key: DEBUG + value: "false" + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4615f55c2bab2451fa533a1ff9537e06a75620fc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,47 @@ +# FastAPI & Server +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +python-multipart>=0.0.6 + +# Document Processing (use latest for pre-built wheels) +pymupdf>=1.24.0 +python-docx>=1.1.0 +markdown>=3.5.2 +chardet>=5.2.0 + +# Text Processing & Chunking +tiktoken>=0.5.2 + +# Embeddings +sentence-transformers>=2.3.1 + +# Vector Database (ChromaDB - local, easy setup) +chromadb>=0.4.22 + +# LLM Providers +google-generativeai>=0.4.0 +openai>=1.12.0 + +# Utilities +pydantic>=2.6.0 +pydantic-settings>=2.1.0 +python-dotenv>=1.0.1 +aiofiles>=23.2.1 + +# Authentication +python-jose[cryptography]>=3.3.0 + +# Rate Limiting +slowapi>=0.1.9 + +# Monitoring & Metrics +prometheus-client>=0.19.0 +prometheus-fastapi-instrumentator>=6.1.0 + +# Database +sqlalchemy>=2.0.23 + +# Evaluation & Testing +pytest>=8.0.0 +httpx>=0.26.0 + diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..906bde563db236c7834641915e893fbea964c0f2 --- /dev/null +++ b/scripts/__init__.py @@ -0,0 +1,4 @@ +# Scripts package + + + diff --git a/scripts/create_billing_tables.py b/scripts/create_billing_tables.py new file mode 100644 index 0000000000000000000000000000000000000000..d8dbd6a71ae6efb6a6f7a40af2f0ca5b111e37ff --- /dev/null +++ b/scripts/create_billing_tables.py @@ -0,0 +1,33 @@ +""" +Create billing database tables. +Run this script to initialize the billing database. +""" +import sys +from pathlib import Path + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.db.database import init_db, engine +from app.db.models import Base + +def main(): + """Create all billing tables.""" + print("Creating billing database tables...") + try: + # Create all tables + Base.metadata.create_all(bind=engine) + print("โœ… Billing tables created successfully!") + print("\nTables created:") + print(" - tenants") + print(" - tenant_plans") + print(" - usage_events") + print(" - usage_daily") + print(" - usage_monthly") + except Exception as e: + print(f"โŒ Error creating tables: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() + diff --git a/scripts/validate_rag.py b/scripts/validate_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..2eda62ceb6c0aaca0f21f4c67ef08bbfed7be2ae --- /dev/null +++ b/scripts/validate_rag.py @@ -0,0 +1,466 @@ +""" +Automated RAG pipeline validation script. +Tests end-to-end functionality, multi-tenant isolation, and anti-hallucination. +""" +import httpx +import time +import json +from pathlib import Path +from typing import Dict, List, Any, Tuple +import sys + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +BASE_URL = "http://localhost:8000" +TEST_TENANT_A = "tenant_A" +TEST_TENANT_B = "tenant_B" +TEST_USER_A = "user_A" +TEST_USER_B = "user_B" +TEST_KB_A = "kb_A" +TEST_KB_B = "kb_B" + +# Test documents +TENANT_A_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_A_kb.md" +TENANT_B_DOC = Path(__file__).parent.parent / "data" / "test_docs" / "tenant_B_kb.md" + +# Test results storage +test_results: List[Dict[str, Any]] = [] + + +def print_header(text: str): + """Print a formatted header.""" + print("\n" + "=" * 80) + print(f" {text}") + print("=" * 80) + + +def print_test(test_name: str, passed: bool, reason: str = ""): + """Print test result.""" + status = "[PASS]" if passed else "[FAIL]" + print(f"{status} | {test_name}") + if reason: + print(f" โ””โ”€ {reason}") + test_results.append({ + "test": test_name, + "passed": passed, + "reason": reason + }) + + +def wait_for_server(max_retries: int = 10, delay: int = 2) -> bool: + """Wait for the server to be ready.""" + print("Waiting for server to be ready...") + for i in range(max_retries): + try: + response = httpx.get(f"{BASE_URL}/health", timeout=5) + if response.status_code == 200: + print("[OK] Server is ready") + return True + except Exception: + pass + time.sleep(delay) + print(f" Retry {i+1}/{max_retries}...") + print("[FAIL] Server not ready after max retries") + return False + + +def upload_document( + client: httpx.Client, + file_path: Path, + tenant_id: str, + user_id: str, + kb_id: str +) -> Dict[str, Any]: + """Upload a document to the knowledge base.""" + try: + with open(file_path, "rb") as f: + files = {"file": (file_path.name, f, "text/markdown")} + data = { + "tenant_id": tenant_id, + "user_id": user_id, + "kb_id": kb_id + } + response = client.post( + f"{BASE_URL}/kb/upload", + files=files, + data=data, + timeout=60 + ) + if response.status_code == 200: + return {"success": True, "data": response.json()} + else: + return {"success": False, "error": response.text} + except Exception as e: + return {"success": False, "error": str(e)} + + +def test_retrieval( + client: httpx.Client, + query: str, + tenant_id: str, + user_id: str, + kb_id: str, + expected_keywords: List[str], + should_not_contain: List[str] = None, + top_k: int = 5 +) -> Tuple[bool, str]: + """Test retrieval accuracy.""" + try: + # Use GET for search endpoint with headers for dev mode auth + headers = { + "X-Tenant-Id": tenant_id, + "X-User-Id": user_id + } + response = client.get( + f"{BASE_URL}/kb/search", + params={ + "query": query, + "kb_id": kb_id, + "top_k": top_k + }, + headers=headers, + timeout=30 + ) + + if response.status_code != 200: + return False, f"API returned {response.status_code}: {response.text}" + + data = response.json() + results = data.get("results", []) + + if not results: + return False, "No results retrieved" + + # Check tenant isolation + for result in results: + metadata = result.get("metadata", {}) + result_tenant = metadata.get("tenant_id") + if result_tenant != tenant_id: + return False, f"Tenant leak detected! Got tenant_id={result_tenant}, expected {tenant_id}" + + # Check for expected keywords + all_content = " ".join([r.get("content", "") for r in results]).lower() + found_keywords = [kw for kw in expected_keywords if kw.lower() in all_content] + + if not found_keywords: + return False, f"Expected keywords not found: {expected_keywords}" + + # Check for forbidden content + if should_not_contain: + for forbidden in should_not_contain: + if forbidden.lower() in all_content: + return False, f"Forbidden content found: {forbidden}" + + return True, f"Retrieved {len(results)} results, found keywords: {found_keywords}" + + except Exception as e: + return False, f"Error: {str(e)}" + + +def test_chat( + client: httpx.Client, + question: str, + tenant_id: str, + user_id: str, + kb_id: str, + expected_keywords: List[str] = None, + should_refuse: bool = False, + should_not_contain: List[str] = None +) -> Tuple[bool, str, Dict[str, Any]]: + """Test full chat endpoint.""" + try: + # Include headers for dev mode auth + headers = { + "X-Tenant-Id": tenant_id, + "X-User-Id": user_id + } + response = client.post( + f"{BASE_URL}/chat", + json={ + "tenant_id": tenant_id, + "user_id": user_id, + "kb_id": kb_id, + "question": question + }, + headers=headers, + timeout=60 + ) + + if response.status_code != 200: + return False, f"API returned {response.status_code}: {response.text}", {} + + data = response.json() + answer = data.get("answer", "").lower() + citations = data.get("citations", []) + from_kb = data.get("from_knowledge_base", False) + confidence = data.get("confidence", 0.0) + metadata = data.get("metadata", {}) + refused = metadata.get("refused", False) + + # Check refusal behavior (STRICT) + if should_refuse: + # Check if response explicitly indicates refusal + refused = data.get("refused", False) + refusal_keywords = [ + "couldn't find", "don't have", "not available", "contact support", + "not in the knowledge base", "could not verify", "not enough information", + "apologize", "couldn't find relevant information" + ] + has_refusal_keywords = any(kw in answer for kw in refusal_keywords) + + # If answer was generated with citations, it's a FAIL (should have refused) + if citations and len(citations) > 0: + return False, ( + f"Should have refused but generated answer with {len(citations)} citations. " + f"Answer: {answer[:300]}" + ), data + + # If confidence is high and answer exists, it's a FAIL + if confidence >= 0.30 and answer and not has_refusal_keywords: + return False, ( + f"Should have refused but generated answer with confidence {confidence:.2f}. " + f"Answer: {answer[:300]}" + ), data + + # If not refused and no refusal keywords, it's a FAIL + if not refused and not has_refusal_keywords: + return False, ( + f"Should have refused but didn't. " + f"refused={refused}, confidence={confidence:.2f}, citations={len(citations)}. " + f"Answer: {answer[:300]}" + ), data + + # If we got here, it properly refused + return True, f"Properly refused (refused={refused}, confidence={confidence:.2f})", data + + # Check for expected keywords + if expected_keywords: + found = [kw for kw in expected_keywords if kw.lower() in answer] + if not found: + return False, f"Expected keywords not found: {expected_keywords}. Answer: {answer[:200]}", data + + # Check citations + if not should_refuse and from_kb: + if not citations: + return False, "Answer claims to be from KB but has no citations", data + + # Check for forbidden content + if should_not_contain: + for forbidden in should_not_contain: + if forbidden.lower() in answer: + return False, f"Forbidden content found in answer: {forbidden}", data + + # Check citation integrity + if citations and expected_keywords: + citation_text = " ".join([c.get("excerpt", "") for c in citations]).lower() + for kw in expected_keywords: + if kw.lower() in answer and kw.lower() not in citation_text: + # This is a warning, not a failure + pass + + return True, f"Answer generated (confidence: {confidence:.2f}, citations: {len(citations)})", data + + except Exception as e: + return False, f"Error: {str(e)}", {} + + +def main(): + """Run all validation tests.""" + print_header("RAG Pipeline Validation Suite") + + # Check server + if not wait_for_server(): + print("[FAIL] Cannot proceed without server") + return + + client = httpx.Client(timeout=120.0) + + # ========== PHASE 1: Upload Documents ========== + print_header("Phase 1: Upload Test Documents") + + # Upload tenant A doc + print(f"\n๐Ÿ“ค Uploading {TENANT_A_DOC.name} for {TEST_TENANT_A}...") + result = upload_document(client, TENANT_A_DOC, TEST_TENANT_A, TEST_USER_A, TEST_KB_A) + if result["success"]: + print("[OK] Upload successful") + print("โณ Waiting for document processing (10 seconds)...") + time.sleep(10) # Wait longer for processing (parsing, chunking, embedding) + else: + print(f"[FAIL] Upload failed: {result.get('error')}") + return + + # Upload tenant B doc + print(f"\n๐Ÿ“ค Uploading {TENANT_B_DOC.name} for {TEST_TENANT_B}...") + result = upload_document(client, TENANT_B_DOC, TEST_TENANT_B, TEST_USER_B, TEST_KB_B) + if result["success"]: + print("[OK] Upload successful") + print("โณ Waiting for document processing (10 seconds)...") + time.sleep(10) # Wait longer for processing (parsing, chunking, embedding) + else: + print(f"[FAIL] Upload failed: {result.get('error')}") + return + + # ========== PHASE 2: Retrieval Tests ========== + print_header("Phase 2: Retrieval Accuracy Tests") + + # Test 1: Tenant A retrieval + passed, reason = test_retrieval( + client, + "What is the refund window?", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + expected_keywords=["7 days"], + should_not_contain=["30 days"] + ) + print_test("Retrieval: Tenant A - Refund Window", passed, reason) + + # Test 2: Tenant B retrieval + passed, reason = test_retrieval( + client, + "What is the refund window?", + TEST_TENANT_B, + TEST_USER_B, + TEST_KB_B, + expected_keywords=["30 days"], + should_not_contain=["7 days"] + ) + print_test("Retrieval: Tenant B - Refund Window", passed, reason) + + # Test 3: Tenant isolation (A should not get B's data) + passed, reason = test_retrieval( + client, + "Starter plan price", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + expected_keywords=["499"], + should_not_contain=["999"] + ) + print_test("Retrieval: Tenant A - Starter Plan Price (Isolation)", passed, reason) + + # Test 4: Tenant isolation (B should not get A's data) + passed, reason = test_retrieval( + client, + "Starter plan price", + TEST_TENANT_B, + TEST_USER_B, + TEST_KB_B, + expected_keywords=["999"], + should_not_contain=["499"] + ) + print_test("Retrieval: Tenant B - Starter Plan Price (Isolation)", passed, reason) + + # ========== PHASE 3: Chat Tests ========== + print_header("Phase 3: Chat Endpoint Tests") + + # Test 5: Tenant A chat - refund window + passed, reason, data = test_chat( + client, + "What is the refund window?", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + expected_keywords=["7 days"], + should_not_contain=["30 days"] + ) + print_test("Chat: Tenant A - Refund Window", passed, reason) + + # Test 6: Tenant B chat - refund window + passed, reason, data = test_chat( + client, + "What is the refund window?", + TEST_TENANT_B, + TEST_USER_B, + TEST_KB_B, + expected_keywords=["30 days"], + should_not_contain=["7 days"] + ) + print_test("Chat: Tenant B - Refund Window", passed, reason) + + # Test 7: Tenant A chat - Starter plan + passed, reason, data = test_chat( + client, + "What is the Starter plan price?", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + expected_keywords=["499"], + should_not_contain=["999"] + ) + print_test("Chat: Tenant A - Starter Plan Price", passed, reason) + + # Test 8: Tenant B chat - Starter plan + passed, reason, data = test_chat( + client, + "What is the Starter plan price?", + TEST_TENANT_B, + TEST_USER_B, + TEST_KB_B, + expected_keywords=["999"], + should_not_contain=["499"] + ) + print_test("Chat: Tenant B - Starter Plan Price", passed, reason) + + # Test 9: Hallucination refusal - out of scope + passed, reason, data = test_chat( + client, + "How to integrate ClientSphere with Shopify?", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + should_refuse=True + ) + print_test("Chat: Hallucination Refusal (Out of Scope)", passed, reason) + + # Test 10: Citation integrity + passed, reason, data = test_chat( + client, + "How long do password reset links last?", + TEST_TENANT_A, + TEST_USER_A, + TEST_KB_A, + expected_keywords=["15"] + ) + if passed: + citations = data.get("citations", []) + if citations: + print_test("Chat: Citation Integrity", True, f"Found {len(citations)} citations") + else: + print_test("Chat: Citation Integrity", False, "No citations provided") + else: + print_test("Chat: Citation Integrity", False, reason) + + # ========== PHASE 4: Summary ========== + print_header("Test Summary") + + total_tests = len(test_results) + passed_tests = sum(1 for r in test_results if r["passed"]) + failed_tests = total_tests - passed_tests + + print(f"\nTotal Tests: {total_tests}") + print(f"[PASS] Passed: {passed_tests}") + print(f"[FAIL] Failed: {failed_tests}") + print(f"Success Rate: {(passed_tests/total_tests*100):.1f}%") + + if failed_tests > 0: + print("\n[FAIL] Failed Tests:") + for result in test_results: + if not result["passed"]: + print(f" - {result['test']}: {result['reason']}") + + # Final verdict + print_header("Final Verdict") + if failed_tests == 0: + print("[PASS] ALL TESTS PASSED - RAG Pipeline is working correctly") + return 0 + else: + print(f"[FAIL] {failed_tests} TEST(S) FAILED - Review issues above") + return 1 + + +if __name__ == "__main__": + exit_code = main() + sys.exit(exit_code) +