Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Depends, status | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict, Any, AsyncGenerator | |
| import asyncio | |
| import json | |
| import uuid | |
| from datetime import datetime | |
| import os | |
| from contextlib import asynccontextmanager | |
| import tempfile | |
| import shutil | |
| import random | |
| import hashlib | |
| import secrets | |
| from functools import wraps | |
| # Third-party imports | |
| from openai import OpenAI, AsyncOpenAI | |
| from qdrant_client import AsyncQdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import PyPDF2 | |
| # Models | |
| OPENROUTER_MODELS = ["deepseek/deepseek-chat-v3-0324:free", "deepseek/deepseek-r1-0528:free", "qwen/qwen3-235b-a22b:free", "google/gemini-2.0-flash-exp:free"] | |
| GROQ_MODELS = ["llama-3.3-70b-versatile", "openai/gpt-oss-120b"] | |
| # Models for OpenAI-compatible API | |
| class Message(BaseModel): | |
| role: str = Field(..., description="The role of the message author") | |
| content: str = Field(..., description="The content of the message") | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default="auto", description="Model to use (auto for dynamic selection)") | |
| messages: List[Message] = Field(..., description="List of messages") | |
| max_tokens: Optional[int] = Field(default=1024, description="Maximum tokens to generate") | |
| temperature: Optional[float] = Field(default=0.7, description="Temperature for sampling") | |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
| top_p: Optional[float] = Field(default=1.0, description="Top-p sampling parameter") | |
| provider: Optional[str] = Field(default="random", description="Provider to use (random, openrouter, groq)") | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[Dict[str, Any]] | |
| usage: Optional[Dict[str, int]] = None | |
| class ChatCompletionChunk(BaseModel): | |
| id: str | |
| object: str = "chat.completion.chunk" | |
| created: int | |
| model: str | |
| choices: List[Dict[str, Any]] | |
| class DocumentUploadRequest(BaseModel): | |
| metadata: Optional[Dict[str, Any]] = None | |
| class DocumentSearchRequest(BaseModel): | |
| query: str = Field(..., description="Search query") | |
| limit: int = Field(default=5, description="Maximum number of results") | |
| min_score: float = Field(default=0.1, description="Minimum similarity score") | |
| # Configuration | |
| class Config: | |
| # Provider API Keys | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| # Vector DB Configuration | |
| QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| COLLECTION_NAME = os.getenv("COLLECTION_NAME", "documents") | |
| # Embedding Configuration | |
| EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| TOP_K = int(os.getenv("TOP_K", "10")) | |
| SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.1")) | |
| DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") | |
| # Security Configuration | |
| API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else [] | |
| MASTER_KEY = os.getenv("MASTER_KEY", "") | |
| ENABLE_SECURITY = os.getenv("ENABLE_SECURITY", "true").lower() == "true" | |
| RATE_LIMIT_PER_MINUTE = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) | |
| def generate_api_key(cls) -> str: | |
| """Generate a new API key""" | |
| return f"sk-{secrets.token_urlsafe(32)}" | |
| def validate_api_key(cls, api_key: str) -> bool: | |
| """Validate API key""" | |
| if not cls.ENABLE_SECURITY: | |
| return True | |
| if not api_key: | |
| return False | |
| # Check master key | |
| if cls.MASTER_KEY and api_key == cls.MASTER_KEY: | |
| return True | |
| # Check configured API keys | |
| if cls.API_KEYS and api_key in cls.API_KEYS: | |
| return True | |
| return False | |
| # Security Models | |
| class APIKeyRequest(BaseModel): | |
| description: Optional[str] = Field(None, description="Description for the API key") | |
| class APIKeyResponse(BaseModel): | |
| api_key: str | |
| description: Optional[str] = None | |
| created_at: str | |
| status: str = "active" | |
| class SecurityInfo(BaseModel): | |
| security_enabled: bool | |
| rate_limit_per_minute: int | |
| has_master_key: bool | |
| configured_keys_count: int | |
| # Rate Limiting | |
| class RateLimiter: | |
| def __init__(self): | |
| self.requests = {} | |
| self.blocked_ips = set() | |
| def is_allowed(self, identifier: str, limit_per_minute: int = Config.RATE_LIMIT_PER_MINUTE) -> bool: | |
| """Check if request is allowed based on rate limit""" | |
| if not Config.ENABLE_SECURITY: | |
| return True | |
| if identifier in self.blocked_ips: | |
| return False | |
| now = datetime.now() | |
| minute_key = now.strftime("%Y-%m-%d %H:%M") | |
| if identifier not in self.requests: | |
| self.requests[identifier] = {} | |
| if minute_key not in self.requests[identifier]: | |
| self.requests[identifier][minute_key] = 0 | |
| # Clean old entries (keep only last 2 minutes) | |
| keys_to_remove = [] | |
| for key in self.requests[identifier]: | |
| try: | |
| key_time = datetime.strptime(key, "%Y-%m-%d %H:%M") | |
| if (now - key_time).total_seconds() > 120: # 2 minutes | |
| keys_to_remove.append(key) | |
| except ValueError: | |
| keys_to_remove.append(key) | |
| for key in keys_to_remove: | |
| del self.requests[identifier][key] | |
| # Check current minute limit | |
| current_requests = self.requests[identifier].get(minute_key, 0) | |
| if current_requests >= limit_per_minute: | |
| return False | |
| self.requests[identifier][minute_key] = current_requests + 1 | |
| return True | |
| def block_ip(self, ip: str): | |
| """Block an IP address""" | |
| self.blocked_ips.add(ip) | |
| def unblock_ip(self, ip: str): | |
| """Unblock an IP address""" | |
| self.blocked_ips.discard(ip) | |
| # Security Dependencies | |
| security = HTTPBearer(auto_error=False) | |
| rate_limiter = RateLimiter() | |
| async def verify_api_key( | |
| request: Request, | |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) | |
| ) -> str: | |
| """Verify API key from Authorization header""" | |
| if not Config.ENABLE_SECURITY: | |
| return "security_disabled" | |
| # Get client IP | |
| client_ip = request.client.host | |
| # Check rate limit | |
| if not rate_limiter.is_allowed(client_ip): | |
| raise HTTPException( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| detail="Rate limit exceeded" | |
| ) | |
| # Check API key | |
| if not credentials: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="API key required. Please provide a valid API key in the Authorization header as 'Bearer <your-api-key>'" | |
| ) | |
| api_key = credentials.credentials | |
| if not Config.validate_api_key(api_key): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid API key" | |
| ) | |
| return api_key | |
| async def verify_master_key( | |
| request: Request, | |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends(security) | |
| ) -> str: | |
| """Verify master key for admin operations""" | |
| if not Config.ENABLE_SECURITY: | |
| return "security_disabled" | |
| if not credentials: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Master key required for admin operations" | |
| ) | |
| api_key = credentials.credentials | |
| if not Config.MASTER_KEY or api_key != Config.MASTER_KEY: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid master key" | |
| ) | |
| return api_key | |
| class DynamicOpenAIService: | |
| """Service for dynamic OpenAI provider selection""" | |
| def __init__(self): | |
| self.validate_api_keys() | |
| def validate_api_keys(self): | |
| """Validate that at least one API key is available""" | |
| if not Config.OPENROUTER_API_KEY and not Config.GROQ_API_KEY: | |
| raise ValueError("At least one API key (OPENROUTER_API_KEY or GROQ_API_KEY) must be provided") | |
| if not Config.OPENROUTER_API_KEY: | |
| print("Warning: OPENROUTER_API_KEY not found, will only use Groq") | |
| if not Config.GROQ_API_KEY: | |
| print("Warning: GROQ_API_KEY not found, will only use OpenRouter") | |
| def get_client(self, provider="random"): | |
| """Get OpenAI client for specified provider""" | |
| available_providers = [] | |
| if Config.OPENROUTER_API_KEY: | |
| available_providers.append("openrouter") | |
| if Config.GROQ_API_KEY: | |
| available_providers.append("groq") | |
| if not available_providers: | |
| raise ValueError("No API keys available for any provider") | |
| if provider == "random": | |
| provider = random.choice(available_providers) | |
| elif provider not in available_providers: | |
| # Fallback to available provider | |
| provider = available_providers[0] | |
| print(f"Requested provider not available, using {provider}") | |
| print(f"Selected provider: {provider}") | |
| if provider == "openrouter": | |
| return ( | |
| OpenAI(api_key=Config.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1"), | |
| OPENROUTER_MODELS, | |
| provider | |
| ) | |
| else: # groq | |
| return ( | |
| OpenAI(api_key=Config.GROQ_API_KEY, base_url="https://api.groq.com/openai/v1"), | |
| GROQ_MODELS, | |
| provider | |
| ) | |
| async def get_async_client(self, provider="random"): | |
| """Get AsyncOpenAI client for specified provider""" | |
| available_providers = [] | |
| if Config.OPENROUTER_API_KEY: | |
| available_providers.append("openrouter") | |
| if Config.GROQ_API_KEY: | |
| available_providers.append("groq") | |
| if not available_providers: | |
| raise ValueError("No API keys available for any provider") | |
| if provider == "random": | |
| provider = random.choice(available_providers) | |
| elif provider not in available_providers: | |
| # Fallback to available provider | |
| provider = available_providers[0] | |
| print(f"Requested provider not available, using {provider}") | |
| print(f"Selected provider: {provider}") | |
| if provider == "openrouter": | |
| return ( | |
| AsyncOpenAI(api_key=Config.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1"), | |
| OPENROUTER_MODELS, | |
| provider | |
| ) | |
| else: # groq | |
| return ( | |
| AsyncOpenAI(api_key=Config.GROQ_API_KEY, base_url="https://api.groq.com/openai/v1"), | |
| GROQ_MODELS, | |
| provider | |
| ) | |
| def get_text_response(self, prompt, provider="random", model=None): | |
| """Get text response from AI""" | |
| client, models, selected_provider = self.get_client(provider) | |
| if not model or model == "auto": | |
| model = random.choice(models) | |
| print(f"Using model: {model}") | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1024, | |
| temperature=0.7 | |
| ) | |
| return response.choices[0].message.content | |
| def get_text_response_streaming(self, prompt, provider="random", model=None): | |
| """Get streaming text response from AI""" | |
| client, models, selected_provider = self.get_client(provider) | |
| if not model or model == "auto": | |
| model = random.choice(models) | |
| print(f"Using model: {model}") | |
| stream = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1024, | |
| temperature=0.7, | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content is not None: | |
| yield chunk.choices[0].delta.content | |
| class ApplicationState: | |
| """Application state container""" | |
| def __init__(self): | |
| self.openai_service = None | |
| self.qdrant_client = None | |
| self.embedding_service = None | |
| self.document_manager = None | |
| # Global state instance | |
| app_state = ApplicationState() | |
| class EmbeddingService: | |
| """Service for generating embeddings using sentence-transformers""" | |
| def __init__(self): | |
| self.model_name = Config.EMBEDDING_MODEL | |
| self.device = Config.DEVICE | |
| self.dimension = 384 # all-MiniLM-L6-v2 dimension | |
| self.executor = ThreadPoolExecutor(max_workers=4) | |
| # Load the model | |
| print(f"Loading embedding model: {self.model_name}") | |
| self.model = SentenceTransformer(self.model_name, device=self.device) | |
| print(f"Model loaded successfully on device: {self.device}") | |
| async def get_embedding(self, text: str) -> List[float]: | |
| """Generate embedding for given text""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| embedding = await loop.run_in_executor( | |
| self.executor, | |
| self._encode_text, | |
| text | |
| ) | |
| return embedding.tolist() | |
| except Exception as e: | |
| print(f"Error generating embedding: {e}") | |
| return [0.1] * self.dimension | |
| def _encode_text(self, text: str): | |
| """Synchronous text encoding - runs in thread pool""" | |
| return self.model.encode([text])[0] | |
| async def get_document_embedding(self, text: str) -> List[float]: | |
| """Generate embedding for document text""" | |
| return await self.get_embedding(text) | |
| async def get_query_embedding(self, text: str) -> List[float]: | |
| """Generate embedding for query text""" | |
| return await self.get_embedding(text) | |
| async def batch_embed(self, texts: List[str]) -> List[List[float]]: | |
| """Generate embeddings for multiple texts efficiently""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| embeddings = await loop.run_in_executor( | |
| self.executor, | |
| self._batch_encode_texts, | |
| texts | |
| ) | |
| return embeddings.tolist() | |
| except Exception as e: | |
| print(f"Error in batch embedding: {e}") | |
| return [[0.1] * self.dimension for _ in texts] | |
| def _batch_encode_texts(self, texts: List[str]): | |
| """Synchronous batch encoding - runs in thread pool""" | |
| return self.model.encode(texts) | |
| def health_check(self) -> dict: | |
| """Check embedding service health""" | |
| try: | |
| test_embedding = self.model.encode(["test"]) | |
| return { | |
| "status": "healthy", | |
| "model": self.model_name, | |
| "device": self.device, | |
| "dimension": self.dimension, | |
| "test_embedding_shape": test_embedding.shape | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "model": self.model_name, | |
| "error": str(e) | |
| } | |
| class DocumentManager: | |
| """Enhanced document management with async support""" | |
| def __init__(self, qdrant_client: AsyncQdrantClient, embedding_service: EmbeddingService): | |
| self.qdrant_client = qdrant_client | |
| self.embedding_service = embedding_service | |
| self.collection_name = Config.COLLECTION_NAME | |
| self.vector_size = 384 | |
| self.executor = ThreadPoolExecutor(max_workers=2) | |
| async def _read_pdf(self, file_path: str) -> str: | |
| """Read text from PDF file asynchronously""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor(self.executor, self._sync_read_pdf, file_path) | |
| except Exception as e: | |
| print(f"Error reading PDF {file_path}: {e}") | |
| return "" | |
| def _sync_read_pdf(self, file_path: str) -> str: | |
| """Synchronous PDF reading""" | |
| try: | |
| with open(file_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except Exception as e: | |
| print(f"Error reading PDF {file_path}: {e}") | |
| return "" | |
| def _chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into chunks""" | |
| if len(text) <= chunk_size: | |
| return [text] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| if end < len(text): | |
| sentence_end = text.rfind('.', start, end) | |
| if sentence_end > start: | |
| end = sentence_end + 1 | |
| else: | |
| word_end = text.rfind(' ', start, end) | |
| if word_end > start: | |
| end = word_end | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| start = end - overlap | |
| return chunks | |
| async def _ensure_collection_exists(self): | |
| """Ensure the collection exists, create if it doesn't""" | |
| try: | |
| collections = await self.qdrant_client.get_collections() | |
| collection_names = [c.name for c in collections.collections] | |
| if self.collection_name not in collection_names: | |
| print(f"Creating collection '{self.collection_name}' on-demand...") | |
| await self.qdrant_client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=self.vector_size, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| print(f"✓ Collection '{self.collection_name}' created successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not ensure collection exists: {e}") | |
| async def add_document(self, file_path: str, metadata: Dict[str, Any] = None) -> str: | |
| """Add a PDF document to the collection""" | |
| try: | |
| await self._ensure_collection_exists() | |
| # Read PDF | |
| text = await self._read_pdf(file_path) | |
| if not text: | |
| print(f"Could not extract text from {file_path}") | |
| return "" | |
| # Create chunks | |
| chunks = self._chunk_text(text) | |
| if not chunks: | |
| print(f"No chunks created from {file_path}") | |
| return "" | |
| # Generate document ID | |
| document_id = str(uuid.uuid4()) | |
| # Create embeddings for all chunks | |
| embeddings = await self.embedding_service.batch_embed(chunks) | |
| # Create points for each chunk | |
| points = [] | |
| for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
| payload = { | |
| "document_id": document_id, | |
| "file_path": file_path, | |
| "chunk_index": i, | |
| "content": chunk, # Use 'content' as the main field | |
| "chunk_text": chunk, # Keep for compatibility | |
| "total_chunks": len(chunks), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| if metadata: | |
| payload["metadata"] = metadata | |
| point = PointStruct( | |
| id=str(uuid.uuid4()), | |
| vector=embedding, | |
| payload=payload | |
| ) | |
| points.append(point) | |
| # Insert into Qdrant | |
| await self.qdrant_client.upsert(collection_name=self.collection_name, points=points) | |
| print(f"✓ Added document: {file_path}") | |
| print(f" Document ID: {document_id}") | |
| print(f" Chunks: {len(chunks)}") | |
| return document_id | |
| except Exception as e: | |
| print(f"Error adding document {file_path}: {e}") | |
| return "" | |
| async def search_documents(self, query: str, limit: int = 5, min_score: float = 0.1) -> List[Dict[str, Any]]: | |
| """Search for relevant document chunks""" | |
| try: | |
| await self._ensure_collection_exists() | |
| print(f"Document Search - Query: '{query}', Limit: {limit}, Min Score: {min_score}") | |
| # Generate query embedding | |
| query_embedding = await self.embedding_service.get_query_embedding(query) | |
| print(f"Document Search - Generated embedding vector of size: {len(query_embedding)}") | |
| # Search in Qdrant | |
| search_results = await self.qdrant_client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| limit=limit, | |
| score_threshold=min_score | |
| ) | |
| print(f"Document Search - Qdrant returned {len(search_results)} results") | |
| # Format results | |
| results = [] | |
| for i, result in enumerate(search_results): | |
| content = result.payload.get("content", result.payload.get("chunk_text", "")) | |
| print(f"Document Search - Result {i+1}: Score={result.score:.4f}, Content preview: {content[:100]}...") | |
| results.append({ | |
| "score": result.score, | |
| "text": content, | |
| "file_path": result.payload.get("file_path", ""), | |
| "document_id": result.payload.get("document_id", ""), | |
| "chunk_index": result.payload.get("chunk_index", 0) | |
| }) | |
| print(f"✓ Document Search - Found {len(results)} results for query: '{query}'") | |
| return results | |
| except Exception as e: | |
| print(f"Error searching: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| async def list_documents(self) -> List[Dict[str, Any]]: | |
| """List all documents in the collection""" | |
| try: | |
| await self._ensure_collection_exists() | |
| # Get all points | |
| points, _ = await self.qdrant_client.scroll( | |
| collection_name=self.collection_name, | |
| limit=10000, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| # Group by document_id | |
| documents = {} | |
| for point in points: | |
| doc_id = point.payload.get("document_id") | |
| if doc_id and doc_id not in documents: | |
| documents[doc_id] = { | |
| "document_id": doc_id, | |
| "file_path": point.payload.get("file_path", ""), | |
| "total_chunks": point.payload.get("total_chunks", 0), | |
| "timestamp": point.payload.get("timestamp", ""), | |
| "metadata": point.payload.get("metadata", {}) | |
| } | |
| doc_list = list(documents.values()) | |
| print(f"✓ Found {len(doc_list)} documents") | |
| return doc_list | |
| except Exception as e: | |
| print(f"Error listing documents: {e}") | |
| return [] | |
| async def delete_document(self, document_id: str) -> bool: | |
| """Delete a document and all its chunks""" | |
| try: | |
| await self._ensure_collection_exists() | |
| # Find all points for this document | |
| points, _ = await self.qdrant_client.scroll( | |
| collection_name=self.collection_name, | |
| limit=10000, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| # Collect point IDs to delete | |
| points_to_delete = [] | |
| for point in points: | |
| if point.payload.get("document_id") == document_id: | |
| points_to_delete.append(point.id) | |
| if not points_to_delete: | |
| print(f"No document found with ID: {document_id}") | |
| return False | |
| # Delete points | |
| await self.qdrant_client.delete( | |
| collection_name=self.collection_name, | |
| points_selector=points_to_delete | |
| ) | |
| print(f"✓ Deleted document: {document_id} ({len(points_to_delete)} chunks)") | |
| return True | |
| except Exception as e: | |
| print(f"Error deleting document: {e}") | |
| return False | |
| class RAGService: | |
| """Service for retrieval-augmented generation""" | |
| async def retrieve_relevant_chunks(query: str, top_k: int = Config.TOP_K) -> List[Dict[str, Any]]: | |
| """Retrieve relevant document chunks using the document manager""" | |
| try: | |
| if app_state.document_manager is None: | |
| print("Error: Document manager is not initialized") | |
| return [] | |
| # Use a lower similarity threshold for RAG to get more results | |
| min_score = 0.1 # Lower threshold for RAG | |
| print(f"RAG Search - Query: '{query}', Limit: {top_k}, Min Score: {min_score}") | |
| # Use the document manager's search functionality | |
| results = await app_state.document_manager.search_documents( | |
| query=query, | |
| limit=top_k, | |
| min_score=min_score | |
| ) | |
| print(f"RAG Search - Found {len(results)} results") | |
| # If no results with low threshold, try even lower | |
| if not results: | |
| print("No results with min_score=0.1, trying with min_score=0.0") | |
| results = await app_state.document_manager.search_documents( | |
| query=query, | |
| limit=top_k, | |
| min_score=0.0 | |
| ) | |
| print(f"RAG Search - Found {len(results)} results with min_score=0.0") | |
| return results | |
| except Exception as e: | |
| print(f"Error retrieving chunks: {e}") | |
| return [] | |
| def build_context_prompt(query: str, results: List[Dict[str, Any]]) -> str: | |
| """Build a context-aware prompt with retrieved chunks""" | |
| if not results: | |
| return query | |
| # Build context parts | |
| context_parts = [] | |
| for result in results: | |
| context_parts.append(f"Source: {result['file_path']}\n{result['text']}") | |
| combined_context = "\n\n---\n\n".join(context_parts) | |
| prompt = f""" | |
| You are a helpful assistant answering questions about Subhrajit based on the provided text. | |
| **Instructions:** | |
| 1. Answer the user's question using ONLY the information from the context below. | |
| 2. Answer directly and naturally, as if you know the information yourself. Do NOT mention the context (e.g., avoid phrases like "Based on the context..."). | |
| 3. The user's question may use pronouns like "he," "him," or "his." These always refer to Subhrajit. In your answer, use the name "Subhrajit" for clarity instead of using pronouns. | |
| 4. If the context does not contain the answer to the question, respond with: "I'm sorry, I don't have enough information to answer that question." | |
| 5. Format your entire response in Markdown. | |
| **Context:** | |
| {combined_context} | |
| **Question:** | |
| {query} | |
| **Answer:** | |
| """ | |
| return prompt | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| print("Initializing services...") | |
| # Initialize dynamic OpenAI service | |
| try: | |
| app_state.openai_service = DynamicOpenAIService() | |
| print("✓ Dynamic OpenAI service initialized") | |
| except Exception as e: | |
| print(f"✗ Error initializing OpenAI service: {e}") | |
| raise e | |
| # Initialize Qdrant client | |
| try: | |
| app_state.qdrant_client = AsyncQdrantClient( | |
| url=Config.QDRANT_URL, | |
| api_key=Config.QDRANT_API_KEY | |
| ) | |
| print("✓ Qdrant client initialized") | |
| except Exception as e: | |
| print(f"✗ Error initializing Qdrant client: {e}") | |
| raise e | |
| # Initialize embedding service | |
| try: | |
| print("Loading embedding model...") | |
| app_state.embedding_service = EmbeddingService() | |
| print(f"✓ Embedding model loaded: {Config.EMBEDDING_MODEL}") | |
| print(f"✓ Model device: {Config.DEVICE}") | |
| print(f"✓ Vector dimension: {app_state.embedding_service.dimension}") | |
| except Exception as e: | |
| print(f"✗ Error initializing embedding service: {e}") | |
| raise e | |
| # Initialize document manager | |
| try: | |
| app_state.document_manager = DocumentManager( | |
| qdrant_client=app_state.qdrant_client, | |
| embedding_service=app_state.embedding_service | |
| ) | |
| print("✓ Document manager initialized") | |
| except Exception as e: | |
| print(f"✗ Error initializing document manager: {e}") | |
| raise e | |
| print("🚀 All services initialized successfully!") | |
| # Print security information | |
| if Config.ENABLE_SECURITY: | |
| print("\n🔒 Security Configuration:") | |
| print(f" Security: ENABLED") | |
| print(f" Rate Limit: {Config.RATE_LIMIT_PER_MINUTE} requests/minute") | |
| print(f" Master Key: {'✓ Configured' if Config.MASTER_KEY else '✗ Not configured'}") | |
| print(f" API Keys: {len([k for k in Config.API_KEYS if k.strip()])} configured") | |
| if not Config.MASTER_KEY and not Config.API_KEYS: | |
| print(" ⚠️ WARNING: No API keys configured! Set MASTER_KEY or API_KEYS environment variable.") | |
| else: | |
| print("\n🔓 Security: DISABLED") | |
| print(" All endpoints are publicly accessible") | |
| yield | |
| # Shutdown | |
| print("Shutting down services...") | |
| if app_state.qdrant_client: | |
| await app_state.qdrant_client.close() | |
| print("✓ Qdrant client closed") | |
| if app_state.embedding_service and hasattr(app_state.embedding_service, 'executor'): | |
| app_state.embedding_service.executor.shutdown(wait=True) | |
| print("✓ Embedding service executor shutdown") | |
| if app_state.document_manager and hasattr(app_state.document_manager, 'executor'): | |
| app_state.document_manager.executor.shutdown(wait=True) | |
| print("✓ Document manager executor shutdown") | |
| print("✓ Shutdown complete") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Enhanced RAG API with Dynamic Provider Selection", | |
| description="OpenAI-compatible API for RAG with dynamic provider selection (OpenRouter/Groq) and document management", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| async def root(): | |
| return { | |
| "message": "Enhanced RAG API with Dynamic Provider Selection", | |
| "status": "running", | |
| "security_enabled": Config.ENABLE_SECURITY, | |
| "version": "1.0.0" | |
| } | |
| async def health_check(api_key: str = Depends(verify_api_key)): | |
| """Health check endpoint""" | |
| try: | |
| # Test Qdrant connection | |
| if app_state.qdrant_client: | |
| collections = await app_state.qdrant_client.get_collections() | |
| qdrant_status = "connected" | |
| else: | |
| qdrant_status = "not_initialized" | |
| except Exception as e: | |
| qdrant_status = f"error: {str(e)}" | |
| # Test embedding service | |
| if app_state.embedding_service is None: | |
| embedding_health = {"status": "not_initialized", "error": "EmbeddingService is None"} | |
| else: | |
| try: | |
| embedding_health = app_state.embedding_service.health_check() | |
| except Exception as e: | |
| embedding_health = {"status": "error", "error": str(e)} | |
| # Test OpenAI service | |
| if app_state.openai_service is None: | |
| openai_health = {"status": "not_initialized", "error": "OpenAI service is None"} | |
| else: | |
| try: | |
| # Test both providers if available | |
| test_results = {} | |
| if Config.OPENROUTER_API_KEY: | |
| try: | |
| client, models, provider = app_state.openai_service.get_client("openrouter") | |
| test_response = client.chat.completions.create( | |
| model=models[0], | |
| messages=[{"role": "user", "content": "test"}], | |
| max_tokens=1 | |
| ) | |
| test_results["openrouter"] = {"status": "healthy", "model": models[0]} | |
| except Exception as e: | |
| test_results["openrouter"] = {"status": "error", "error": str(e)} | |
| if Config.GROQ_API_KEY: | |
| try: | |
| client, models, provider = app_state.openai_service.get_client("groq") | |
| test_response = client.chat.completions.create( | |
| model=models[0], | |
| messages=[{"role": "user", "content": "test"}], | |
| max_tokens=1 | |
| ) | |
| test_results["groq"] = {"status": "healthy", "model": models[0]} | |
| except Exception as e: | |
| test_results["groq"] = {"status": "error", "error": str(e)} | |
| openai_health = {"status": "healthy", "providers": test_results} | |
| except Exception as e: | |
| openai_health = {"status": "error", "error": str(e)} | |
| return { | |
| "status": "healthy" if app_state.embedding_service is not None else "unhealthy", | |
| "openai_service": openai_health, | |
| "qdrant": qdrant_status, | |
| "embedding_service": embedding_health, | |
| "document_manager": "initialized" if app_state.document_manager else "not_initialized", | |
| "collection": Config.COLLECTION_NAME, | |
| "embedding_model": Config.EMBEDDING_MODEL, | |
| "available_providers": { | |
| "openrouter": bool(Config.OPENROUTER_API_KEY), | |
| "groq": bool(Config.GROQ_API_KEY) | |
| } | |
| } | |
| async def chat_completions(request: ChatCompletionRequest, api_key: str = Depends(verify_api_key)): | |
| """OpenAI-compatible chat completions endpoint with enhanced RAG and dynamic provider selection""" | |
| if not app_state.openai_service: | |
| raise HTTPException(status_code=500, detail="OpenAI service not initialized") | |
| try: | |
| # Get the last user message for retrieval | |
| user_messages = [msg for msg in request.messages if msg.role == "user"] | |
| if not user_messages: | |
| raise HTTPException(status_code=400, detail="No user message found") | |
| last_user_message = user_messages[-1].content | |
| print(f"Processing query: {last_user_message[:100]}...") | |
| # Retrieve relevant chunks using enhanced search | |
| try: | |
| relevant_results = await RAGService.retrieve_relevant_chunks(last_user_message) | |
| print(f"Retrieved {len(relevant_results)} chunks") | |
| except Exception as e: | |
| print(f"Error in retrieval: {e}") | |
| relevant_results = [] | |
| # Build context-aware prompt | |
| if relevant_results: | |
| context_prompt = RAGService.build_context_prompt(last_user_message, relevant_results) | |
| enhanced_messages = request.messages[:-1] + [Message(role="user", content=context_prompt)] | |
| print("Using context-enhanced prompt") | |
| else: | |
| enhanced_messages = request.messages | |
| print("Using original prompt (no context)") | |
| # Convert to OpenAI format | |
| openai_messages = [{"role": msg.role, "content": msg.content} for msg in enhanced_messages] | |
| print(f"Sending {len(openai_messages)} messages to OpenAI API") | |
| if request.stream: | |
| return StreamingResponse( | |
| stream_chat_completion(openai_messages, request), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| return await create_chat_completion(openai_messages, request) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Unexpected error in chat_completions: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def create_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| """Create a non-streaming chat completion using dynamic provider selection""" | |
| try: | |
| # Get async client with dynamic provider selection | |
| client, models, selected_provider = await app_state.openai_service.get_async_client(request.provider) | |
| # Select model | |
| if request.model == "auto" or not request.model: | |
| selected_model = random.choice(models) | |
| else: | |
| selected_model = request.model | |
| print(f"Using provider: {selected_provider}, model: {selected_model}") | |
| response = await client.chat.completions.create( | |
| model=selected_model, | |
| messages=messages, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stream=False | |
| ) | |
| result = ChatCompletionResponse( | |
| id=response.id, | |
| created=response.created, | |
| model=f"{selected_provider}:{response.model}", # Include provider in model name | |
| choices=[{ | |
| "index": choice.index, | |
| "message": { | |
| "role": choice.message.role, | |
| "content": choice.message.content | |
| }, | |
| "finish_reason": choice.finish_reason | |
| } for choice in response.choices], | |
| usage={ | |
| "prompt_tokens": response.usage.prompt_tokens, | |
| "completion_tokens": response.usage.completion_tokens, | |
| "total_tokens": response.usage.total_tokens | |
| } if response.usage else None | |
| ) | |
| return result | |
| except Exception as e: | |
| print(f"Error in create_chat_completion: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}") | |
| async def stream_chat_completion(messages: List[Dict], request: ChatCompletionRequest) -> AsyncGenerator[str, None]: | |
| """Stream chat completion responses using dynamic provider selection""" | |
| try: | |
| # Get async client with dynamic provider selection | |
| client, models, selected_provider = await app_state.openai_service.get_async_client(request.provider) | |
| # Select model | |
| if request.model == "auto" or not request.model: | |
| selected_model = random.choice(models) | |
| else: | |
| selected_model = request.model | |
| print(f"Using provider: {selected_provider}, model: {selected_model}") | |
| stream = await client.chat.completions.create( | |
| model=selected_model, | |
| messages=messages, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stream=True | |
| ) | |
| async for chunk in stream: | |
| if chunk.choices and len(chunk.choices) > 0: | |
| choice = chunk.choices[0] | |
| if choice.delta: | |
| chunk_response = ChatCompletionChunk( | |
| id=chunk.id, | |
| created=chunk.created, | |
| model=f"{selected_provider}:{chunk.model}", # Include provider in model name | |
| choices=[{ | |
| "index": choice.index, | |
| "delta": { | |
| "role": choice.delta.role if choice.delta.role else None, | |
| "content": choice.delta.content if choice.delta.content else None | |
| }, | |
| "finish_reason": choice.finish_reason | |
| }] | |
| ) | |
| yield f"data: {chunk_response.model_dump_json()}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| print(f"Error in streaming: {e}") | |
| error_chunk = { | |
| "error": { | |
| "message": str(e), | |
| "type": "internal_error" | |
| } | |
| } | |
| yield f"data: {json.dumps(error_chunk)}\n\n" | |
| # Document management endpoints | |
| async def upload_document( | |
| file: UploadFile = File(...), | |
| metadata: str = None, | |
| api_key: str = Depends(verify_api_key) | |
| ): | |
| """Upload a PDF document""" | |
| try: | |
| if not app_state.document_manager: | |
| raise HTTPException(status_code=500, detail="Document manager not initialized") | |
| # Validate file type | |
| if not file.filename.lower().endswith('.pdf'): | |
| raise HTTPException(status_code=400, detail="Only PDF files are supported") | |
| # Parse metadata if provided | |
| parsed_metadata = {} | |
| if metadata: | |
| try: | |
| parsed_metadata = json.loads(metadata) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid metadata JSON") | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| shutil.copyfileobj(file.file, tmp_file) | |
| tmp_path = tmp_file.name | |
| try: | |
| # Add document to the collection | |
| document_id = await app_state.document_manager.add_document( | |
| file_path=tmp_path, | |
| metadata={ | |
| **parsed_metadata, | |
| "original_filename": file.filename, | |
| "upload_timestamp": datetime.now().isoformat() | |
| } | |
| ) | |
| if not document_id: | |
| raise HTTPException(status_code=500, detail="Failed to add document") | |
| return { | |
| "message": "Document uploaded successfully", | |
| "document_id": document_id, | |
| "filename": file.filename | |
| } | |
| finally: | |
| # Clean up temporary file | |
| os.unlink(tmp_path) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error uploading document: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") | |
| async def search_documents(request: DocumentSearchRequest, api_key: str = Depends(verify_api_key)): | |
| """Search for documents""" | |
| try: | |
| if not app_state.document_manager: | |
| raise HTTPException(status_code=500, detail="Document manager not initialized") | |
| results = await app_state.document_manager.search_documents( | |
| query=request.query, | |
| limit=request.limit, | |
| min_score=request.min_score | |
| ) | |
| return { | |
| "query": request.query, | |
| "results": results, | |
| "count": len(results) | |
| } | |
| except Exception as e: | |
| print(f"Error searching documents: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}") | |
| async def list_documents(api_key: str = Depends(verify_api_key)): | |
| """List all documents""" | |
| try: | |
| if not app_state.document_manager: | |
| raise HTTPException(status_code=500, detail="Document manager not initialized") | |
| documents = await app_state.document_manager.list_documents() | |
| return { | |
| "documents": documents, | |
| "count": len(documents) | |
| } | |
| except Exception as e: | |
| print(f"Error listing documents: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error listing documents: {str(e)}") | |
| async def delete_document(document_id: str, api_key: str = Depends(verify_api_key)): | |
| """Delete a document""" | |
| try: | |
| if not app_state.document_manager: | |
| raise HTTPException(status_code=500, detail="Document manager not initialized") | |
| success = await app_state.document_manager.delete_document(document_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| return {"message": "Document deleted successfully", "document_id": document_id} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error deleting document: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error deleting document: {str(e)}") | |
| # Legacy compatibility endpoints | |
| async def add_document_legacy(content: str, metadata: Optional[Dict] = None, api_key: str = Depends(verify_api_key)): | |
| """Legacy endpoint for adding documents (text content)""" | |
| try: | |
| if not app_state.embedding_service or not app_state.qdrant_client: | |
| raise HTTPException(status_code=500, detail="Services not initialized") | |
| await app_state.document_manager._ensure_collection_exists() | |
| embedding = await app_state.embedding_service.get_document_embedding(content) | |
| point = PointStruct( | |
| id=str(uuid.uuid4()), | |
| vector=embedding, | |
| payload={ | |
| "content": content, | |
| "metadata": metadata or {}, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| ) | |
| await app_state.qdrant_client.upsert( | |
| collection_name=Config.COLLECTION_NAME, | |
| points=[point] | |
| ) | |
| return {"message": "Document added successfully", "id": point.id} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error adding document: {str(e)}") | |
| async def get_collection_info(api_key: str = Depends(verify_api_key)): | |
| """Get information about the collection""" | |
| try: | |
| if app_state.qdrant_client is None: | |
| raise HTTPException(status_code=500, detail="Qdrant client is not initialized") | |
| await app_state.document_manager._ensure_collection_exists() | |
| collection_info = await app_state.qdrant_client.get_collection(Config.COLLECTION_NAME) | |
| return { | |
| "name": Config.COLLECTION_NAME, | |
| "vectors_count": collection_info.vectors_count, | |
| "status": collection_info.status, | |
| "vector_size": app_state.embedding_service.dimension if app_state.embedding_service else "unknown" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") | |
| # New endpoint to get available providers and models | |
| async def get_providers(api_key: str = Depends(verify_api_key)): | |
| """Get available providers and their models""" | |
| try: | |
| if not app_state.openai_service: | |
| raise HTTPException(status_code=500, detail="OpenAI service not initialized") | |
| available_providers = {} | |
| if Config.OPENROUTER_API_KEY: | |
| available_providers["openrouter"] = { | |
| "status": "available", | |
| "models": OPENROUTER_MODELS | |
| } | |
| else: | |
| available_providers["openrouter"] = { | |
| "status": "unavailable", | |
| "reason": "API key not provided" | |
| } | |
| if Config.GROQ_API_KEY: | |
| available_providers["groq"] = { | |
| "status": "available", | |
| "models": GROQ_MODELS | |
| } | |
| else: | |
| available_providers["groq"] = { | |
| "status": "unavailable", | |
| "reason": "API key not provided" | |
| } | |
| return { | |
| "providers": available_providers, | |
| "default_selection": "random" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error getting providers: {str(e)}") | |
| # Security Management Endpoints | |
| async def get_security_info() -> SecurityInfo: | |
| """Get security configuration information (public endpoint)""" | |
| return SecurityInfo( | |
| security_enabled=Config.ENABLE_SECURITY, | |
| rate_limit_per_minute=Config.RATE_LIMIT_PER_MINUTE, | |
| has_master_key=bool(Config.MASTER_KEY), | |
| configured_keys_count=len([k for k in Config.API_KEYS if k.strip()]) | |
| ) | |
| async def generate_api_key( | |
| request: APIKeyRequest, | |
| master_key: str = Depends(verify_master_key) | |
| ) -> APIKeyResponse: | |
| """Generate a new API key (requires master key)""" | |
| try: | |
| new_key = Config.generate_api_key() | |
| return APIKeyResponse( | |
| api_key=new_key, | |
| description=request.description, | |
| created_at=datetime.now().isoformat(), | |
| status="active" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error generating API key: {str(e)}") | |
| async def validate_api_key_endpoint( | |
| api_key: str = Depends(verify_api_key) | |
| ) -> Dict[str, Any]: | |
| """Validate an API key""" | |
| return { | |
| "valid": True, | |
| "key_type": "master" if api_key == Config.MASTER_KEY else "standard", | |
| "validated_at": datetime.now().isoformat() | |
| } | |
| async def get_rate_limit_status( | |
| request: Request, | |
| api_key: str = Depends(verify_api_key) | |
| ) -> Dict[str, Any]: | |
| """Get current rate limit status""" | |
| client_ip = request.client.host | |
| # Get current minute requests | |
| now = datetime.now() | |
| minute_key = now.strftime("%Y-%m-%d %H:%M") | |
| current_requests = 0 | |
| if client_ip in rate_limiter.requests: | |
| current_requests = rate_limiter.requests[client_ip].get(minute_key, 0) | |
| return { | |
| "client_ip": client_ip, | |
| "current_requests": current_requests, | |
| "limit_per_minute": Config.RATE_LIMIT_PER_MINUTE, | |
| "remaining_requests": max(0, Config.RATE_LIMIT_PER_MINUTE - current_requests), | |
| "reset_at": f"{minute_key}:00", | |
| "is_blocked": client_ip in rate_limiter.blocked_ips | |
| } | |
| # Admin endpoints for IP management | |
| async def block_ip( | |
| ip: str, | |
| master_key: str = Depends(verify_master_key) | |
| ) -> Dict[str, str]: | |
| """Block an IP address (requires master key)""" | |
| rate_limiter.block_ip(ip) | |
| return {"message": f"IP {ip} has been blocked", "blocked_at": datetime.now().isoformat()} | |
| async def unblock_ip( | |
| ip: str, | |
| master_key: str = Depends(verify_master_key) | |
| ) -> Dict[str, str]: | |
| """Unblock an IP address (requires master key)""" | |
| rate_limiter.unblock_ip(ip) | |
| return {"message": f"IP {ip} has been unblocked", "unblocked_at": datetime.now().isoformat()} | |
| async def get_blocked_ips( | |
| master_key: str = Depends(verify_master_key) | |
| ) -> Dict[str, Any]: | |
| """Get list of blocked IPs (requires master key)""" | |
| return { | |
| "blocked_ips": list(rate_limiter.blocked_ips), | |
| "count": len(rate_limiter.blocked_ips) | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |