Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import yaml | |
| from loguru import logger | |
| from opik import opik_context, track | |
| from smolagents import Tool | |
| from pymongo import MongoClient | |
| from second_brain_online.application.rag import get_retriever | |
| from second_brain_online.config import settings | |
| class MongoDBRetrieverTool(Tool): | |
| name = "mongodb_vector_search_retriever" | |
| description = """Use this tool to search and retrieve relevant documents from a knowledge base using semantic search. | |
| This tool performs similarity-based search to find the most relevant documents matching the query. | |
| Best used when you need to: | |
| - Find specific information from stored documents | |
| - Get context about a topic | |
| - Research historical data or documentation | |
| The tool will return multiple relevant document snippets.""" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": """The search query to find relevant documents for using semantic search. | |
| Should be a clear, specific question or statement about the information you're looking for.""", | |
| } | |
| } | |
| output_type = "string" | |
| # Class variable to store formatted sources for the summarizer tool to access | |
| # This allows us to pass ONLY lightweight context to the LLM, while the summarizer | |
| # can append the full sources section to the final answer | |
| _cached_sources = "" | |
| def __init__(self, config_path: Path, **kwargs): | |
| super().__init__(**kwargs) | |
| self.config_path = config_path | |
| self.retriever = self.__load_retriever(config_path) | |
| # Setup MongoDB client for fetching conversation insights | |
| self.mongodb_client = MongoClient(settings.MONGODB_URI) | |
| self.database = self.mongodb_client[settings.MONGODB_DATABASE_NAME] | |
| self.conversation_docs_collection = self.database["test_conversation_documents"] | |
| def __load_retriever(self, config_path: Path): | |
| config = yaml.safe_load(config_path.read_text()) | |
| config = config["parameters"] | |
| return get_retriever( | |
| embedding_model_id=config["embedding_model_id"], | |
| embedding_model_type=config["embedding_model_type"], | |
| retriever_type=config["retriever_type"], | |
| k=5, # Reduced from 10 to 5 for faster processing | |
| device=config["device"], | |
| enable_reranking=config.get("enable_reranking", False), | |
| rerank_model_name=config.get("rerank_model_name", "cross-encoder/ms-marco-MiniLM-L-2-v2"), | |
| stage1_limit=config.get("stage1_limit", 50), | |
| final_k=config.get("final_k", 5), # Reduced from 10 to 5 | |
| ) | |
| def __fetch_conversation_insights(self, document_ids: list[str]) -> dict: | |
| """ | |
| Fetch conversation_insights and metadata for the given document IDs from test_conversation_documents. | |
| Args: | |
| document_ids: List of document IDs to fetch insights for | |
| Returns: | |
| Dictionary mapping document_id -> {conversation_insights, url, source, user_id} | |
| """ | |
| insights_map = {} | |
| not_found_count = 0 | |
| # Fetch documents from MongoDB with additional metadata | |
| cursor = self.conversation_docs_collection.find( | |
| {"id": {"$in": document_ids}}, | |
| { | |
| "id": 1, | |
| "conversation_insights": 1, | |
| "metadata.url": 1, | |
| "metadata.source": 1, | |
| "metadata.user_id": 1 | |
| } | |
| ) | |
| for doc in cursor: | |
| doc_id = doc.get("id") | |
| insights = doc.get("conversation_insights") | |
| metadata = doc.get("metadata", {}) | |
| if insights: | |
| insights_map[doc_id] = { | |
| "conversation_insights": insights, | |
| "url": metadata.get("url"), | |
| "source": metadata.get("source"), | |
| "user_id": metadata.get("user_id") | |
| } | |
| # Track mismatches | |
| not_found_count = len(document_ids) - len(insights_map) | |
| if not_found_count > 0: | |
| logger.warning(f"Could not find conversation_insights for {not_found_count} out of {len(document_ids)} document IDs") | |
| return insights_map | |
| def forward(self, query: str) -> str: | |
| if hasattr(self.retriever, "search_kwargs"): | |
| search_kwargs = self.retriever.search_kwargs | |
| else: | |
| try: | |
| search_kwargs = { | |
| "fulltext_penalty": self.retriever.fulltext_penalty, | |
| "vector_score_penalty": self.retriever.vector_penalty, | |
| "top_k": self.retriever.top_k, | |
| } | |
| except AttributeError: | |
| logger.warning("Could not extract search kwargs from retriever.") | |
| search_kwargs = {} | |
| opik_context.update_current_trace( | |
| tags=["agent"], | |
| metadata={ | |
| "search": search_kwargs, | |
| "embedding_model_id": self.retriever.vectorstore.embeddings.model, | |
| }, | |
| ) | |
| try: | |
| query = self.__parse_query(query) | |
| relevant_docs = self.retriever.invoke(query) | |
| # Step 1: Extract unique document IDs from chunks | |
| document_ids = [] | |
| for doc in relevant_docs: | |
| doc_id = doc.metadata.get("id") | |
| if doc_id: | |
| document_ids.append(doc_id) | |
| # Step 2: Fetch conversation insights for unique IDs | |
| unique_doc_ids = list(set(document_ids)) # De-duplicate | |
| insights_map = self.__fetch_conversation_insights(unique_doc_ids) | |
| # Step 3: Group chunks by document ID to avoid duplicating insights | |
| docs_by_id = {} | |
| skipped_chunks = 0 | |
| for i, doc in enumerate(relevant_docs, 1): | |
| doc_id = doc.metadata.get("id") | |
| # Skip chunks without conversation insights | |
| if not doc_id or doc_id not in insights_map: | |
| skipped_chunks += 1 | |
| logger.debug(f"Skipping chunk {i} - no conversation insights available for doc_id: {doc_id}") | |
| continue | |
| # Group chunks by document ID | |
| if doc_id not in docs_by_id: | |
| docs_by_id[doc_id] = { | |
| "title": doc.metadata.get("title", "Untitled"), | |
| "datetime": doc.metadata.get("datetime", "unknown"), | |
| "source": insights_map[doc_id].get("source", "Unknown Source"), | |
| "url": insights_map[doc_id].get("url", ""), | |
| "user_id": insights_map[doc_id].get("user_id", ""), | |
| "insights": insights_map[doc_id]["conversation_insights"], | |
| "chunks": [] | |
| } | |
| # Add this chunk's contextual summary to the document | |
| docs_by_id[doc_id]["chunks"].append(doc.metadata.get("contextual_summary", "")) | |
| # Step 4: Format output into TWO sections to reduce LLM token usage | |
| # Section A: Lightweight context for LLM answer generation (minimal tokens) | |
| # Section B: Full metadata for sources section (appended to final answer) | |
| context_for_llm = [] # Lightweight format for answer generation | |
| metadata_for_sources = [] # Full format for sources section | |
| for doc_num, (doc_id, doc_info) in enumerate(docs_by_id.items(), 1): | |
| # ================================================================= | |
| # SECTION A: CONTEXT FOR LLM (Lightweight - Reduced Token Usage) | |
| # ================================================================= | |
| # Format: Doc Title | Date | User ID | |
| # - Contextual Summary 1 | |
| # - Contextual Summary 2 | |
| # This section is sent to the LLM for answer generation | |
| context_text = f"Doc {doc_num}: {doc_info['title']} | {doc_info['datetime']}" | |
| if doc_info['user_id']: | |
| context_text += f" | User: {doc_info['user_id']}" | |
| context_text += "\n" | |
| # Add contextual summaries as bullet points (compact format) | |
| for chunk_context in doc_info['chunks']: | |
| context_text += f"- {chunk_context}\n" | |
| context_text += "\n" | |
| context_for_llm.append(context_text) | |
| # ================================================================= | |
| # SECTION B: METADATA FOR SOURCES (Full details for final answer) | |
| # ================================================================= | |
| # This section is NOT sent to the LLM but appended to the final answer | |
| # Contains full conversation insights, URLs, and structured metadata | |
| source_text = f"Doc {doc_num}: {doc_info['title']} ({doc_info['datetime']})\n" | |
| source_text += f"Source: {doc_info['source']} | Document ID: {doc_id}" | |
| if doc_info['url']: | |
| source_text += f" | [View Chat]({doc_info['url']})" | |
| if doc_info['user_id']: | |
| source_text += f" | User ID: {doc_info['user_id']}" | |
| source_text += "\n\n" | |
| # Add conversation insights (summary + key findings) | |
| insights = doc_info['insights'] | |
| summary = insights.get("summary", "") | |
| if summary: | |
| source_text += f"Summary: {summary}\n\n" | |
| key_findings = insights.get("key_findings", []) | |
| if key_findings: | |
| source_text += "Key Findings:\n" | |
| for finding in key_findings: | |
| insight_type = finding.get("insight_type", "Unknown") | |
| finding_text = finding.get("finding", "") | |
| impact = finding.get("impact", "Unknown") | |
| source_text += f"- [{insight_type}/{impact}] {finding_text}\n" | |
| source_text += "\n---\n\n" | |
| metadata_for_sources.append(source_text) | |
| # Log statistics for monitoring | |
| logger.info(f"Retrieved {len(relevant_docs)} chunks from {len(docs_by_id)} unique conversations, skipped {skipped_chunks} without insights") | |
| # ================================================================= | |
| # STORE SOURCES SEPARATELY AND RETURN ONLY LIGHTWEIGHT CONTEXT | |
| # ================================================================= | |
| # Strategy: Store formatted sources in class variable for summarizer to access | |
| # Return ONLY lightweight context to LLM (reduces tokens significantly) | |
| # Summarizer will append sources directly to final answer | |
| # Build lightweight context string (ONLY this goes to LLM) | |
| context_section = "".join(context_for_llm) | |
| # Build formatted sources string (stored for later appending) | |
| metadata_section = "".join(metadata_for_sources) | |
| # Store sources in class variable for summarizer tool to access | |
| # This ensures we don't send sources to the LLM at all | |
| MongoDBRetrieverTool._cached_sources = f"""π Sources | |
| {metadata_section}""" | |
| # Return ONLY the lightweight context to be sent to LLM | |
| logger.info(f"Returning {len(context_section)} chars of context to LLM, {len(MongoDBRetrieverTool._cached_sources)} chars cached for sources") | |
| return context_section | |
| except Exception: | |
| logger.opt(exception=True).debug("Error retrieving documents.") | |
| return "Error retrieving documents." | |
| def __parse_query(self, query: str) -> str: | |
| try: | |
| # Try to parse as JSON first | |
| query_dict = json.loads(query) | |
| return query_dict["query"] | |
| except (json.JSONDecodeError, KeyError): | |
| # If JSON parsing fails, return the query as-is | |
| return query | |