Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import yaml | |
| from loguru import logger | |
| from opik import opik_context, track | |
| from smolagents import Tool | |
| from pymongo import MongoClient | |
| from second_brain_online.application.rag import get_retriever | |
| from second_brain_online.config import settings | |
| class MongoDBRetrieverTool(Tool): | |
| name = "mongodb_vector_search_retriever" | |
| description = """Use this tool to search and retrieve relevant documents from a knowledge base using semantic search. | |
| This tool performs similarity-based search to find the most relevant documents matching the query. | |
| Best used when you need to: | |
| - Find specific information from stored documents | |
| - Get context about a topic | |
| - Research historical data or documentation | |
| The tool will return multiple relevant document snippets.""" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": """The search query to find relevant documents for using semantic search. | |
| Should be a clear, specific question or statement about the information you're looking for.""", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, config_path: Path, **kwargs): | |
| super().__init__(**kwargs) | |
| self.config_path = config_path | |
| self.retriever = self.__load_retriever(config_path) | |
| # Setup MongoDB client for fetching conversation insights | |
| self.mongodb_client = MongoClient(settings.MONGODB_URI) | |
| self.database = self.mongodb_client[settings.MONGODB_DATABASE_NAME] | |
| self.conversation_docs_collection = self.database["test_conversation_documents"] | |
| def __load_retriever(self, config_path: Path): | |
| config = yaml.safe_load(config_path.read_text()) | |
| config = config["parameters"] | |
| return get_retriever( | |
| embedding_model_id=config["embedding_model_id"], | |
| embedding_model_type=config["embedding_model_type"], | |
| retriever_type=config["retriever_type"], | |
| k=5, # Reduced from 10 to 5 for faster processing | |
| device=config["device"], | |
| enable_reranking=config.get("enable_reranking", False), | |
| rerank_model_name=config.get("rerank_model_name", "cross-encoder/ms-marco-MiniLM-L-2-v2"), | |
| stage1_limit=config.get("stage1_limit", 50), | |
| final_k=config.get("final_k", 5), # Reduced from 10 to 5 | |
| ) | |
| def __fetch_conversation_insights(self, document_ids: list[str]) -> dict: | |
| """ | |
| Fetch conversation_insights and metadata for the given document IDs from test_conversation_documents. | |
| Args: | |
| document_ids: List of document IDs to fetch insights for | |
| Returns: | |
| Dictionary mapping document_id -> {conversation_insights, url, source, user_id} | |
| """ | |
| insights_map = {} | |
| not_found_count = 0 | |
| # Fetch documents from MongoDB with additional metadata | |
| cursor = self.conversation_docs_collection.find( | |
| {"id": {"$in": document_ids}}, | |
| { | |
| "id": 1, | |
| "conversation_insights": 1, | |
| "metadata.url": 1, | |
| "metadata.source": 1, | |
| "metadata.user_id": 1 | |
| } | |
| ) | |
| for doc in cursor: | |
| doc_id = doc.get("id") | |
| insights = doc.get("conversation_insights") | |
| metadata = doc.get("metadata", {}) | |
| if insights: | |
| insights_map[doc_id] = { | |
| "conversation_insights": insights, | |
| "url": metadata.get("url"), | |
| "source": metadata.get("source"), | |
| "user_id": metadata.get("user_id") | |
| } | |
| # Track mismatches | |
| not_found_count = len(document_ids) - len(insights_map) | |
| if not_found_count > 0: | |
| logger.warning(f"Could not find conversation_insights for {not_found_count} out of {len(document_ids)} document IDs") | |
| return insights_map | |
| def forward(self, query: str) -> str: | |
| if hasattr(self.retriever, "search_kwargs"): | |
| search_kwargs = self.retriever.search_kwargs | |
| else: | |
| try: | |
| search_kwargs = { | |
| "fulltext_penalty": self.retriever.fulltext_penalty, | |
| "vector_score_penalty": self.retriever.vector_penalty, | |
| "top_k": self.retriever.top_k, | |
| } | |
| except AttributeError: | |
| logger.warning("Could not extract search kwargs from retriever.") | |
| search_kwargs = {} | |
| opik_context.update_current_trace( | |
| tags=["agent"], | |
| metadata={ | |
| "search": search_kwargs, | |
| "embedding_model_id": self.retriever.vectorstore.embeddings.model, | |
| }, | |
| ) | |
| try: | |
| query = self.__parse_query(query) | |
| relevant_docs = self.retriever.invoke(query) | |
| # Step 1: Extract unique document IDs from chunks | |
| document_ids = [] | |
| for doc in relevant_docs: | |
| doc_id = doc.metadata.get("id") | |
| if doc_id: | |
| document_ids.append(doc_id) | |
| # Step 2: Fetch conversation insights for unique IDs | |
| unique_doc_ids = list(set(document_ids)) # De-duplicate | |
| insights_map = self.__fetch_conversation_insights(unique_doc_ids) | |
| # Step 3: Group chunks by document ID to avoid duplicating insights | |
| docs_by_id = {} | |
| skipped_chunks = 0 | |
| for i, doc in enumerate(relevant_docs, 1): | |
| doc_id = doc.metadata.get("id") | |
| # Skip chunks without conversation insights | |
| if not doc_id or doc_id not in insights_map: | |
| skipped_chunks += 1 | |
| logger.debug(f"Skipping chunk {i} - no conversation insights available for doc_id: {doc_id}") | |
| continue | |
| # Group chunks by document ID | |
| if doc_id not in docs_by_id: | |
| docs_by_id[doc_id] = { | |
| "title": doc.metadata.get("title", "Untitled"), | |
| "datetime": doc.metadata.get("datetime", "unknown"), | |
| "source": insights_map[doc_id].get("source", "Unknown Source"), | |
| "url": insights_map[doc_id].get("url", ""), | |
| "user_id": insights_map[doc_id].get("user_id", ""), | |
| "insights": insights_map[doc_id]["conversation_insights"], | |
| "chunks": [] | |
| } | |
| # Add this chunk's contextual summary to the document | |
| docs_by_id[doc_id]["chunks"].append(doc.metadata.get("contextual_summary", "")) | |
| # Step 4: Format unique documents with their insights | |
| formatted_docs = [] | |
| for doc_num, (doc_id, doc_info) in enumerate(docs_by_id.items(), 1): | |
| doc_text = f"=== DOCUMENT {doc_num} ===\n" | |
| doc_text += f"Title: {doc_info['title']}\n" | |
| doc_text += f"Date: {doc_info['datetime']}\n" | |
| doc_text += f"Source: {doc_info['source']} | ID: {doc_id}" | |
| if doc_info['user_id']: | |
| doc_text += f" | User: {doc_info['user_id']}" | |
| if doc_info['url']: | |
| doc_text += f"\nURL: {doc_info['url']}" | |
| # Add all chunk contexts from this conversation | |
| doc_text += f"\n\nCONTEXT (from {len(doc_info['chunks'])} chunk(s)):\n" | |
| for chunk_idx, chunk_context in enumerate(doc_info['chunks'], 1): | |
| doc_text += f"{chunk_idx}. {chunk_context}\n" | |
| # Add conversation insights (for Sources section only - not for answer generation) | |
| insights = doc_info['insights'] | |
| doc_text += f"\n[METADATA FOR SOURCES SECTION]\n" | |
| summary = insights.get("summary", "") | |
| if summary: | |
| doc_text += f"Summary: {summary}\n" | |
| key_findings = insights.get("key_findings", []) | |
| if key_findings: | |
| doc_text += "Key Findings:\n" | |
| for finding in key_findings: | |
| insight_type = finding.get("insight_type", "Unknown") | |
| finding_text = finding.get("finding", "") | |
| impact = finding.get("impact", "Unknown") | |
| doc_text += f"- [{insight_type}/{impact}] {finding_text}\n" | |
| doc_text += "\n---\n" | |
| formatted_docs.append(doc_text) | |
| # Log statistics | |
| logger.info(f"Retrieved {len(relevant_docs)} chunks from {len(docs_by_id)} unique conversations, skipped {skipped_chunks} without insights") | |
| result = "\n".join(formatted_docs) | |
| result = f"""SEARCH RESULTS | |
| =============== | |
| {result} | |
| When using context, reference the document title, date, and ID for attribution. | |
| """ | |
| return result | |
| except Exception: | |
| logger.opt(exception=True).debug("Error retrieving documents.") | |
| return "Error retrieving documents." | |
| def __parse_query(self, query: str) -> str: | |
| try: | |
| # Try to parse as JSON first | |
| query_dict = json.loads(query) | |
| return query_dict["query"] | |
| except (json.JSONDecodeError, KeyError): | |
| # If JSON parsing fails, return the query as-is | |
| return query | |