Spaces:
Runtime error
Runtime error
| """ | |
| Complete RAG (Retrieval-Augmented Generation) QA System with MongoDB Atlas Vector Search | |
| A single-file implementation for document processing, embedding, and question answering. | |
| Updated to use MongoDB Atlas Vector Search for production-ready vector storage. | |
| Requirements: | |
| pip install langchain langchain-community langchain-mongodb pymongo sentence-transformers | |
| pip install faiss-cpu pypdf pandas requests beautifulsoup4 fastapi uvicorn | |
| pip install llama-cpp-python (optional, for GGUF models) | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| import logging | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from datetime import datetime | |
| # Load environment variables from .env file | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Set USER_AGENT to avoid warnings | |
| os.environ.setdefault("USER_AGENT", "RAG-System/1.0") | |
| # LangChain imports | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import ( | |
| PyPDFLoader, | |
| CSVLoader, | |
| JSONLoader, | |
| WebBaseLoader | |
| ) | |
| from langchain_community.document_loaders import RecursiveUrlLoader | |
| from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
| from langchain.llms.base import LLM | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| # MongoDB imports | |
| from pymongo import MongoClient | |
| from pymongo.collection import Collection | |
| # FastAPI and other imports | |
| from fastapi import FastAPI, UploadFile, File, Form, Request, Body, HTTPException | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.templating import Jinja2Templates | |
| import shutil | |
| import pathlib | |
| import tempfile | |
| # Ollama Cloud import | |
| import ollama_client | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app setup | |
| app = FastAPI(title="RAG QA System with MongoDB Atlas Vector Search") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| BASE_DIR = pathlib.Path(__file__).parent | |
| TEMPLATES_DIR = BASE_DIR / "templates" | |
| TEMPLATES_DIR.mkdir(exist_ok=True) | |
| templates = Jinja2Templates(directory=str(TEMPLATES_DIR)) | |
| # Configuration for MongoDB Atlas | |
| class Config: | |
| # MongoDB Atlas Configuration | |
| # Replace with your actual MongoDB Atlas connection string | |
| MONGODB_URI = os.getenv("MONGODB_URI", "mongodb+srv://jainishpatel188:clgC0gsKbQBAauiu@vector.f3dkdar.mongodb.net/") | |
| # Database and Collection Configuration (as per your structure) | |
| MONGODB_DB_NAME = "vector_data" # Your database name | |
| MONGODB_COLLECTION_NAME = "RAG" # Your collection name | |
| # Vector Search Index Configuration | |
| VECTOR_INDEX_NAME = "vector_search_index" # This will be created in Atlas | |
| # Embedding Configuration | |
| EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2 | |
| # LLM Configuration | |
| GGUF_MODEL_PATH = os.getenv("GGUF_MODEL_PATH", | |
| "C:\\Users\\jaini\\IntellijIdea\\Jainish PYTHON AI TIGER\\PDF\\mistral-7b-instruct-v0.2.Q4_K_M.gguf") | |
| OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "gpt-oss:120b-cloud") | |
| # Text Splitting Configuration | |
| CHUNK_SIZE = 1000 | |
| CHUNK_OVERLAP = 200 | |
| # Similarity Metrics Configuration | |
| SUPPORTED_METRICS = { | |
| "cosine": "cosine", | |
| "tanh_cosine": "tanh_cosine", | |
| "dot": "dotProduct", | |
| "euclidean": "euclidean" | |
| } | |
| DEFAULT_METRIC = "cosine" | |
| config = Config() | |
| # Custom LLM wrapper for Ollama Cloud | |
| class CustomOllamaLLM(LLM): | |
| """Custom LLM wrapper for Ollama Cloud models.""" | |
| model_name: str | |
| def __init__(self, model_name: str = None, **kwargs): | |
| if model_name is None: | |
| model_name = config.OLLAMA_MODEL | |
| super().__init__(model_name=model_name, **kwargs) | |
| self.model_name = model_name | |
| def _llm_type(self) -> str: | |
| return "custom_ollama" | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| """Generate response from Ollama Cloud.""" | |
| try: | |
| formatted_prompt = f"<s>[INST] {prompt} [/INST]" | |
| response = ollama_client.generate_from_ollama( | |
| model=self.model_name, | |
| prompt=formatted_prompt, | |
| max_tokens=512 | |
| ) | |
| answer = response.strip() | |
| return answer if answer else "I couldn't generate a relevant answer." | |
| except Exception as e: | |
| logger.error(f"Error generating response from Ollama Cloud: {e}") | |
| return f"Error generating answer: {e}" | |
| # MongoDB Atlas Vector Store Manager | |
| class MongoDBAtlasVectorStore: | |
| """Manages MongoDB Atlas vector storage and search operations.""" | |
| def __init__(self): | |
| # Initialize MongoDB connection | |
| self.client = MongoClient(config.MONGODB_URI) | |
| self.db = self.client[config.MONGODB_DB_NAME] | |
| self.collection = self.db[config.MONGODB_COLLECTION_NAME] | |
| # Initialize embeddings | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=config.EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| # Initialize text splitter | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=config.CHUNK_SIZE, | |
| chunk_overlap=config.CHUNK_OVERLAP, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| # Initialize LangChain MongoDB Atlas Vector Search | |
| self.vector_store = MongoDBAtlasVectorSearch( | |
| collection=self.collection, | |
| embedding=self.embeddings, | |
| index_name=config.VECTOR_INDEX_NAME | |
| ) | |
| # Test connection | |
| self._test_connection() | |
| logger.info("MongoDB Atlas Vector Store initialized successfully") | |
| def _test_connection(self): | |
| """Test MongoDB Atlas connection.""" | |
| try: | |
| # Test connection | |
| self.client.admin.command('ping') | |
| logger.info("β Successfully connected to MongoDB Atlas") | |
| # Check if collection exists | |
| if config.MONGODB_COLLECTION_NAME in self.db.list_collection_names(): | |
| logger.info(f"β Collection '{config.MONGODB_COLLECTION_NAME}' exists") | |
| else: | |
| logger.info(f"π Collection '{config.MONGODB_COLLECTION_NAME}' will be created") | |
| except Exception as e: | |
| logger.error(f"β MongoDB Atlas connection failed: {e}") | |
| raise | |
| def add_documents(self, documents: List[Document], source_info: Dict[str, Any] = None) -> int: | |
| """Add documents to the MongoDB Atlas vector store.""" | |
| try: | |
| # Split documents into chunks | |
| print("π Original document count:", len(documents)) | |
| text_chunks = self.text_splitter.split_documents(documents) | |
| # Add source info to metadata | |
| if source_info: | |
| for chunk in text_chunks: | |
| chunk.metadata.update(source_info) | |
| chunk.metadata.update({ | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "chunk_id": f"{source_info.get('source_file', 'unknown')}_{len(text_chunks)}" | |
| }) | |
| # Add documents to vector store using LangChain | |
| ids = self.vector_store.add_documents(text_chunks) | |
| logger.info(f"β Added {len(ids)} document chunks to MongoDB Atlas") | |
| return len(ids) | |
| except Exception as e: | |
| logger.error(f"β Error adding documents: {e}") | |
| raise | |
| def similarity_search(self, query: str, k: int = 3, metric: str = None, score_threshold: float = 0.0) -> List[Tuple[Document, float]]: | |
| """Perform similarity search using MongoDB Atlas Vector Search with specified metric.""" | |
| try: | |
| # Log the metric being used | |
| if metric and metric in config.SUPPORTED_METRICS: | |
| logger.info(f"οΏ½ Using similarity metric: {metric}") | |
| else: | |
| logger.info(f"οΏ½ Using default similarity metric: cosine") | |
| # Perform similarity search using the vector store | |
| results = self.vector_store.similarity_search_with_score( | |
| query=query, | |
| k=k | |
| ) | |
| logger.info(f"π Found {len(results)} relevant documents for query") | |
| return results | |
| except Exception as e: | |
| logger.error(f"β Error performing similarity search: {e}") | |
| # Fallback to basic similarity search without score | |
| try: | |
| docs = self.vector_store.similarity_search(query=query, k=k) | |
| # Convert to tuple format with dummy scores | |
| results = [(doc, 0.0) for doc in docs] | |
| logger.info(f"π Fallback search returned {len(results)} results") | |
| return results | |
| except Exception as fallback_e: | |
| logger.error(f"β Fallback search also failed: {fallback_e}") | |
| return [] | |
| def get_document_count(self) -> int: | |
| """Get total number of documents in the collection.""" | |
| try: | |
| count = self.collection.count_documents({}) | |
| logger.info(f"π Total documents in collection: {count}") | |
| return count | |
| except Exception as e: | |
| logger.error(f"β Error getting document count: {e}") | |
| return 0 | |
| def delete_all_documents(self) -> int: | |
| """Delete all documents from the collection (for testing purposes).""" | |
| try: | |
| result = self.collection.delete_many({}) | |
| logger.info(f"ποΈ Deleted {result.deleted_count} documents") | |
| return result.deleted_count | |
| except Exception as e: | |
| logger.error(f"β Error deleting documents: {e}") | |
| return 0 | |
| # Document Processor using LangChain | |
| class DocumentProcessor: | |
| """Handles document loading and processing using LangChain.""" | |
| def __init__(self): | |
| self.temp_dir = tempfile.mkdtemp() | |
| logger.info(f"π Created temporary directory: {self.temp_dir}") | |
| def process_json(self, file_path: str) -> List[Document]: | |
| """Convert JSON file to Document objects.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # Convert JSON to formatted string | |
| json_content = json.dumps(data, indent=2, ensure_ascii=False) | |
| # Create a Document object | |
| document = Document( | |
| page_content=json_content, | |
| metadata={ | |
| "source": file_path, | |
| "file_type": "json", | |
| "total_chars": len(json_content) | |
| } | |
| ) | |
| return [document] | |
| except json.JSONDecodeError as e: | |
| raise Exception(f"Invalid JSON format: {e}") | |
| except Exception as e: | |
| raise Exception(f"Error processing JSON: {e}") | |
| def process_uploaded_file(self, file_path: str, filename: str) -> List[Document]: | |
| """Process uploaded file and return LangChain Documents.""" | |
| _, ext = os.path.splitext(filename.lower()) | |
| try: | |
| if ext == '.pdf': | |
| loader = PyPDFLoader(file_path) | |
| documents = loader.load() | |
| logger.info(f"π Loaded PDF with {len(documents)} pages") | |
| elif ext == '.csv': | |
| loader = CSVLoader(file_path) | |
| documents = loader.load() | |
| logger.info(f"π Loaded CSV with {len(documents)} rows") | |
| elif ext == '.json': | |
| # Use simple JSON processing instead of JSONLoader | |
| documents = self.process_json(file_path) | |
| logger.info(f"π Loaded JSON with {len(documents)} document(s)") | |
| else: | |
| raise ValueError(f"Unsupported file type: {ext}") | |
| # Add source metadata | |
| for doc in documents: | |
| doc.metadata.update({ | |
| "source_file": filename, | |
| "file_type": ext, | |
| "processed_at": datetime.utcnow().isoformat() | |
| }) | |
| return documents | |
| except Exception as e: | |
| logger.error(f"β Error processing file {filename}: {e}") | |
| raise | |
| def process_url(self, url: str) -> List[Document]: | |
| """Process URL and return LangChain Documents.""" | |
| try: | |
| loader = WebBaseLoader([url]) | |
| documents = loader.load() | |
| logger.info(f"π Loaded webpage with {len(documents)} documents") | |
| # Add source metadata | |
| for doc in documents: | |
| doc.metadata.update({ | |
| "source_url": url, | |
| "source_type": "web", | |
| "processed_at": datetime.utcnow().isoformat() | |
| }) | |
| return documents | |
| except Exception as e: | |
| logger.error(f"β Error processing URL {url}: {e}") | |
| raise | |
| # RAG System | |
| class RAGSystem: | |
| """Main RAG system combining all components.""" | |
| def __init__(self): | |
| self.vector_store = MongoDBAtlasVectorStore() | |
| self.document_processor = DocumentProcessor() | |
| # Initialize LLM | |
| try: | |
| self.llm = CustomOllamaLLM() | |
| logger.info("π€ LLM initialized successfully") | |
| except Exception as e: | |
| logger.warning(f"β οΈ Could not load GGUF model: {e}") | |
| self.llm = None | |
| def add_documents(self, documents: List[Document], source_info: Dict[str, Any] = None) -> int: | |
| """Add documents to the vector store.""" | |
| return self.vector_store.add_documents(documents, source_info) | |
| def query(self, question: str, k: int = 3, metric: str = None) -> Dict[str, Any]: | |
| """Query the RAG system with specified similarity metric.""" | |
| try: | |
| # Check if there are any documents in the database | |
| doc_count = self.vector_store.get_document_count() | |
| # Perform similarity search with specified metric | |
| docs_with_scores = self.vector_store.similarity_search(question, k=k, metric=metric) | |
| # If no documents found or no documents in database, use LLM directly for conversation | |
| if not docs_with_scores or doc_count == 0: | |
| logger.info("π No relevant documents found, using LLM for direct conversation") | |
| # Use LLM for general conversation | |
| if self.llm: | |
| prompt = f"""You're a helpful, smart, and friendly AI assistant. Answer the user's question naturally and conversationally. | |
| Question: {question} | |
| Answer:""" | |
| answer = self.llm(prompt) | |
| logger.info("π€ Generated conversational answer using LLM") | |
| else: | |
| answer = "I'm ready to help! However, I need the LLM to be properly configured. You can upload documents and I'll help you find information from them." | |
| return { | |
| "answer": answer, | |
| "sources": [], | |
| "scores": [], | |
| "context_count": 0, | |
| "metric_used": metric or config.DEFAULT_METRIC, | |
| "mode": "conversation" | |
| } | |
| # Extract context and metadata from retrieved documents | |
| contexts = [] | |
| sources = [] | |
| scores = [] | |
| for doc, score in docs_with_scores: | |
| contexts.append(doc.page_content) | |
| sources.append({ | |
| "source_file": doc.metadata.get("source_file", "Unknown"), | |
| "page": doc.metadata.get("page", "N/A"), | |
| "chunk_id": doc.metadata.get("chunk_id", "N/A"), | |
| "content": doc.page_content[:500] # Trimmed for UI display (adjust as needed) | |
| }) | |
| scores.append(float(score)) | |
| # Generate answer using LLM with document context (RAG mode) | |
| if self.llm: | |
| context_text = "\n\n".join(contexts) | |
| prompt = f"""You're a helpful AI assistant. Answer the user's question based on the context provided below. | |
| If the context contains relevant information, use it to provide a detailed and accurate answer. | |
| If the context doesn't contain enough information, you can supplement with general knowledge but mention what came from the documents. | |
| Context from documents: | |
| {context_text} | |
| Question: {question} | |
| Answer:""" | |
| answer = self.llm(prompt) | |
| logger.info("π€ Generated RAG answer using LLM with document context") | |
| else: | |
| # Fallback when LLM is not available | |
| context_text = "\n\n".join(contexts[:2]) | |
| answer = f"Based on the retrieved documents:\n\n{context_text[:800]}..." | |
| logger.info("π Generated fallback answer") | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "scores": scores, | |
| "context_count": len(contexts), | |
| "metric_used": metric or config.DEFAULT_METRIC, | |
| "mode": "rag" | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error querying RAG system: {e}") | |
| return { | |
| "answer": f"Error processing query: {str(e)}", | |
| "sources": [], | |
| "scores": [], | |
| "metric_used": metric or config.DEFAULT_METRIC | |
| } | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get system status.""" | |
| llm_available = True # Ollama Cloud is always available if API key is set | |
| document_count = self.vector_store.get_document_count() | |
| return { | |
| "documents_count": document_count, | |
| "documents_loaded": document_count, # For compatibility | |
| "llm_available": llm_available, | |
| "embedding_model": config.EMBEDDING_MODEL, | |
| "mongodb_atlas_connected": True, | |
| "database_name": config.MONGODB_DB_NAME, | |
| "collection_name": config.MONGODB_COLLECTION_NAME, | |
| "vector_index_name": config.VECTOR_INDEX_NAME, | |
| "ollama_model": config.OLLAMA_MODEL, | |
| "ollama_cloud_available": True, | |
| "supported_metrics": list(config.SUPPORTED_METRICS.keys()), | |
| "default_metric": config.DEFAULT_METRIC | |
| } | |
| # Global RAG system instance | |
| rag_system = None | |
| async def startup_event(): | |
| """Initialize RAG system on startup.""" | |
| global rag_system | |
| try: | |
| rag_system = RAGSystem() | |
| logger.info("π RAG System initialized successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to initialize RAG system: {e}") | |
| raise | |
| async def serve_index(request: Request): | |
| """Serve the main HTML interface.""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload and process a document.""" | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| # Ensure temp directory exists | |
| temp_dir = rag_system.document_processor.temp_dir | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, file.filename) | |
| try: | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| logger.info(f"π€ Saved uploaded file to: {temp_path}") | |
| # Process document | |
| documents = rag_system.document_processor.process_uploaded_file(temp_path, file.filename) | |
| # Add to vector store | |
| chunks_added = rag_system.add_documents(documents, {"upload_filename": file.filename}) | |
| return { | |
| "status": "success", | |
| "chunks": chunks_added, | |
| "filename": file.filename, | |
| "message": f"Successfully processed {file.filename} with {chunks_added} chunks" | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error uploading document: {e}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| async def add_url(request: dict = Body(...)): | |
| """Add a URL to the system.""" | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| url = request.get("url") | |
| if not url: | |
| raise HTTPException(status_code=400, detail="URL is required") | |
| try: | |
| logger.info(f"π Processing URL: {url}") | |
| # Process URL | |
| documents = rag_system.document_processor.process_url(url) | |
| # Add to vector store | |
| chunks_added = rag_system.add_documents(documents, {"source_url": url}) | |
| return { | |
| "status": "success", | |
| "chunks": chunks_added, | |
| "url": url, | |
| "message": f"Successfully processed URL with {chunks_added} chunks" | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error processing URL: {e}") | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def ask_question( | |
| question: str = Form(...), | |
| k: int = Form(3), | |
| metric: str = Form(None) | |
| ): | |
| """Ask a question to the RAG system with optional similarity metric.""" | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| if not question.strip(): | |
| raise HTTPException(status_code=400, detail="Question is required") | |
| # Validate metric if provided | |
| if metric and metric not in config.SUPPORTED_METRICS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported metric: {metric}. Supported metrics: {list(config.SUPPORTED_METRICS.keys())}" | |
| ) | |
| try: | |
| logger.info(f"β Processing question: {question}") | |
| if metric: | |
| logger.info(f"π§ Using similarity metric: {metric}") | |
| result = rag_system.query(question, k=k, metric=metric) | |
| # Extract page numbers from sources for compatibility | |
| pages = [] | |
| for source in result["sources"]: | |
| if "page" in source: | |
| pages.append(source["page"]) | |
| elif "source_file" in source: | |
| pages.append(source["source_file"]) | |
| return { | |
| "status": "success", | |
| "answer": result["answer"], | |
| "pages": pages, | |
| "scores": result["scores"], | |
| "sources": result["sources"], | |
| "context_count": result.get("context_count", 0), | |
| "metric_used": result.get("metric_used", config.DEFAULT_METRIC) | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error processing question: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_status(): | |
| """Get system status.""" | |
| if not rag_system: | |
| return {"error": "RAG system not initialized"} | |
| return rag_system.get_status() | |
| async def get_supported_metrics(): | |
| """Get list of supported similarity metrics.""" | |
| return { | |
| "supported_metrics": list(config.SUPPORTED_METRICS.keys()), | |
| "default_metric": config.DEFAULT_METRIC, | |
| "metric_descriptions": { | |
| "cosine": "Cosine Similarity - Measures angle between vectors (0-1, higher is better)", | |
| "tanh_cosine": "Tanh(Cosine) - Hyperbolic tangent of cosine similarity", | |
| "dot": "Dot Product - Direct dot product of vectors", | |
| "euclidean": "Euclidean Distance - L2 distance between vectors (lower is better)" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "service": "MongoDB Atlas RAG System" | |
| } | |
| async def clear_database(): | |
| """Clear all documents from the database (for testing purposes).""" | |
| if not rag_system: | |
| raise HTTPException(status_code=500, detail="RAG system not initialized") | |
| try: | |
| deleted_count = rag_system.vector_store.delete_all_documents() | |
| return { | |
| "status": "success", | |
| "message": f"Deleted {deleted_count} documents", | |
| "deleted_count": deleted_count | |
| } | |
| except Exception as e: | |
| logger.error(f"β Error clearing database: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("π Starting MongoDB Atlas RAG System") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |