chinmayjha's picture
feat: optimize RAG agent with token reduction and separate context/sources
a697e1b unverified
import json
from pathlib import Path
import yaml
from loguru import logger
from opik import opik_context, track
from smolagents import Tool
from pymongo import MongoClient
from second_brain_online.application.rag import get_retriever
from second_brain_online.config import settings
class MongoDBRetrieverTool(Tool):
name = "mongodb_vector_search_retriever"
description = """Use this tool to search and retrieve relevant documents from a knowledge base using semantic search.
This tool performs similarity-based search to find the most relevant documents matching the query.
Best used when you need to:
- Find specific information from stored documents
- Get context about a topic
- Research historical data or documentation
The tool will return multiple relevant document snippets."""
inputs = {
"query": {
"type": "string",
"description": """The search query to find relevant documents for using semantic search.
Should be a clear, specific question or statement about the information you're looking for.""",
}
}
output_type = "string"
# Class variable to store formatted sources for the summarizer tool to access
# This allows us to pass ONLY lightweight context to the LLM, while the summarizer
# can append the full sources section to the final answer
_cached_sources = ""
def __init__(self, config_path: Path, **kwargs):
super().__init__(**kwargs)
self.config_path = config_path
self.retriever = self.__load_retriever(config_path)
# Setup MongoDB client for fetching conversation insights
self.mongodb_client = MongoClient(settings.MONGODB_URI)
self.database = self.mongodb_client[settings.MONGODB_DATABASE_NAME]
self.conversation_docs_collection = self.database["test_conversation_documents"]
def __load_retriever(self, config_path: Path):
config = yaml.safe_load(config_path.read_text())
config = config["parameters"]
return get_retriever(
embedding_model_id=config["embedding_model_id"],
embedding_model_type=config["embedding_model_type"],
retriever_type=config["retriever_type"],
k=5, # Reduced from 10 to 5 for faster processing
device=config["device"],
enable_reranking=config.get("enable_reranking", False),
rerank_model_name=config.get("rerank_model_name", "cross-encoder/ms-marco-MiniLM-L-2-v2"),
stage1_limit=config.get("stage1_limit", 50),
final_k=config.get("final_k", 5), # Reduced from 10 to 5
)
def __fetch_conversation_insights(self, document_ids: list[str]) -> dict:
"""
Fetch conversation_insights and metadata for the given document IDs from test_conversation_documents.
Args:
document_ids: List of document IDs to fetch insights for
Returns:
Dictionary mapping document_id -> {conversation_insights, url, source, user_id}
"""
insights_map = {}
not_found_count = 0
# Fetch documents from MongoDB with additional metadata
cursor = self.conversation_docs_collection.find(
{"id": {"$in": document_ids}},
{
"id": 1,
"conversation_insights": 1,
"metadata.url": 1,
"metadata.source": 1,
"metadata.user_id": 1
}
)
for doc in cursor:
doc_id = doc.get("id")
insights = doc.get("conversation_insights")
metadata = doc.get("metadata", {})
if insights:
insights_map[doc_id] = {
"conversation_insights": insights,
"url": metadata.get("url"),
"source": metadata.get("source"),
"user_id": metadata.get("user_id")
}
# Track mismatches
not_found_count = len(document_ids) - len(insights_map)
if not_found_count > 0:
logger.warning(f"Could not find conversation_insights for {not_found_count} out of {len(document_ids)} document IDs")
return insights_map
@track(name="MongoDBRetrieverTool.forward")
def forward(self, query: str) -> str:
if hasattr(self.retriever, "search_kwargs"):
search_kwargs = self.retriever.search_kwargs
else:
try:
search_kwargs = {
"fulltext_penalty": self.retriever.fulltext_penalty,
"vector_score_penalty": self.retriever.vector_penalty,
"top_k": self.retriever.top_k,
}
except AttributeError:
logger.warning("Could not extract search kwargs from retriever.")
search_kwargs = {}
opik_context.update_current_trace(
tags=["agent"],
metadata={
"search": search_kwargs,
"embedding_model_id": self.retriever.vectorstore.embeddings.model,
},
)
try:
query = self.__parse_query(query)
relevant_docs = self.retriever.invoke(query)
# Step 1: Extract unique document IDs from chunks
document_ids = []
for doc in relevant_docs:
doc_id = doc.metadata.get("id")
if doc_id:
document_ids.append(doc_id)
# Step 2: Fetch conversation insights for unique IDs
unique_doc_ids = list(set(document_ids)) # De-duplicate
insights_map = self.__fetch_conversation_insights(unique_doc_ids)
# Step 3: Group chunks by document ID to avoid duplicating insights
docs_by_id = {}
skipped_chunks = 0
for i, doc in enumerate(relevant_docs, 1):
doc_id = doc.metadata.get("id")
# Skip chunks without conversation insights
if not doc_id or doc_id not in insights_map:
skipped_chunks += 1
logger.debug(f"Skipping chunk {i} - no conversation insights available for doc_id: {doc_id}")
continue
# Group chunks by document ID
if doc_id not in docs_by_id:
docs_by_id[doc_id] = {
"title": doc.metadata.get("title", "Untitled"),
"datetime": doc.metadata.get("datetime", "unknown"),
"source": insights_map[doc_id].get("source", "Unknown Source"),
"url": insights_map[doc_id].get("url", ""),
"user_id": insights_map[doc_id].get("user_id", ""),
"insights": insights_map[doc_id]["conversation_insights"],
"chunks": []
}
# Add this chunk's contextual summary to the document
docs_by_id[doc_id]["chunks"].append(doc.metadata.get("contextual_summary", ""))
# Step 4: Format output into TWO sections to reduce LLM token usage
# Section A: Lightweight context for LLM answer generation (minimal tokens)
# Section B: Full metadata for sources section (appended to final answer)
context_for_llm = [] # Lightweight format for answer generation
metadata_for_sources = [] # Full format for sources section
for doc_num, (doc_id, doc_info) in enumerate(docs_by_id.items(), 1):
# =================================================================
# SECTION A: CONTEXT FOR LLM (Lightweight - Reduced Token Usage)
# =================================================================
# Format: Doc Title | Date | User ID
# - Contextual Summary 1
# - Contextual Summary 2
# This section is sent to the LLM for answer generation
context_text = f"Doc {doc_num}: {doc_info['title']} | {doc_info['datetime']}"
if doc_info['user_id']:
context_text += f" | User: {doc_info['user_id']}"
context_text += "\n"
# Add contextual summaries as bullet points (compact format)
for chunk_context in doc_info['chunks']:
context_text += f"- {chunk_context}\n"
context_text += "\n"
context_for_llm.append(context_text)
# =================================================================
# SECTION B: METADATA FOR SOURCES (Full details for final answer)
# =================================================================
# This section is NOT sent to the LLM but appended to the final answer
# Contains full conversation insights, URLs, and structured metadata
source_text = f"Doc {doc_num}: {doc_info['title']} ({doc_info['datetime']})\n"
source_text += f"Source: {doc_info['source']} | Document ID: {doc_id}"
if doc_info['url']:
source_text += f" | [View Chat]({doc_info['url']})"
if doc_info['user_id']:
source_text += f" | User ID: {doc_info['user_id']}"
source_text += "\n\n"
# Add conversation insights (summary + key findings)
insights = doc_info['insights']
summary = insights.get("summary", "")
if summary:
source_text += f"Summary: {summary}\n\n"
key_findings = insights.get("key_findings", [])
if key_findings:
source_text += "Key Findings:\n"
for finding in key_findings:
insight_type = finding.get("insight_type", "Unknown")
finding_text = finding.get("finding", "")
impact = finding.get("impact", "Unknown")
source_text += f"- [{insight_type}/{impact}] {finding_text}\n"
source_text += "\n---\n\n"
metadata_for_sources.append(source_text)
# Log statistics for monitoring
logger.info(f"Retrieved {len(relevant_docs)} chunks from {len(docs_by_id)} unique conversations, skipped {skipped_chunks} without insights")
# =================================================================
# STORE SOURCES SEPARATELY AND RETURN ONLY LIGHTWEIGHT CONTEXT
# =================================================================
# Strategy: Store formatted sources in class variable for summarizer to access
# Return ONLY lightweight context to LLM (reduces tokens significantly)
# Summarizer will append sources directly to final answer
# Build lightweight context string (ONLY this goes to LLM)
context_section = "".join(context_for_llm)
# Build formatted sources string (stored for later appending)
metadata_section = "".join(metadata_for_sources)
# Store sources in class variable for summarizer tool to access
# This ensures we don't send sources to the LLM at all
MongoDBRetrieverTool._cached_sources = f"""πŸ“š Sources
{metadata_section}"""
# Return ONLY the lightweight context to be sent to LLM
logger.info(f"Returning {len(context_section)} chars of context to LLM, {len(MongoDBRetrieverTool._cached_sources)} chars cached for sources")
return context_section
except Exception:
logger.opt(exception=True).debug("Error retrieving documents.")
return "Error retrieving documents."
@track(name="MongoDBRetrieverTool.parse_query")
def __parse_query(self, query: str) -> str:
try:
# Try to parse as JSON first
query_dict = json.loads(query)
return query_dict["query"]
except (json.JSONDecodeError, KeyError):
# If JSON parsing fails, return the query as-is
return query