chinmayjha's picture
Use return_full_result to extract raw answer_with_sources output
65229ef unverified
raw
history blame
9.88 kB
import json
from pathlib import Path
import yaml
from loguru import logger
from opik import opik_context, track
from smolagents import Tool
from pymongo import MongoClient
from second_brain_online.application.rag import get_retriever
from second_brain_online.config import settings
class MongoDBRetrieverTool(Tool):
name = "mongodb_vector_search_retriever"
description = """Use this tool to search and retrieve relevant documents from a knowledge base using semantic search.
This tool performs similarity-based search to find the most relevant documents matching the query.
Best used when you need to:
- Find specific information from stored documents
- Get context about a topic
- Research historical data or documentation
The tool will return multiple relevant document snippets."""
inputs = {
"query": {
"type": "string",
"description": """The search query to find relevant documents for using semantic search.
Should be a clear, specific question or statement about the information you're looking for.""",
}
}
output_type = "string"
def __init__(self, config_path: Path, **kwargs):
super().__init__(**kwargs)
self.config_path = config_path
self.retriever = self.__load_retriever(config_path)
# Setup MongoDB client for fetching conversation insights
self.mongodb_client = MongoClient(settings.MONGODB_URI)
self.database = self.mongodb_client[settings.MONGODB_DATABASE_NAME]
self.conversation_docs_collection = self.database["test_conversation_documents"]
def __load_retriever(self, config_path: Path):
config = yaml.safe_load(config_path.read_text())
config = config["parameters"]
return get_retriever(
embedding_model_id=config["embedding_model_id"],
embedding_model_type=config["embedding_model_type"],
retriever_type=config["retriever_type"],
k=5, # Reduced from 10 to 5 for faster processing
device=config["device"],
enable_reranking=config.get("enable_reranking", False),
rerank_model_name=config.get("rerank_model_name", "cross-encoder/ms-marco-MiniLM-L-2-v2"),
stage1_limit=config.get("stage1_limit", 50),
final_k=config.get("final_k", 5), # Reduced from 10 to 5
)
def __fetch_conversation_insights(self, document_ids: list[str]) -> dict:
"""
Fetch conversation_insights and metadata for the given document IDs from test_conversation_documents.
Args:
document_ids: List of document IDs to fetch insights for
Returns:
Dictionary mapping document_id -> {conversation_insights, url, source, user_id}
"""
insights_map = {}
not_found_count = 0
# Fetch documents from MongoDB with additional metadata
cursor = self.conversation_docs_collection.find(
{"id": {"$in": document_ids}},
{
"id": 1,
"conversation_insights": 1,
"metadata.url": 1,
"metadata.source": 1,
"metadata.user_id": 1
}
)
for doc in cursor:
doc_id = doc.get("id")
insights = doc.get("conversation_insights")
metadata = doc.get("metadata", {})
if insights:
insights_map[doc_id] = {
"conversation_insights": insights,
"url": metadata.get("url"),
"source": metadata.get("source"),
"user_id": metadata.get("user_id")
}
# Track mismatches
not_found_count = len(document_ids) - len(insights_map)
if not_found_count > 0:
logger.warning(f"Could not find conversation_insights for {not_found_count} out of {len(document_ids)} document IDs")
return insights_map
@track(name="MongoDBRetrieverTool.forward")
def forward(self, query: str) -> str:
if hasattr(self.retriever, "search_kwargs"):
search_kwargs = self.retriever.search_kwargs
else:
try:
search_kwargs = {
"fulltext_penalty": self.retriever.fulltext_penalty,
"vector_score_penalty": self.retriever.vector_penalty,
"top_k": self.retriever.top_k,
}
except AttributeError:
logger.warning("Could not extract search kwargs from retriever.")
search_kwargs = {}
opik_context.update_current_trace(
tags=["agent"],
metadata={
"search": search_kwargs,
"embedding_model_id": self.retriever.vectorstore.embeddings.model,
},
)
try:
query = self.__parse_query(query)
relevant_docs = self.retriever.invoke(query)
# Step 1: Extract unique document IDs from chunks
document_ids = []
for doc in relevant_docs:
doc_id = doc.metadata.get("id")
if doc_id:
document_ids.append(doc_id)
# Step 2: Fetch conversation insights for unique IDs
unique_doc_ids = list(set(document_ids)) # De-duplicate
insights_map = self.__fetch_conversation_insights(unique_doc_ids)
# Step 3: Group chunks by document ID to avoid duplicating insights
docs_by_id = {}
skipped_chunks = 0
for i, doc in enumerate(relevant_docs, 1):
doc_id = doc.metadata.get("id")
# Skip chunks without conversation insights
if not doc_id or doc_id not in insights_map:
skipped_chunks += 1
logger.debug(f"Skipping chunk {i} - no conversation insights available for doc_id: {doc_id}")
continue
# Group chunks by document ID
if doc_id not in docs_by_id:
docs_by_id[doc_id] = {
"title": doc.metadata.get("title", "Untitled"),
"datetime": doc.metadata.get("datetime", "unknown"),
"source": insights_map[doc_id].get("source", "Unknown Source"),
"url": insights_map[doc_id].get("url", ""),
"user_id": insights_map[doc_id].get("user_id", ""),
"insights": insights_map[doc_id]["conversation_insights"],
"chunks": []
}
# Add this chunk's contextual summary to the document
docs_by_id[doc_id]["chunks"].append(doc.metadata.get("contextual_summary", ""))
# Step 4: Format unique documents with their insights
formatted_docs = []
for doc_num, (doc_id, doc_info) in enumerate(docs_by_id.items(), 1):
doc_text = f"=== DOCUMENT {doc_num} ===\n"
doc_text += f"Title: {doc_info['title']}\n"
doc_text += f"Date: {doc_info['datetime']}\n"
doc_text += f"Source: {doc_info['source']} | ID: {doc_id}"
if doc_info['user_id']:
doc_text += f" | User: {doc_info['user_id']}"
if doc_info['url']:
doc_text += f"\nURL: {doc_info['url']}"
# Add all chunk contexts from this conversation
doc_text += f"\n\nCONTEXT (from {len(doc_info['chunks'])} chunk(s)):\n"
for chunk_idx, chunk_context in enumerate(doc_info['chunks'], 1):
doc_text += f"{chunk_idx}. {chunk_context}\n"
# Add conversation insights (for Sources section only - not for answer generation)
insights = doc_info['insights']
doc_text += f"\n[METADATA FOR SOURCES SECTION]\n"
summary = insights.get("summary", "")
if summary:
doc_text += f"Summary: {summary}\n"
key_findings = insights.get("key_findings", [])
if key_findings:
doc_text += "Key Findings:\n"
for finding in key_findings:
insight_type = finding.get("insight_type", "Unknown")
finding_text = finding.get("finding", "")
impact = finding.get("impact", "Unknown")
doc_text += f"- [{insight_type}/{impact}] {finding_text}\n"
doc_text += "\n---\n"
formatted_docs.append(doc_text)
# Log statistics
logger.info(f"Retrieved {len(relevant_docs)} chunks from {len(docs_by_id)} unique conversations, skipped {skipped_chunks} without insights")
result = "\n".join(formatted_docs)
result = f"""SEARCH RESULTS
===============
{result}
When using context, reference the document title, date, and ID for attribution.
"""
return result
except Exception:
logger.opt(exception=True).debug("Error retrieving documents.")
return "Error retrieving documents."
@track(name="MongoDBRetrieverTool.parse_query")
def __parse_query(self, query: str) -> str:
try:
# Try to parse as JSON first
query_dict = json.loads(query)
return query_dict["query"]
except (json.JSONDecodeError, KeyError):
# If JSON parsing fails, return the query as-is
return query