import json import os import numpy as np import torch from typing import List, Dict, Any, Optional from tqdm import tqdm from pymongo import ReplaceOne from rank_bm25 import BM25Okapi from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from config import VECTOR_INDEX_NAME from .database import get_mongo_client, get_mongo_collection from .models import get_clip_model, get_llm, get_groq_client from dotenv import load_dotenv import time load_dotenv() import os class RAGEngine: """ Unified RAG engine refactored from search.py. """ def __init__(self, use_hybrid: bool = True, force_clean: bool = False): self.use_hybrid = use_hybrid self.clip_model = get_clip_model() self.collection = get_mongo_collection() self.llm = get_llm() self.groq_client = get_groq_client() if force_clean: self.collection.delete_many({}) self._setup_vector_index() self.bm25_index = None self.bm25_doc_map = {} if self.collection.count_documents({}) > 0: self._rebuild_bm25_index() def _setup_vector_index(self): """ Attempts to create a vector search index if using MongoDB Atlas. Includes robust dimension checking and error handling. """ # 1. Determine Dimensions safely try: dims = self.clip_model.get_sentence_embedding_dimension() if dims is None or not isinstance(dims, int): raise ValueError("Model returned invalid dimensions") except Exception: print("Auto-dim failed, probing model...") test_vec = self.clip_model.encode("test") dims = len(test_vec) print(f"Vector Dimensions: {dims}") # 2. Define Index Model index_model = { "definition": { "fields": [ { "type": "vector", "path": "embedding", "numDimensions": int(dims), # Ensure strict integer "similarity": "cosine" }, { "type": "filter", "path": "metadata.type" } ] }, "name": VECTOR_INDEX_NAME, "type": "vectorSearch" } # 3. Create Index try: # Check if index already exists indexes = list(self.collection.list_search_indexes()) index_names = [idx.get("name") for idx in indexes] if VECTOR_INDEX_NAME not in index_names: print(f"Creating Atlas Vector Search Index '{VECTOR_INDEX_NAME}'...") self.collection.create_search_index(model=index_model) print("Index creation initiated. Please wait 1-2 minutes for Atlas to build it.") print("You can check progress in Atlas UI -> Database -> Search -> Vector Search") else: print(f"Index '{VECTOR_INDEX_NAME}' already exists.") except Exception as e: print(f"\nAutomatic Index Creation Failed: {e}") print("This is common on Free Tier (M0) or due to permissions.") print("PLEASE CREATE MANUALLY IN ATLAS UI (See JSON below)\n") print(json.dumps(index_model["definition"], indent=2)) except Exception as e: print(f"Unexpected error checking/creating index: {e}") def _rebuild_bm25_index(self): cursor = self.collection.find( {"metadata.type": {"$in": ["text", "table", "list", "header", "code"]}}, {"content": 1, "_id": 1} ) text_docs = [] self.bm25_doc_map = {} for idx, doc in enumerate(cursor): content = doc.get("content", "") if content: text_docs.append(content.lower().split()) self.bm25_doc_map[idx] = str(doc["_id"]) if text_docs: self.bm25_index = BM25Okapi(text_docs) def _encode_content(self, content: Any, content_type: str) -> np.ndarray: if content_type == "image": # Assuming content is base64 from PIL import Image from io import BytesIO import base64 try: img = Image.open(BytesIO(base64.b64decode(content))).convert("RGB") return self.clip_model.encode(img, normalize_embeddings=True) except: return None return self.clip_model.encode(content, normalize_embeddings=True) def ingest_data(self, data: Dict[str, Any]): """Ingests processed document data.""" operations = [] for chunk in data.get("chunks", []): embedding = self._encode_content(chunk["text"], "text") if embedding is None: continue doc = { "_id": chunk["chunk_id"], "content": chunk["text"], "embedding": embedding.tolist(), "metadata": { **chunk["metadata"], "type": chunk.get("type", "text") } } operations.append(ReplaceOne({"_id": doc["_id"]}, doc, upsert=True)) for img in data.get("images", []): embedding = self._encode_content(img["image_base64"], "image") if embedding is None: continue doc = { "_id": img["image_id"], "content": img.get("description", ""), "embedding": embedding.tolist(), "metadata": { "page": str(img.get("page_number", 0)), "header": str(img.get("section_header", "")), "type": "image", "description": img.get("description", ""), "image_base64": img["image_base64"] } } operations.append(ReplaceOne({"_id": doc["_id"]}, doc, upsert=True)) if operations: for i in range(0, len(operations), 100): self.collection.bulk_write(operations[i:i+100]) self._rebuild_bm25_index() def hybrid_search(self, query: str, top_k: int = 5, alpha: float = 0.5) -> List[Dict]: query_embedding = self._encode_content(query, "text") dense_results = [] try: pipeline = [ {"$vectorSearch": { "index": VECTOR_INDEX_NAME, "path": "embedding", "queryVector": query_embedding.tolist(), "numCandidates": top_k * 10, "limit": top_k * 2 }}, {"$project": {"content": 1, "metadata": 1, "score": {"$meta": "vectorSearchScore"}}} ] dense_results = list(self.collection.aggregate(pipeline)) except: pass dense_scores = {str(r["_id"]): {"score": r.get("score", 0), "doc": r} for r in dense_results} sparse_scores = {} if self.bm25_index: scores = self.bm25_index.get_scores(query.lower().split()) max_s = max(scores) if len(scores) > 0 and max(scores) > 0 else 1.0 for i in np.argsort(scores)[::-1][:top_k*2]: if scores[i] > 0: sparse_scores[self.bm25_doc_map[i]] = scores[i] / max_s combined = [] all_ids = set(dense_scores.keys()) | set(sparse_scores.keys()) for did in all_ids: d_s = dense_scores.get(did, {}).get("score", 0) s_s = sparse_scores.get(did, 0) score = (alpha * d_s) + ((1-alpha) * s_s) doc = dense_scores.get(did, {}).get("doc") or self.collection.find_one({"_id": did}) if doc: combined.append({**doc, "score": score}) combined.sort(key=lambda x: x["score"], reverse=True) return combined[:top_k] def answer_question(self, question: str, top_k: int = 5) -> str: results = self.hybrid_search(question, top_k=top_k) if not results: return "No relevant info found." context = "" for i, res in enumerate(results, 1): m = res["metadata"] context += f"\n[Src {i} | Page {m.get('page_number','?')}] {res['content']}" prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer strictly based on context:" try: chain = ChatPromptTemplate.from_template("{p}") | self.llm | StrOutputParser() # return chain.invoke({"p": prompt}) for msg in chain.stream({"p": prompt}): if hasattr(msg, "content"): time.sleep(0.01) yield msg.content else: time.sleep(0.01) yield str(msg) except Exception as e: return f"Error: {e}" def search_images(self, query: str, top_k: int = 3, min_score: float = 0.5) -> List[Dict]: query_embedding = self._encode_content(f"{query}", "text") try: pipeline = [ {"$vectorSearch": { "index": VECTOR_INDEX_NAME, "path": "embedding", "queryVector": query_embedding.tolist(), "numCandidates": top_k*10, "limit": top_k*2, "filter": {"metadata.type": "image"} }}, {"$project": {"content": 1, "metadata": 1, "score": {"$meta": "vectorSearchScore"}}} ] results = list(self.collection.aggregate(pipeline)) return [{"description": r["content"], "image_base64": r["metadata"].get("image_base64"), "score": r["score"]} for r in results if r["score"] >= min_score][:top_k] except Exception as e: print("*********error", str(e)) return [] # def generate_suggested_questions(self, num_questions: int = 5) -> List[str]: # # Simple metadata-based generation or just a fixed list for now # return ["What is the main topic?", "Explain the diagrams.", "Summarize the results."] def generate_suggested_questions(self, num_questions: int = 4) -> List[str]: """Token-efficient question generation using metadata.""" print("\nGenerating suggested questions (Efficient Mode)...") try: # 1. Fetch metadata ONLY (projection excludes embedding and content) cursor = self.collection.find( {}, {"metadata": 1, "_id": 0} ).limit(100) metadatas = [doc.get('metadata', {}) for doc in cursor] if not metadatas: return ["What is this document about?"] # 2. Extract High-Level Structure headers = set() image_descriptions = [] import random random.shuffle(metadatas) for meta in metadatas: if 'header' in meta and len(headers) < 8: h = str(meta['header']).strip() if h and h.lower() != "unknown" and len(h) > 5: headers.add(h) if meta.get('type') == 'image' and len(image_descriptions) < 2: desc = meta.get('description', '') if len(desc) > 20: image_descriptions.append(desc[:100] + "...") # 3. Construct Prompt context_str = "Document Sections:\n" + "\n".join([f"- {h}" for h in headers]) if image_descriptions: context_str += "\n\nVisual Content involves:\n" + "\n".join([f"- {d}" for d in image_descriptions]) # 4. Prompt LLM prompt = f"""Generate {num_questions} short, interesting questions about a document with these sections and visuals: {context_str} Output ONLY the {num_questions} questions, one per line. No numbering.""" prompt_tmpl = ChatPromptTemplate.from_messages([ ("system", "You are a helpful assistant."), ("user", "{prompt}") ]) chain = prompt_tmpl | self.llm | StrOutputParser() response = chain.invoke({"prompt": prompt}) questions = [q.strip().lstrip('-1234567890. ') for q in response.split('\n') if q.strip()] return questions[:num_questions] except Exception as e: print(f"Error generating questions: {e}")