rag-system / main.py
Jainish1808
Fix LLM check - remove incorrect model attribute check
39ac481
"""
Complete RAG (Retrieval-Augmented Generation) QA System with MongoDB Atlas Vector Search
A single-file implementation for document processing, embedding, and question answering.
Updated to use MongoDB Atlas Vector Search for production-ready vector storage.
Requirements:
pip install langchain langchain-community langchain-mongodb pymongo sentence-transformers
pip install faiss-cpu pypdf pandas requests beautifulsoup4 fastapi uvicorn
pip install llama-cpp-python (optional, for GGUF models)
"""
import os
import json
import numpy as np
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()
# Set USER_AGENT to avoid warnings
os.environ.setdefault("USER_AGENT", "RAG-System/1.0")
# LangChain imports
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import (
PyPDFLoader,
CSVLoader,
JSONLoader,
WebBaseLoader
)
from langchain_community.document_loaders import RecursiveUrlLoader
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
# MongoDB imports
from pymongo import MongoClient
from pymongo.collection import Collection
# FastAPI and other imports
from fastapi import FastAPI, UploadFile, File, Form, Request, Body, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
import shutil
import pathlib
import tempfile
# Ollama Cloud import
import ollama_client
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app setup
app = FastAPI(title="RAG QA System with MongoDB Atlas Vector Search")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
BASE_DIR = pathlib.Path(__file__).parent
TEMPLATES_DIR = BASE_DIR / "templates"
TEMPLATES_DIR.mkdir(exist_ok=True)
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
# Configuration for MongoDB Atlas
class Config:
# MongoDB Atlas Configuration
# Replace with your actual MongoDB Atlas connection string
MONGODB_URI = os.getenv("MONGODB_URI", "mongodb+srv://jainishpatel188:clgC0gsKbQBAauiu@vector.f3dkdar.mongodb.net/")
# Database and Collection Configuration (as per your structure)
MONGODB_DB_NAME = "vector_data" # Your database name
MONGODB_COLLECTION_NAME = "RAG" # Your collection name
# Vector Search Index Configuration
VECTOR_INDEX_NAME = "vector_search_index" # This will be created in Atlas
# Embedding Configuration
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
EMBEDDING_DIMENSION = 384 # Dimension for all-MiniLM-L6-v2
# LLM Configuration
GGUF_MODEL_PATH = os.getenv("GGUF_MODEL_PATH",
"C:\\Users\\jaini\\IntellijIdea\\Jainish PYTHON AI TIGER\\PDF\\mistral-7b-instruct-v0.2.Q4_K_M.gguf")
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "gpt-oss:120b-cloud")
# Text Splitting Configuration
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
# Similarity Metrics Configuration
SUPPORTED_METRICS = {
"cosine": "cosine",
"tanh_cosine": "tanh_cosine",
"dot": "dotProduct",
"euclidean": "euclidean"
}
DEFAULT_METRIC = "cosine"
config = Config()
# Custom LLM wrapper for Ollama Cloud
class CustomOllamaLLM(LLM):
"""Custom LLM wrapper for Ollama Cloud models."""
model_name: str
def __init__(self, model_name: str = None, **kwargs):
if model_name is None:
model_name = config.OLLAMA_MODEL
super().__init__(model_name=model_name, **kwargs)
self.model_name = model_name
@property
def _llm_type(self) -> str:
return "custom_ollama"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate response from Ollama Cloud."""
try:
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
response = ollama_client.generate_from_ollama(
model=self.model_name,
prompt=formatted_prompt,
max_tokens=512
)
answer = response.strip()
return answer if answer else "I couldn't generate a relevant answer."
except Exception as e:
logger.error(f"Error generating response from Ollama Cloud: {e}")
return f"Error generating answer: {e}"
# MongoDB Atlas Vector Store Manager
class MongoDBAtlasVectorStore:
"""Manages MongoDB Atlas vector storage and search operations."""
def __init__(self):
# Initialize MongoDB connection
self.client = MongoClient(config.MONGODB_URI)
self.db = self.client[config.MONGODB_DB_NAME]
self.collection = self.db[config.MONGODB_COLLECTION_NAME]
# Initialize embeddings
self.embeddings = HuggingFaceEmbeddings(
model_name=config.EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
# Initialize text splitter
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.CHUNK_SIZE,
chunk_overlap=config.CHUNK_OVERLAP,
separators=["\n\n", "\n", " ", ""]
)
# Initialize LangChain MongoDB Atlas Vector Search
self.vector_store = MongoDBAtlasVectorSearch(
collection=self.collection,
embedding=self.embeddings,
index_name=config.VECTOR_INDEX_NAME
)
# Test connection
self._test_connection()
logger.info("MongoDB Atlas Vector Store initialized successfully")
def _test_connection(self):
"""Test MongoDB Atlas connection."""
try:
# Test connection
self.client.admin.command('ping')
logger.info("βœ… Successfully connected to MongoDB Atlas")
# Check if collection exists
if config.MONGODB_COLLECTION_NAME in self.db.list_collection_names():
logger.info(f"βœ… Collection '{config.MONGODB_COLLECTION_NAME}' exists")
else:
logger.info(f"πŸ“ Collection '{config.MONGODB_COLLECTION_NAME}' will be created")
except Exception as e:
logger.error(f"❌ MongoDB Atlas connection failed: {e}")
raise
def add_documents(self, documents: List[Document], source_info: Dict[str, Any] = None) -> int:
"""Add documents to the MongoDB Atlas vector store."""
try:
# Split documents into chunks
print("πŸ“„ Original document count:", len(documents))
text_chunks = self.text_splitter.split_documents(documents)
# Add source info to metadata
if source_info:
for chunk in text_chunks:
chunk.metadata.update(source_info)
chunk.metadata.update({
"timestamp": datetime.utcnow().isoformat(),
"chunk_id": f"{source_info.get('source_file', 'unknown')}_{len(text_chunks)}"
})
# Add documents to vector store using LangChain
ids = self.vector_store.add_documents(text_chunks)
logger.info(f"βœ… Added {len(ids)} document chunks to MongoDB Atlas")
return len(ids)
except Exception as e:
logger.error(f"❌ Error adding documents: {e}")
raise
def similarity_search(self, query: str, k: int = 3, metric: str = None, score_threshold: float = 0.0) -> List[Tuple[Document, float]]:
"""Perform similarity search using MongoDB Atlas Vector Search with specified metric."""
try:
# Log the metric being used
if metric and metric in config.SUPPORTED_METRICS:
logger.info(f"οΏ½ Using similarity metric: {metric}")
else:
logger.info(f"οΏ½ Using default similarity metric: cosine")
# Perform similarity search using the vector store
results = self.vector_store.similarity_search_with_score(
query=query,
k=k
)
logger.info(f"πŸ“Š Found {len(results)} relevant documents for query")
return results
except Exception as e:
logger.error(f"❌ Error performing similarity search: {e}")
# Fallback to basic similarity search without score
try:
docs = self.vector_store.similarity_search(query=query, k=k)
# Convert to tuple format with dummy scores
results = [(doc, 0.0) for doc in docs]
logger.info(f"πŸ“Š Fallback search returned {len(results)} results")
return results
except Exception as fallback_e:
logger.error(f"❌ Fallback search also failed: {fallback_e}")
return []
def get_document_count(self) -> int:
"""Get total number of documents in the collection."""
try:
count = self.collection.count_documents({})
logger.info(f"πŸ“š Total documents in collection: {count}")
return count
except Exception as e:
logger.error(f"❌ Error getting document count: {e}")
return 0
def delete_all_documents(self) -> int:
"""Delete all documents from the collection (for testing purposes)."""
try:
result = self.collection.delete_many({})
logger.info(f"πŸ—‘οΈ Deleted {result.deleted_count} documents")
return result.deleted_count
except Exception as e:
logger.error(f"❌ Error deleting documents: {e}")
return 0
# Document Processor using LangChain
class DocumentProcessor:
"""Handles document loading and processing using LangChain."""
def __init__(self):
self.temp_dir = tempfile.mkdtemp()
logger.info(f"πŸ“ Created temporary directory: {self.temp_dir}")
def process_json(self, file_path: str) -> List[Document]:
"""Convert JSON file to Document objects."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert JSON to formatted string
json_content = json.dumps(data, indent=2, ensure_ascii=False)
# Create a Document object
document = Document(
page_content=json_content,
metadata={
"source": file_path,
"file_type": "json",
"total_chars": len(json_content)
}
)
return [document]
except json.JSONDecodeError as e:
raise Exception(f"Invalid JSON format: {e}")
except Exception as e:
raise Exception(f"Error processing JSON: {e}")
def process_uploaded_file(self, file_path: str, filename: str) -> List[Document]:
"""Process uploaded file and return LangChain Documents."""
_, ext = os.path.splitext(filename.lower())
try:
if ext == '.pdf':
loader = PyPDFLoader(file_path)
documents = loader.load()
logger.info(f"πŸ“„ Loaded PDF with {len(documents)} pages")
elif ext == '.csv':
loader = CSVLoader(file_path)
documents = loader.load()
logger.info(f"πŸ“Š Loaded CSV with {len(documents)} rows")
elif ext == '.json':
# Use simple JSON processing instead of JSONLoader
documents = self.process_json(file_path)
logger.info(f"πŸ“ Loaded JSON with {len(documents)} document(s)")
else:
raise ValueError(f"Unsupported file type: {ext}")
# Add source metadata
for doc in documents:
doc.metadata.update({
"source_file": filename,
"file_type": ext,
"processed_at": datetime.utcnow().isoformat()
})
return documents
except Exception as e:
logger.error(f"❌ Error processing file {filename}: {e}")
raise
def process_url(self, url: str) -> List[Document]:
"""Process URL and return LangChain Documents."""
try:
loader = WebBaseLoader([url])
documents = loader.load()
logger.info(f"🌐 Loaded webpage with {len(documents)} documents")
# Add source metadata
for doc in documents:
doc.metadata.update({
"source_url": url,
"source_type": "web",
"processed_at": datetime.utcnow().isoformat()
})
return documents
except Exception as e:
logger.error(f"❌ Error processing URL {url}: {e}")
raise
# RAG System
class RAGSystem:
"""Main RAG system combining all components."""
def __init__(self):
self.vector_store = MongoDBAtlasVectorStore()
self.document_processor = DocumentProcessor()
# Initialize LLM
try:
self.llm = CustomOllamaLLM()
logger.info("πŸ€– LLM initialized successfully")
except Exception as e:
logger.warning(f"⚠️ Could not load GGUF model: {e}")
self.llm = None
def add_documents(self, documents: List[Document], source_info: Dict[str, Any] = None) -> int:
"""Add documents to the vector store."""
return self.vector_store.add_documents(documents, source_info)
def query(self, question: str, k: int = 3, metric: str = None) -> Dict[str, Any]:
"""Query the RAG system with specified similarity metric."""
try:
# Check if there are any documents in the database
doc_count = self.vector_store.get_document_count()
# Perform similarity search with specified metric
docs_with_scores = self.vector_store.similarity_search(question, k=k, metric=metric)
# If no documents found or no documents in database, use LLM directly for conversation
if not docs_with_scores or doc_count == 0:
logger.info("πŸ“ No relevant documents found, using LLM for direct conversation")
# Use LLM for general conversation
if self.llm:
prompt = f"""You're a helpful, smart, and friendly AI assistant. Answer the user's question naturally and conversationally.
Question: {question}
Answer:"""
answer = self.llm(prompt)
logger.info("πŸ€– Generated conversational answer using LLM")
else:
answer = "I'm ready to help! However, I need the LLM to be properly configured. You can upload documents and I'll help you find information from them."
return {
"answer": answer,
"sources": [],
"scores": [],
"context_count": 0,
"metric_used": metric or config.DEFAULT_METRIC,
"mode": "conversation"
}
# Extract context and metadata from retrieved documents
contexts = []
sources = []
scores = []
for doc, score in docs_with_scores:
contexts.append(doc.page_content)
sources.append({
"source_file": doc.metadata.get("source_file", "Unknown"),
"page": doc.metadata.get("page", "N/A"),
"chunk_id": doc.metadata.get("chunk_id", "N/A"),
"content": doc.page_content[:500] # Trimmed for UI display (adjust as needed)
})
scores.append(float(score))
# Generate answer using LLM with document context (RAG mode)
if self.llm:
context_text = "\n\n".join(contexts)
prompt = f"""You're a helpful AI assistant. Answer the user's question based on the context provided below.
If the context contains relevant information, use it to provide a detailed and accurate answer.
If the context doesn't contain enough information, you can supplement with general knowledge but mention what came from the documents.
Context from documents:
{context_text}
Question: {question}
Answer:"""
answer = self.llm(prompt)
logger.info("πŸ€– Generated RAG answer using LLM with document context")
else:
# Fallback when LLM is not available
context_text = "\n\n".join(contexts[:2])
answer = f"Based on the retrieved documents:\n\n{context_text[:800]}..."
logger.info("πŸ“ Generated fallback answer")
return {
"answer": answer,
"sources": sources,
"scores": scores,
"context_count": len(contexts),
"metric_used": metric or config.DEFAULT_METRIC,
"mode": "rag"
}
except Exception as e:
logger.error(f"❌ Error querying RAG system: {e}")
return {
"answer": f"Error processing query: {str(e)}",
"sources": [],
"scores": [],
"metric_used": metric or config.DEFAULT_METRIC
}
def get_status(self) -> Dict[str, Any]:
"""Get system status."""
llm_available = True # Ollama Cloud is always available if API key is set
document_count = self.vector_store.get_document_count()
return {
"documents_count": document_count,
"documents_loaded": document_count, # For compatibility
"llm_available": llm_available,
"embedding_model": config.EMBEDDING_MODEL,
"mongodb_atlas_connected": True,
"database_name": config.MONGODB_DB_NAME,
"collection_name": config.MONGODB_COLLECTION_NAME,
"vector_index_name": config.VECTOR_INDEX_NAME,
"ollama_model": config.OLLAMA_MODEL,
"ollama_cloud_available": True,
"supported_metrics": list(config.SUPPORTED_METRICS.keys()),
"default_metric": config.DEFAULT_METRIC
}
# Global RAG system instance
rag_system = None
@app.on_event("startup")
async def startup_event():
"""Initialize RAG system on startup."""
global rag_system
try:
rag_system = RAGSystem()
logger.info("πŸš€ RAG System initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize RAG system: {e}")
raise
@app.get("/", response_class=HTMLResponse)
async def serve_index(request: Request):
"""Serve the main HTML interface."""
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/upload/")
async def upload_document(file: UploadFile = File(...)):
"""Upload and process a document."""
if not rag_system:
raise HTTPException(status_code=500, detail="RAG system not initialized")
# Ensure temp directory exists
temp_dir = rag_system.document_processor.temp_dir
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, file.filename)
try:
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"πŸ“€ Saved uploaded file to: {temp_path}")
# Process document
documents = rag_system.document_processor.process_uploaded_file(temp_path, file.filename)
# Add to vector store
chunks_added = rag_system.add_documents(documents, {"upload_filename": file.filename})
return {
"status": "success",
"chunks": chunks_added,
"filename": file.filename,
"message": f"Successfully processed {file.filename} with {chunks_added} chunks"
}
except Exception as e:
logger.error(f"❌ Error uploading document: {e}")
raise HTTPException(status_code=400, detail=str(e))
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
@app.post("/add_url/")
async def add_url(request: dict = Body(...)):
"""Add a URL to the system."""
if not rag_system:
raise HTTPException(status_code=500, detail="RAG system not initialized")
url = request.get("url")
if not url:
raise HTTPException(status_code=400, detail="URL is required")
try:
logger.info(f"🌐 Processing URL: {url}")
# Process URL
documents = rag_system.document_processor.process_url(url)
# Add to vector store
chunks_added = rag_system.add_documents(documents, {"source_url": url})
return {
"status": "success",
"chunks": chunks_added,
"url": url,
"message": f"Successfully processed URL with {chunks_added} chunks"
}
except Exception as e:
logger.error(f"❌ Error processing URL: {e}")
raise HTTPException(status_code=400, detail=str(e))
@app.post("/ask/")
async def ask_question(
question: str = Form(...),
k: int = Form(3),
metric: str = Form(None)
):
"""Ask a question to the RAG system with optional similarity metric."""
if not rag_system:
raise HTTPException(status_code=500, detail="RAG system not initialized")
if not question.strip():
raise HTTPException(status_code=400, detail="Question is required")
# Validate metric if provided
if metric and metric not in config.SUPPORTED_METRICS:
raise HTTPException(
status_code=400,
detail=f"Unsupported metric: {metric}. Supported metrics: {list(config.SUPPORTED_METRICS.keys())}"
)
try:
logger.info(f"❓ Processing question: {question}")
if metric:
logger.info(f"πŸ”§ Using similarity metric: {metric}")
result = rag_system.query(question, k=k, metric=metric)
# Extract page numbers from sources for compatibility
pages = []
for source in result["sources"]:
if "page" in source:
pages.append(source["page"])
elif "source_file" in source:
pages.append(source["source_file"])
return {
"status": "success",
"answer": result["answer"],
"pages": pages,
"scores": result["scores"],
"sources": result["sources"],
"context_count": result.get("context_count", 0),
"metric_used": result.get("metric_used", config.DEFAULT_METRIC)
}
except Exception as e:
logger.error(f"❌ Error processing question: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status/")
async def get_status():
"""Get system status."""
if not rag_system:
return {"error": "RAG system not initialized"}
return rag_system.get_status()
@app.get("/metrics/")
async def get_supported_metrics():
"""Get list of supported similarity metrics."""
return {
"supported_metrics": list(config.SUPPORTED_METRICS.keys()),
"default_metric": config.DEFAULT_METRIC,
"metric_descriptions": {
"cosine": "Cosine Similarity - Measures angle between vectors (0-1, higher is better)",
"tanh_cosine": "Tanh(Cosine) - Hyperbolic tangent of cosine similarity",
"dot": "Dot Product - Direct dot product of vectors",
"euclidean": "Euclidean Distance - L2 distance between vectors (lower is better)"
}
}
@app.get("/health/")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"service": "MongoDB Atlas RAG System"
}
@app.delete("/clear/")
async def clear_database():
"""Clear all documents from the database (for testing purposes)."""
if not rag_system:
raise HTTPException(status_code=500, detail="RAG system not initialized")
try:
deleted_count = rag_system.vector_store.delete_all_documents()
return {
"status": "success",
"message": f"Deleted {deleted_count} documents",
"deleted_count": deleted_count
}
except Exception as e:
logger.error(f"❌ Error clearing database: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
logger.info("πŸš€ Starting MongoDB Atlas RAG System")
uvicorn.run(app, host="0.0.0.0", port=8000)