Spaces:
Running
Running
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| from typing import List, Dict, Any, Optional | |
| from tqdm import tqdm | |
| from pymongo import ReplaceOne | |
| from rank_bm25 import BM25Okapi | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from config import VECTOR_INDEX_NAME | |
| from .database import get_mongo_client, get_mongo_collection | |
| from .models import get_clip_model, get_llm, get_groq_client | |
| from dotenv import load_dotenv | |
| import time | |
| load_dotenv() | |
| import os | |
| class RAGEngine: | |
| """ | |
| Unified RAG engine refactored from search.py. | |
| """ | |
| def __init__(self, use_hybrid: bool = True, force_clean: bool = False): | |
| self.use_hybrid = use_hybrid | |
| self.clip_model = get_clip_model() | |
| self.collection = get_mongo_collection() | |
| self.llm = get_llm() | |
| self.groq_client = get_groq_client() | |
| if force_clean: | |
| self.collection.delete_many({}) | |
| self._setup_vector_index() | |
| self.bm25_index = None | |
| self.bm25_doc_map = {} | |
| if self.collection.count_documents({}) > 0: | |
| self._rebuild_bm25_index() | |
| def _setup_vector_index(self): | |
| """ | |
| Attempts to create a vector search index if using MongoDB Atlas. | |
| Includes robust dimension checking and error handling. | |
| """ | |
| # 1. Determine Dimensions safely | |
| try: | |
| dims = self.clip_model.get_sentence_embedding_dimension() | |
| if dims is None or not isinstance(dims, int): | |
| raise ValueError("Model returned invalid dimensions") | |
| except Exception: | |
| print("Auto-dim failed, probing model...") | |
| test_vec = self.clip_model.encode("test") | |
| dims = len(test_vec) | |
| print(f"Vector Dimensions: {dims}") | |
| # 2. Define Index Model | |
| index_model = { | |
| "definition": { | |
| "fields": [ | |
| { | |
| "type": "vector", | |
| "path": "embedding", | |
| "numDimensions": int(dims), # Ensure strict integer | |
| "similarity": "cosine" | |
| }, | |
| { | |
| "type": "filter", | |
| "path": "metadata.type" | |
| } | |
| ] | |
| }, | |
| "name": VECTOR_INDEX_NAME, | |
| "type": "vectorSearch" | |
| } | |
| # 3. Create Index | |
| try: | |
| # Check if index already exists | |
| indexes = list(self.collection.list_search_indexes()) | |
| index_names = [idx.get("name") for idx in indexes] | |
| if VECTOR_INDEX_NAME not in index_names: | |
| print(f"Creating Atlas Vector Search Index '{VECTOR_INDEX_NAME}'...") | |
| self.collection.create_search_index(model=index_model) | |
| print("Index creation initiated. Please wait 1-2 minutes for Atlas to build it.") | |
| print("You can check progress in Atlas UI -> Database -> Search -> Vector Search") | |
| else: | |
| print(f"Index '{VECTOR_INDEX_NAME}' already exists.") | |
| except Exception as e: | |
| print(f"\nAutomatic Index Creation Failed: {e}") | |
| print("This is common on Free Tier (M0) or due to permissions.") | |
| print("PLEASE CREATE MANUALLY IN ATLAS UI (See JSON below)\n") | |
| print(json.dumps(index_model["definition"], indent=2)) | |
| except Exception as e: | |
| print(f"Unexpected error checking/creating index: {e}") | |
| def _rebuild_bm25_index(self): | |
| cursor = self.collection.find( | |
| {"metadata.type": {"$in": ["text", "table", "list", "header", "code"]}}, | |
| {"content": 1, "_id": 1} | |
| ) | |
| text_docs = [] | |
| self.bm25_doc_map = {} | |
| for idx, doc in enumerate(cursor): | |
| content = doc.get("content", "") | |
| if content: | |
| text_docs.append(content.lower().split()) | |
| self.bm25_doc_map[idx] = str(doc["_id"]) | |
| if text_docs: | |
| self.bm25_index = BM25Okapi(text_docs) | |
| def _encode_content(self, content: Any, content_type: str) -> np.ndarray: | |
| if content_type == "image": | |
| # Assuming content is base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| try: | |
| img = Image.open(BytesIO(base64.b64decode(content))).convert("RGB") | |
| return self.clip_model.encode(img, normalize_embeddings=True) | |
| except: return None | |
| return self.clip_model.encode(content, normalize_embeddings=True) | |
| def ingest_data(self, data: Dict[str, Any]): | |
| """Ingests processed document data.""" | |
| operations = [] | |
| for chunk in data.get("chunks", []): | |
| embedding = self._encode_content(chunk["text"], "text") | |
| if embedding is None: continue | |
| doc = { | |
| "_id": chunk["chunk_id"], | |
| "content": chunk["text"], | |
| "embedding": embedding.tolist(), | |
| "metadata": { | |
| **chunk["metadata"], | |
| "type": chunk.get("type", "text") | |
| } | |
| } | |
| operations.append(ReplaceOne({"_id": doc["_id"]}, doc, upsert=True)) | |
| for img in data.get("images", []): | |
| embedding = self._encode_content(img["image_base64"], "image") | |
| if embedding is None: continue | |
| doc = { | |
| "_id": img["image_id"], | |
| "content": img.get("description", ""), | |
| "embedding": embedding.tolist(), | |
| "metadata": { | |
| "page": str(img.get("page_number", 0)), | |
| "header": str(img.get("section_header", "")), | |
| "type": "image", | |
| "description": img.get("description", ""), | |
| "image_base64": img["image_base64"] | |
| } | |
| } | |
| operations.append(ReplaceOne({"_id": doc["_id"]}, doc, upsert=True)) | |
| if operations: | |
| for i in range(0, len(operations), 100): | |
| self.collection.bulk_write(operations[i:i+100]) | |
| self._rebuild_bm25_index() | |
| def hybrid_search(self, query: str, top_k: int = 5, alpha: float = 0.5) -> List[Dict]: | |
| query_embedding = self._encode_content(query, "text") | |
| dense_results = [] | |
| try: | |
| pipeline = [ | |
| {"$vectorSearch": { | |
| "index": VECTOR_INDEX_NAME, | |
| "path": "embedding", | |
| "queryVector": query_embedding.tolist(), | |
| "numCandidates": top_k * 10, | |
| "limit": top_k * 2 | |
| }}, | |
| {"$project": {"content": 1, "metadata": 1, "score": {"$meta": "vectorSearchScore"}}} | |
| ] | |
| dense_results = list(self.collection.aggregate(pipeline)) | |
| except: pass | |
| dense_scores = {str(r["_id"]): {"score": r.get("score", 0), "doc": r} for r in dense_results} | |
| sparse_scores = {} | |
| if self.bm25_index: | |
| scores = self.bm25_index.get_scores(query.lower().split()) | |
| max_s = max(scores) if len(scores) > 0 and max(scores) > 0 else 1.0 | |
| for i in np.argsort(scores)[::-1][:top_k*2]: | |
| if scores[i] > 0: | |
| sparse_scores[self.bm25_doc_map[i]] = scores[i] / max_s | |
| combined = [] | |
| all_ids = set(dense_scores.keys()) | set(sparse_scores.keys()) | |
| for did in all_ids: | |
| d_s = dense_scores.get(did, {}).get("score", 0) | |
| s_s = sparse_scores.get(did, 0) | |
| score = (alpha * d_s) + ((1-alpha) * s_s) | |
| doc = dense_scores.get(did, {}).get("doc") or self.collection.find_one({"_id": did}) | |
| if doc: | |
| combined.append({**doc, "score": score}) | |
| combined.sort(key=lambda x: x["score"], reverse=True) | |
| return combined[:top_k] | |
| def answer_question(self, question: str, top_k: int = 5) -> str: | |
| results = self.hybrid_search(question, top_k=top_k) | |
| if not results: return "No relevant info found." | |
| context = "" | |
| for i, res in enumerate(results, 1): | |
| m = res["metadata"] | |
| context += f"\n[Src {i} | Page {m.get('page_number','?')}] {res['content']}" | |
| prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer strictly based on context:" | |
| try: | |
| chain = ChatPromptTemplate.from_template("{p}") | self.llm | StrOutputParser() | |
| # return chain.invoke({"p": prompt}) | |
| for msg in chain.stream({"p": prompt}): | |
| if hasattr(msg, "content"): | |
| time.sleep(0.01) | |
| yield msg.content | |
| else: | |
| time.sleep(0.01) | |
| yield str(msg) | |
| except Exception as e: return f"Error: {e}" | |
| def search_images(self, query: str, top_k: int = 3, min_score: float = 0.5) -> List[Dict]: | |
| query_embedding = self._encode_content(f"{query}", "text") | |
| try: | |
| pipeline = [ | |
| {"$vectorSearch": { | |
| "index": VECTOR_INDEX_NAME, "path": "embedding", | |
| "queryVector": query_embedding.tolist(), "numCandidates": top_k*10, "limit": top_k*2, | |
| "filter": {"metadata.type": "image"} | |
| }}, | |
| {"$project": {"content": 1, "metadata": 1, "score": {"$meta": "vectorSearchScore"}}} | |
| ] | |
| results = list(self.collection.aggregate(pipeline)) | |
| return [{"description": r["content"], "image_base64": r["metadata"].get("image_base64"), "score": r["score"]} | |
| for r in results if r["score"] >= min_score][:top_k] | |
| except Exception as e: | |
| print("*********error", str(e)) | |
| return [] | |
| # def generate_suggested_questions(self, num_questions: int = 5) -> List[str]: | |
| # # Simple metadata-based generation or just a fixed list for now | |
| # return ["What is the main topic?", "Explain the diagrams.", "Summarize the results."] | |
| def generate_suggested_questions(self, num_questions: int = 4) -> List[str]: | |
| """Token-efficient question generation using metadata.""" | |
| print("\nGenerating suggested questions (Efficient Mode)...") | |
| try: | |
| # 1. Fetch metadata ONLY (projection excludes embedding and content) | |
| cursor = self.collection.find( | |
| {}, | |
| {"metadata": 1, "_id": 0} | |
| ).limit(100) | |
| metadatas = [doc.get('metadata', {}) for doc in cursor] | |
| if not metadatas: | |
| return ["What is this document about?"] | |
| # 2. Extract High-Level Structure | |
| headers = set() | |
| image_descriptions = [] | |
| import random | |
| random.shuffle(metadatas) | |
| for meta in metadatas: | |
| if 'header' in meta and len(headers) < 8: | |
| h = str(meta['header']).strip() | |
| if h and h.lower() != "unknown" and len(h) > 5: | |
| headers.add(h) | |
| if meta.get('type') == 'image' and len(image_descriptions) < 2: | |
| desc = meta.get('description', '') | |
| if len(desc) > 20: | |
| image_descriptions.append(desc[:100] + "...") | |
| # 3. Construct Prompt | |
| context_str = "Document Sections:\n" + "\n".join([f"- {h}" for h in headers]) | |
| if image_descriptions: | |
| context_str += "\n\nVisual Content involves:\n" + "\n".join([f"- {d}" for d in image_descriptions]) | |
| # 4. Prompt LLM | |
| prompt = f"""Generate {num_questions} short, interesting questions about a document with these sections and visuals: | |
| {context_str} | |
| Output ONLY the {num_questions} questions, one per line. No numbering.""" | |
| prompt_tmpl = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful assistant."), | |
| ("user", "{prompt}") | |
| ]) | |
| chain = prompt_tmpl | self.llm | StrOutputParser() | |
| response = chain.invoke({"prompt": prompt}) | |
| questions = [q.strip().lstrip('-1234567890. ') for q in response.split('\n') if q.strip()] | |
| return questions[:num_questions] | |
| except Exception as e: | |
| print(f"Error generating questions: {e}") |