Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import yaml | |
| from loguru import logger | |
| from opik import opik_context, track | |
| from smolagents import Tool | |
| from second_brain_online.application.rag import get_retriever | |
| class MongoDBRetrieverTool(Tool): | |
| name = "mongodb_vector_search_retriever" | |
| description = """Use this tool to search and retrieve relevant documents from a knowledge base using semantic search. | |
| This tool performs similarity-based search to find the most relevant documents matching the query. | |
| Best used when you need to: | |
| - Find specific information from stored documents | |
| - Get context about a topic | |
| - Research historical data or documentation | |
| The tool will return multiple relevant document snippets.""" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": """The search query to find relevant documents for using semantic search. | |
| Should be a clear, specific question or statement about the information you're looking for.""", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, config_path: Path, **kwargs): | |
| super().__init__(**kwargs) | |
| self.config_path = config_path | |
| self.retriever = self.__load_retriever(config_path) | |
| def __load_retriever(self, config_path: Path): | |
| config = yaml.safe_load(config_path.read_text()) | |
| config = config["parameters"] | |
| return get_retriever( | |
| embedding_model_id=config["embedding_model_id"], | |
| embedding_model_type=config["embedding_model_type"], | |
| retriever_type=config["retriever_type"], | |
| k=5, | |
| device=config["device"], | |
| ) | |
| def forward(self, query: str) -> str: | |
| if hasattr(self.retriever, "search_kwargs"): | |
| search_kwargs = self.retriever.search_kwargs | |
| else: | |
| try: | |
| search_kwargs = { | |
| "fulltext_penalty": self.retriever.fulltext_penalty, | |
| "vector_score_penalty": self.retriever.vector_penalty, | |
| "top_k": self.retriever.top_k, | |
| } | |
| except AttributeError: | |
| logger.warning("Could not extract search kwargs from retriever.") | |
| search_kwargs = {} | |
| opik_context.update_current_trace( | |
| tags=["agent"], | |
| metadata={ | |
| "search": search_kwargs, | |
| "embedding_model_id": self.retriever.vectorstore.embeddings.model, | |
| }, | |
| ) | |
| try: | |
| query = self.__parse_query(query) | |
| relevant_docs = self.retriever.invoke(query) | |
| formatted_docs = [] | |
| for i, doc in enumerate(relevant_docs, 1): | |
| # Extract metadata | |
| title = doc.metadata.get("title", "Untitled") | |
| datetime = doc.metadata.get("datetime", "unknown") | |
| contextual_summary = doc.metadata.get("contextual_summary", "") | |
| marketing_insights = doc.metadata.get("marketing_insights", {}) | |
| content = doc.page_content.strip() | |
| # Format marketing insights if available | |
| marketing_insights_text = "" | |
| if marketing_insights: | |
| marketing_insights_text = "\n<marketing_insights>\n" | |
| # Add quotes | |
| quotes = marketing_insights.get("quotes", []) | |
| if quotes: | |
| marketing_insights_text += "<quotes>\n" | |
| for quote in quotes: | |
| marketing_insights_text += f"- \"{quote.get('quote', '')}\" (Sentiment: {quote.get('sentiment', 'Unknown')})\n" | |
| marketing_insights_text += "</quotes>\n" | |
| # Add key findings | |
| findings = marketing_insights.get("key_findings", []) | |
| if findings: | |
| marketing_insights_text += "<key_findings>\n" | |
| for finding in findings: | |
| marketing_insights_text += f"- {finding.get('finding', '')} (Impact: {finding.get('impact', 'Unknown')})\n" | |
| marketing_insights_text += "</key_findings>\n" | |
| marketing_insights_text += "</marketing_insights>\n" | |
| # Create optimized document structure - truncate content to avoid token overload | |
| content_preview = content[:500] + "..." if len(content) > 500 else content | |
| formatted_docs.append( | |
| f""" | |
| <document id="{i}"> | |
| <title>{title}</title> | |
| <date>{datetime}</date> | |
| <contextual_summary> | |
| {contextual_summary} | |
| </contextual_summary> | |
| {marketing_insights_text} | |
| <content> | |
| {content_preview} | |
| </content> | |
| </document> | |
| """ | |
| ) | |
| result = "\n".join(formatted_docs) | |
| result = f""" | |
| <search_results> | |
| {result} | |
| </search_results> | |
| When using context from any document, reference the document title and date for attribution. | |
| """ | |
| return result | |
| except Exception: | |
| logger.opt(exception=True).debug("Error retrieving documents.") | |
| return "Error retrieving documents." | |
| def __parse_query(self, query: str) -> str: | |
| try: | |
| # Try to parse as JSON first | |
| query_dict = json.loads(query) | |
| return query_dict["query"] | |
| except (json.JSONDecodeError, KeyError): | |
| # If JSON parsing fails, return the query as-is | |
| return query | |