Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application for ClientSphere RAG Backend. | |
| Provides endpoints for knowledge base management and chat. | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| from pathlib import Path | |
| import shutil | |
| import uuid | |
| from datetime import datetime | |
| from typing import Optional | |
| import logging | |
| from app.config import settings | |
| from app.middleware.auth import get_auth_context, require_auth | |
| from app.middleware.rate_limit import ( | |
| limiter, | |
| get_tenant_rate_limit_key, | |
| RateLimitExceeded, | |
| _rate_limit_exceeded_handler | |
| ) | |
| from app.models.schemas import ( | |
| UploadResponse, | |
| ChatRequest, | |
| ChatResponse, | |
| KnowledgeBaseStats, | |
| HealthResponse, | |
| DocumentStatus, | |
| Citation, | |
| ) | |
| from app.models.billing_schemas import ( | |
| UsageResponse, | |
| PlanLimitsResponse, | |
| CostReportResponse, | |
| SetPlanRequest | |
| ) | |
| from app.rag.ingest import parser | |
| from app.rag.chunking import chunker | |
| from app.rag.embeddings import get_embedding_service | |
| from app.rag.vectorstore import get_vector_store | |
| from app.rag.retrieval import get_retrieval_service | |
| from app.rag.answer import get_answer_service | |
| from app.db.database import get_db, init_db | |
| from app.billing.quota import check_quota, ensure_tenant_exists | |
| from app.billing.usage_tracker import track_usage | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title=settings.APP_NAME, | |
| description="RAG-based customer support chatbot API", | |
| version="1.0.0", | |
| ) | |
| # Initialize database on startup | |
| async def startup_event(): | |
| """Initialize database on application startup.""" | |
| init_db() | |
| logger.info("Database initialized") | |
| # Configure CORS - SECURITY: Restrict in production | |
| if settings.ALLOWED_ORIGINS == "*": | |
| allowed_origins = ["*"] | |
| else: | |
| # Split by comma and strip whitespace | |
| allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()] | |
| # Default to allowing localhost if no origins specified | |
| if not allowed_origins or allowed_origins == ["*"]: | |
| allowed_origins = ["*"] # Allow all in dev mode | |
| logger.info(f"CORS configured with origins: {allowed_origins}") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight | |
| allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers | |
| ) | |
| # Configure rate limiting | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # Add exception handler for validation errors | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| """Handle request validation errors with detailed logging.""" | |
| body = await request.body() | |
| logger.error(f"Request validation error: {exc.errors()}") | |
| logger.error(f"Request body (raw): {body}") | |
| logger.error(f"Request headers: {dict(request.headers)}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')} | |
| ) | |
| # Add exception handler for validation errors | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| """Handle request validation errors with detailed logging.""" | |
| logger.error(f"Request validation error: {exc.errors()}") | |
| logger.error(f"Request body: {await request.body()}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={"detail": exc.errors(), "body": str(await request.body())} | |
| ) | |
| # ============== Health & Status Endpoints ============== | |
| async def root(): | |
| """Root endpoint with basic info.""" | |
| return HealthResponse( | |
| status="ok", | |
| version="1.0.0", | |
| vector_db_connected=True, | |
| llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) | |
| ) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| try: | |
| vector_store = get_vector_store() | |
| stats = vector_store.get_stats() | |
| return HealthResponse( | |
| status="healthy", | |
| version="1.0.0", | |
| vector_db_connected=True, | |
| llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return HealthResponse( | |
| status="unhealthy", | |
| version="1.0.0", | |
| vector_db_connected=False, | |
| llm_configured=False | |
| ) | |
| async def liveness(): | |
| """Kubernetes liveness probe - always returns alive.""" | |
| return {"status": "alive"} | |
| async def readiness(): | |
| """Kubernetes readiness probe - checks dependencies.""" | |
| checks = { | |
| "vector_db": False, | |
| "llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY) | |
| } | |
| # Check vector DB connection | |
| try: | |
| vector_store = get_vector_store() | |
| vector_store.get_stats() | |
| checks["vector_db"] = True | |
| except Exception as e: | |
| logger.warning(f"Vector DB check failed: {e}") | |
| checks["vector_db"] = False | |
| # All checks must pass | |
| if all(checks.values()): | |
| return {"status": "ready", "checks": checks} | |
| else: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks}) | |
| # ============== Knowledge Base Endpoints ============== | |
| async def upload_document( | |
| background_tasks: BackgroundTasks, | |
| request: Request, | |
| file: UploadFile = File(...), | |
| tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod | |
| user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod | |
| kb_id: str = Form(...) | |
| ): | |
| """ | |
| Upload a document to the knowledge base. | |
| - Saves file to disk | |
| - Parses and chunks the document | |
| - Generates embeddings | |
| - Stores in vector database | |
| """ | |
| # SECURITY: Extract tenant_id from auth token in production | |
| if settings.ENV == "prod": | |
| auth_context = await require_auth(request) | |
| tenant_id = auth_context.get("tenant_id") | |
| if not tenant_id: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id must come from authentication token in production mode" | |
| ) | |
| elif not tenant_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id is required" | |
| ) | |
| # Validate file type | |
| file_ext = Path(file.filename).suffix.lower() | |
| if file_ext not in parser.SUPPORTED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}" | |
| ) | |
| # Validate file size (SECURITY) | |
| file.file.seek(0, 2) # Seek to end | |
| file_size = file.file.tell() | |
| file.file.seek(0) # Reset to start | |
| max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024 | |
| if file_size > max_size_bytes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB" | |
| ) | |
| # Generate document ID | |
| doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}" | |
| # Save file to uploads directory | |
| upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}" | |
| try: | |
| with open(upload_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| logger.info(f"Saved file: {upload_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving file: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to save file") | |
| # Process document in background | |
| background_tasks.add_task( | |
| process_document, | |
| upload_path, | |
| tenant_id, # CRITICAL: Multi-tenant isolation | |
| user_id, | |
| kb_id, | |
| file.filename, | |
| doc_id | |
| ) | |
| return UploadResponse( | |
| success=True, | |
| message="Document upload started. Processing in background.", | |
| document_id=doc_id, | |
| file_name=file.filename, | |
| chunks_created=0, | |
| status=DocumentStatus.PROCESSING | |
| ) | |
| async def process_document( | |
| file_path: Path, | |
| tenant_id: str, # CRITICAL: Multi-tenant isolation | |
| user_id: str, | |
| kb_id: str, | |
| original_filename: str, | |
| document_id: str | |
| ): | |
| """ | |
| Background task to process an uploaded document. | |
| """ | |
| try: | |
| logger.info(f"Processing document: {original_filename}") | |
| # Parse document | |
| parsed_doc = parser.parse(file_path) | |
| logger.info(f"Parsed document: {len(parsed_doc.text)} characters") | |
| # Chunk document | |
| chunks = chunker.chunk_text( | |
| parsed_doc.text, | |
| page_numbers=parsed_doc.page_map | |
| ) | |
| logger.info(f"Created {len(chunks)} chunks") | |
| if not chunks: | |
| logger.warning(f"No chunks created from {original_filename}") | |
| return | |
| # Create metadata for each chunk | |
| metadatas = [] | |
| chunk_ids = [] | |
| chunk_texts = [] | |
| for chunk in chunks: | |
| metadata = chunker.create_chunk_metadata( | |
| chunk=chunk, | |
| tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation | |
| kb_id=kb_id, | |
| user_id=user_id, | |
| file_name=original_filename, | |
| file_type=parsed_doc.file_type, | |
| total_chunks=len(chunks), | |
| document_id=document_id | |
| ) | |
| metadatas.append(metadata) | |
| chunk_ids.append(metadata["chunk_id"]) | |
| chunk_texts.append(chunk.content) | |
| # Generate embeddings | |
| embedding_service = get_embedding_service() | |
| embeddings = embedding_service.embed_texts(chunk_texts) | |
| logger.info(f"Generated {len(embeddings)} embeddings") | |
| # Store in vector database | |
| vector_store = get_vector_store() | |
| vector_store.add_documents( | |
| documents=chunk_texts, | |
| embeddings=embeddings, | |
| metadatas=metadatas, | |
| ids=chunk_ids | |
| ) | |
| logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored") | |
| except Exception as e: | |
| logger.error(f"Error processing document {original_filename}: {e}") | |
| raise | |
| async def get_kb_stats( | |
| request: Request, | |
| tenant_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| kb_id: Optional[str] = None, | |
| user_id: Optional[str] = None # Optional in dev, ignored in prod | |
| ): | |
| """Get statistics for a knowledge base.""" | |
| # SECURITY: Get tenant_id and user_id from auth context | |
| auth_context = await get_auth_context(request) | |
| tenant_id_from_auth = auth_context.get("tenant_id") | |
| user_id_from_auth = auth_context.get("user_id") | |
| if settings.ENV == "prod": | |
| if not tenant_id_from_auth or not user_id_from_auth: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id and user_id must come from authentication token in production mode" | |
| ) | |
| tenant_id = tenant_id_from_auth | |
| user_id = user_id_from_auth | |
| else: | |
| tenant_id = tenant_id or tenant_id_from_auth | |
| user_id = user_id or user_id_from_auth | |
| if not tenant_id or not kb_id or not user_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id, kb_id, and user_id are required" | |
| ) | |
| try: | |
| vector_store = get_vector_store() | |
| stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id) | |
| return KnowledgeBaseStats( | |
| tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation | |
| kb_id=kb_id, | |
| user_id=user_id, | |
| total_documents=len(stats.get("file_names", [])), | |
| total_chunks=stats.get("total_chunks", 0), | |
| file_names=stats.get("file_names", []), | |
| last_updated=datetime.utcnow() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting KB stats: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_document( | |
| request: Request, | |
| tenant_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| kb_id: Optional[str] = None, | |
| user_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| file_name: Optional[str] = None | |
| ): | |
| """Delete a document from the knowledge base.""" | |
| # SECURITY: Get tenant_id and user_id from auth context | |
| auth_context = await get_auth_context(request) | |
| tenant_id_from_auth = auth_context.get("tenant_id") | |
| user_id_from_auth = auth_context.get("user_id") | |
| if settings.ENV == "prod": | |
| if not tenant_id_from_auth or not user_id_from_auth: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id and user_id must come from authentication token in production mode" | |
| ) | |
| tenant_id = tenant_id_from_auth | |
| user_id = user_id_from_auth | |
| else: | |
| tenant_id = tenant_id or tenant_id_from_auth | |
| user_id = user_id or user_id_from_auth | |
| if not tenant_id or not kb_id or not user_id or not file_name: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)" | |
| ) | |
| try: | |
| vector_store = get_vector_store() | |
| deleted = vector_store.delete_by_filter({ | |
| "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation | |
| "kb_id": kb_id, | |
| "user_id": user_id, | |
| "file_name": file_name | |
| }) | |
| return { | |
| "success": True, | |
| "message": f"Deleted {deleted} chunks", | |
| "file_name": file_name | |
| } | |
| except Exception as e: | |
| logger.error(f"Error deleting document: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def clear_kb( | |
| request: Request, | |
| tenant_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| kb_id: Optional[str] = None, | |
| user_id: Optional[str] = None # Optional in dev, ignored in prod | |
| ): | |
| """Clear all documents from a knowledge base.""" | |
| # SECURITY: Get tenant_id and user_id from auth context | |
| auth_context = await get_auth_context(request) | |
| tenant_id_from_auth = auth_context.get("tenant_id") | |
| user_id_from_auth = auth_context.get("user_id") | |
| if settings.ENV == "prod": | |
| if not tenant_id_from_auth or not user_id_from_auth: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id and user_id must come from authentication token in production mode" | |
| ) | |
| tenant_id = tenant_id_from_auth | |
| user_id = user_id_from_auth | |
| else: | |
| tenant_id = tenant_id or tenant_id_from_auth | |
| user_id = user_id or user_id_from_auth | |
| if not tenant_id or not kb_id or not user_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id, kb_id, and user_id are required" | |
| ) | |
| try: | |
| vector_store = get_vector_store() | |
| deleted = vector_store.delete_by_filter({ | |
| "tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation | |
| "kb_id": kb_id, | |
| "user_id": user_id | |
| }) | |
| return { | |
| "success": True, | |
| "message": f"Cleared knowledge base. Deleted {deleted} chunks.", | |
| "kb_id": kb_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error clearing KB: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============== Chat Endpoints ============== | |
| async def chat(chat_request: ChatRequest, request: Request): | |
| """ | |
| Process a chat message using RAG. | |
| - Retrieves relevant context from knowledge base | |
| - Generates answer using LLM | |
| - Returns answer with citations | |
| """ | |
| conversation_id = "unknown" | |
| try: | |
| logger.info(f"=== CHAT REQUEST RECEIVED ===") | |
| logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}") | |
| logger.info(f"Request headers: {dict(request.headers)}") | |
| # SECURITY: Get tenant_id and user_id from auth context | |
| # In PROD: MUST come from JWT token (never from request body) | |
| try: | |
| auth_context = await get_auth_context(request) | |
| except Exception as e: | |
| logger.error(f"Error getting auth context: {e}", exc_info=True) | |
| raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}") | |
| tenant_id_from_auth = auth_context.get("tenant_id") | |
| user_id_from_auth = auth_context.get("user_id") | |
| if settings.ENV == "prod": | |
| if not tenant_id_from_auth or not user_id_from_auth: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id and user_id must come from authentication token in production mode" | |
| ) | |
| # Override request values with auth context (security enforcement) | |
| chat_request.tenant_id = tenant_id_from_auth | |
| chat_request.user_id = user_id_from_auth | |
| else: | |
| # DEV mode: use from request if provided, otherwise from auth context | |
| if not chat_request.tenant_id: | |
| chat_request.tenant_id = tenant_id_from_auth | |
| if not chat_request.user_id: | |
| chat_request.user_id = user_id_from_auth | |
| if not chat_request.tenant_id or not chat_request.user_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)" | |
| ) | |
| # Log without PII in production | |
| if settings.ENV == "prod": | |
| logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}") | |
| else: | |
| logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...") | |
| # Generate conversation ID if not provided | |
| conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}" | |
| # Get database session | |
| try: | |
| db = next(get_db()) | |
| except Exception as e: | |
| logger.error(f"Database connection error: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| try: | |
| # Ensure tenant exists in billing DB | |
| ensure_tenant_exists(db, chat_request.tenant_id) | |
| # Check quota BEFORE making LLM call | |
| has_quota, quota_error = check_quota(db, chat_request.tenant_id) | |
| if not has_quota: | |
| logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}") | |
| raise HTTPException( | |
| status_code=402, | |
| detail=quota_error or "AI quota exceeded. Upgrade your plan." | |
| ) | |
| # Retrieve relevant context | |
| retrieval_service = get_retrieval_service() | |
| results, confidence, has_relevant = retrieval_service.retrieve( | |
| query=chat_request.question, | |
| tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation | |
| kb_id=chat_request.kb_id, | |
| user_id=chat_request.user_id | |
| ) | |
| logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}") | |
| # Format context for LLM | |
| context, citations_info = retrieval_service.get_context_for_llm(results) | |
| logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}") | |
| # Generate answer | |
| answer_service = get_answer_service() | |
| answer_result = answer_service.generate_answer( | |
| question=chat_request.question, | |
| context=context, | |
| citations_info=citations_info, | |
| confidence=confidence, | |
| has_relevant_results=has_relevant | |
| ) | |
| # Track usage if LLM was called (usage info present) | |
| usage_info = answer_result.get("usage") | |
| if usage_info: | |
| try: | |
| track_usage( | |
| db=db, | |
| tenant_id=chat_request.tenant_id, | |
| user_id=chat_request.user_id, | |
| kb_id=chat_request.kb_id, | |
| provider=settings.LLM_PROVIDER, | |
| model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL), | |
| prompt_tokens=usage_info.get("prompt_tokens", 0), | |
| completion_tokens=usage_info.get("completion_tokens", 0) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to track usage: {e}", exc_info=True) | |
| # Don't fail the request if usage tracking fails | |
| # Build metadata with refusal info | |
| metadata = { | |
| "chunks_retrieved": len(results), | |
| "kb_id": chat_request.kb_id | |
| } | |
| if "refused" in answer_result: | |
| metadata["refused"] = answer_result["refused"] | |
| if "refusal_reason" in answer_result: | |
| metadata["refusal_reason"] = answer_result["refusal_reason"] | |
| if "verifier_passed" in answer_result: | |
| metadata["verifier_passed"] = answer_result["verifier_passed"] | |
| return ChatResponse( | |
| success=True, | |
| answer=answer_result["answer"], | |
| citations=answer_result["citations"], | |
| confidence=answer_result["confidence"], | |
| from_knowledge_base=answer_result["from_knowledge_base"], | |
| escalation_suggested=answer_result["escalation_suggested"], | |
| conversation_id=conversation_id, | |
| refused=answer_result.get("refused", False), | |
| metadata=metadata | |
| ) | |
| except ValueError as e: | |
| # API key or configuration error | |
| error_msg = str(e) | |
| logger.error(f"Configuration error: {error_msg}") | |
| if "API key" in error_msg.lower(): | |
| return ChatResponse( | |
| success=False, | |
| answer="⚠️ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.", | |
| citations=[], | |
| confidence=0.0, | |
| from_knowledge_base=False, | |
| escalation_suggested=True, | |
| conversation_id=conversation_id, | |
| metadata={"error": error_msg, "error_type": "configuration"} | |
| ) | |
| else: | |
| return ChatResponse( | |
| success=False, | |
| answer=f"Configuration error: {error_msg}", | |
| citations=[], | |
| confidence=0.0, | |
| from_knowledge_base=False, | |
| escalation_suggested=True, | |
| conversation_id=conversation_id, | |
| metadata={"error": error_msg} | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTP exceptions (they have proper status codes) | |
| raise | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}", exc_info=True) | |
| logger.error(f"Error type: {type(e).__name__}", exc_info=True) | |
| return ChatResponse( | |
| success=False, | |
| answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", | |
| citations=[], | |
| confidence=0.0, | |
| from_knowledge_base=False, | |
| escalation_suggested=True, | |
| conversation_id=conversation_id, | |
| metadata={"error": str(e), "error_type": type(e).__name__} | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTP exceptions from outer try block | |
| raise | |
| except Exception as e: | |
| logger.error(f"Outer chat error: {e}", exc_info=True) | |
| return ChatResponse( | |
| success=False, | |
| answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.", | |
| citations=[], | |
| confidence=0.0, | |
| from_knowledge_base=False, | |
| escalation_suggested=True, | |
| conversation_id=conversation_id, | |
| metadata={"error": str(e), "error_type": type(e).__name__} | |
| ) | |
| # ============== Utility Endpoints ============== | |
| async def search_kb( | |
| request: Request, | |
| query: str, | |
| tenant_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| kb_id: Optional[str] = None, | |
| user_id: Optional[str] = None, # Optional in dev, ignored in prod | |
| top_k: int = 5 | |
| ): | |
| """ | |
| Search the knowledge base without generating an answer. | |
| Useful for debugging and testing retrieval. | |
| """ | |
| # SECURITY: Extract tenant_id from auth token in production | |
| if settings.ENV == "prod": | |
| auth_context = await require_auth(request) | |
| tenant_id = auth_context.get("tenant_id") | |
| user_id = auth_context.get("user_id") | |
| if not tenant_id or not user_id: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="tenant_id and user_id must come from authentication token in production mode" | |
| ) | |
| elif not tenant_id or not kb_id or not user_id: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="tenant_id, kb_id, and user_id are required" | |
| ) | |
| try: | |
| retrieval_service = get_retrieval_service() | |
| results, confidence, has_relevant = retrieval_service.retrieve( | |
| query=query, | |
| tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation | |
| kb_id=kb_id, | |
| user_id=user_id, | |
| top_k=top_k | |
| ) | |
| return { | |
| "success": True, | |
| "results": [ | |
| { | |
| "chunk_id": r.chunk_id, | |
| "content": r.content[:500] + "..." if len(r.content) > 500 else r.content, | |
| "metadata": r.metadata, | |
| "similarity_score": r.similarity_score | |
| } | |
| for r in results | |
| ], | |
| "confidence": confidence, | |
| "has_relevant_results": has_relevant | |
| } | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ============== Billing & Usage Endpoints ============== | |
| async def get_usage( | |
| request: Request, | |
| range: str = "month", # "day" or "month" | |
| year: Optional[int] = None, | |
| month: Optional[int] = None, | |
| day: Optional[int] = None | |
| ): | |
| """ | |
| Get usage statistics for the current tenant. | |
| Args: | |
| range: "day" or "month" | |
| year: Year (optional, defaults to current) | |
| month: Month 1-12 (optional, defaults to current) | |
| day: Day 1-31 (optional, defaults to current, only for range="day") | |
| """ | |
| # Get tenant from auth | |
| auth_context = await get_auth_context(request) | |
| tenant_id = auth_context.get("tenant_id") | |
| if not tenant_id: | |
| raise HTTPException(status_code=403, detail="tenant_id required") | |
| db = next(get_db()) | |
| try: | |
| from app.db.models import UsageDaily, UsageMonthly | |
| from datetime import datetime | |
| from calendar import monthrange | |
| now = datetime.utcnow() | |
| target_year = year or now.year | |
| target_month = month or now.month | |
| if range == "day": | |
| target_day = day or now.day | |
| date_start = datetime(target_year, target_month, target_day) | |
| daily = db.query(UsageDaily).filter( | |
| UsageDaily.tenant_id == tenant_id, | |
| UsageDaily.date == date_start | |
| ).first() | |
| if not daily: | |
| return UsageResponse( | |
| tenant_id=tenant_id, | |
| period="day", | |
| total_requests=0, | |
| total_tokens=0, | |
| total_cost_usd=0.0, | |
| start_date=date_start, | |
| end_date=date_start | |
| ) | |
| return UsageResponse( | |
| tenant_id=tenant_id, | |
| period="day", | |
| total_requests=daily.total_requests, | |
| total_tokens=daily.total_tokens, | |
| total_cost_usd=daily.total_cost_usd, | |
| gemini_requests=daily.gemini_requests, | |
| openai_requests=daily.openai_requests, | |
| start_date=daily.date, | |
| end_date=daily.date | |
| ) | |
| else: # month | |
| monthly = db.query(UsageMonthly).filter( | |
| UsageMonthly.tenant_id == tenant_id, | |
| UsageMonthly.year == target_year, | |
| UsageMonthly.month == target_month | |
| ).first() | |
| if not monthly: | |
| # Calculate date range for the month | |
| _, last_day = monthrange(target_year, target_month) | |
| start_date = datetime(target_year, target_month, 1) | |
| end_date = datetime(target_year, target_month, last_day) | |
| return UsageResponse( | |
| tenant_id=tenant_id, | |
| period="month", | |
| total_requests=0, | |
| total_tokens=0, | |
| total_cost_usd=0.0, | |
| start_date=start_date, | |
| end_date=end_date | |
| ) | |
| _, last_day = monthrange(monthly.year, monthly.month) | |
| start_date = datetime(monthly.year, monthly.month, 1) | |
| end_date = datetime(monthly.year, monthly.month, last_day) | |
| return UsageResponse( | |
| tenant_id=tenant_id, | |
| period="month", | |
| total_requests=monthly.total_requests, | |
| total_tokens=monthly.total_tokens, | |
| total_cost_usd=monthly.total_cost_usd, | |
| gemini_requests=monthly.gemini_requests, | |
| openai_requests=monthly.openai_requests, | |
| start_date=start_date, | |
| end_date=end_date | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting usage: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_limits(request: Request): | |
| """Get current plan limits and usage for the tenant.""" | |
| # Get tenant from auth | |
| auth_context = await get_auth_context(request) | |
| tenant_id = auth_context.get("tenant_id") | |
| if not tenant_id: | |
| raise HTTPException(status_code=403, detail="tenant_id required") | |
| db = next(get_db()) | |
| try: | |
| from app.billing.quota import get_tenant_plan, get_monthly_usage | |
| from datetime import datetime | |
| plan = get_tenant_plan(db, tenant_id) | |
| if not plan: | |
| # Default to starter | |
| plan_name = "starter" | |
| monthly_limit = 500 | |
| else: | |
| plan_name = plan.plan_name | |
| monthly_limit = plan.monthly_chat_limit | |
| # Get current month usage | |
| now = datetime.utcnow() | |
| monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month) | |
| current_usage = monthly_usage.total_requests if monthly_usage else 0 | |
| remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage) | |
| return PlanLimitsResponse( | |
| tenant_id=tenant_id, | |
| plan_name=plan_name, | |
| monthly_chat_limit=monthly_limit, | |
| current_month_usage=current_usage, | |
| remaining_chats=remaining | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting limits: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def set_plan(request_body: SetPlanRequest, http_request: Request): | |
| """ | |
| Set tenant's subscription plan (admin only in production). | |
| In dev mode, allows any tenant to set their plan. | |
| In prod mode, should be restricted to admin users. | |
| """ | |
| # Get tenant from auth | |
| auth_context = await get_auth_context(http_request) | |
| auth_tenant_id = auth_context.get("tenant_id") | |
| # In prod, verify admin role (placeholder - implement actual admin check) | |
| if settings.ENV == "prod": | |
| # TODO: Add admin role check | |
| if auth_tenant_id != request_body.tenant_id: | |
| raise HTTPException(status_code=403, detail="Cannot set plan for other tenants") | |
| db = next(get_db()) | |
| try: | |
| from app.billing.quota import set_tenant_plan | |
| plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name) | |
| return { | |
| "success": True, | |
| "tenant_id": request_body.tenant_id, | |
| "plan_name": plan.plan_name, | |
| "monthly_chat_limit": plan.monthly_chat_limit | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Error setting plan: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_cost_report( | |
| request: Request, | |
| range: str = "month", | |
| year: Optional[int] = None, | |
| month: Optional[int] = None | |
| ): | |
| """Get cost report with breakdown by provider and model.""" | |
| # Get tenant from auth | |
| auth_context = await get_auth_context(request) | |
| tenant_id = auth_context.get("tenant_id") | |
| if not tenant_id: | |
| raise HTTPException(status_code=403, detail="tenant_id required") | |
| db = next(get_db()) | |
| try: | |
| from app.db.models import UsageEvent | |
| from datetime import datetime | |
| from sqlalchemy import func, and_ | |
| now = datetime.utcnow() | |
| target_year = year or now.year | |
| target_month = month or now.month | |
| # Query usage events for the period | |
| if range == "month": | |
| query = db.query(UsageEvent).filter( | |
| and_( | |
| UsageEvent.tenant_id == tenant_id, | |
| func.extract('year', UsageEvent.request_timestamp) == target_year, | |
| func.extract('month', UsageEvent.request_timestamp) == target_month | |
| ) | |
| ) | |
| else: # all time | |
| query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id) | |
| events = query.all() | |
| # Calculate totals | |
| total_cost = sum(e.estimated_cost_usd for e in events) | |
| total_requests = len(events) | |
| total_tokens = sum(e.total_tokens for e in events) | |
| # Breakdown by provider | |
| breakdown_by_provider = {} | |
| for event in events: | |
| provider = event.provider | |
| if provider not in breakdown_by_provider: | |
| breakdown_by_provider[provider] = { | |
| "requests": 0, | |
| "tokens": 0, | |
| "cost_usd": 0.0 | |
| } | |
| breakdown_by_provider[provider]["requests"] += 1 | |
| breakdown_by_provider[provider]["tokens"] += event.total_tokens | |
| breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd | |
| # Breakdown by model | |
| breakdown_by_model = {} | |
| for event in events: | |
| model = event.model | |
| if model not in breakdown_by_model: | |
| breakdown_by_model[model] = { | |
| "requests": 0, | |
| "tokens": 0, | |
| "cost_usd": 0.0 | |
| } | |
| breakdown_by_model[model]["requests"] += 1 | |
| breakdown_by_model[model]["tokens"] += event.total_tokens | |
| breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd | |
| return CostReportResponse( | |
| tenant_id=tenant_id, | |
| period=range, | |
| total_cost_usd=total_cost, | |
| total_requests=total_requests, | |
| total_tokens=total_tokens, | |
| breakdown_by_provider=breakdown_by_provider, | |
| breakdown_by_model=breakdown_by_model | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error getting cost report: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |