chinmayjha's picture
feat: optimize RAG agent with token reduction and separate context/sources
a697e1b unverified
from openai import OpenAI
from opik import track
from smolagents import Tool
from loguru import logger
from second_brain_online.config import settings
class HuggingFaceEndpointSummarizerTool(Tool):
name = "huggingface_summarizer"
description = """Use this tool to summarize a piece of text. Especially useful when you need to summarize a document."""
inputs = {
"text": {
"type": "string",
"description": """The text to summarize.""",
}
}
output_type = "string"
SYSTEM_PROMPT = """
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
You are a helpful assistant specialized in summarizing documents. Generate a concise TL;DR summary in markdown format having a maximum of 512 characters of the key findings from the provided documents, highlighting the most significant insights
### Input:
{content}
### Response:
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
assert settings.HUGGINGFACE_ACCESS_TOKEN is not None, (
"HUGGINGFACE_ACCESS_TOKEN is required to use the dedicated endpoint. Add it to the .env file."
)
assert settings.HUGGINGFACE_DEDICATED_ENDPOINT is not None, (
"HUGGINGFACE_DEDICATED_ENDPOINT is required to use the dedicated endpoint. Add it to the .env file."
)
self.__client = OpenAI(
base_url=settings.HUGGINGFACE_DEDICATED_ENDPOINT,
api_key=settings.HUGGINGFACE_ACCESS_TOKEN,
)
@track
def forward(self, text: str) -> str:
result = self.__client.chat.completions.create(
model="tgi",
messages=[
{
"role": "user",
"content": self.SYSTEM_PROMPT.format(content=text),
},
],
)
return result.choices[0].message.content
class OpenAISummarizerTool(Tool):
name = "answer_with_sources"
description = """Use this tool to generate the FINAL answer to the user's question based on search results.
After retrieving documents with mongodb_vector_search_retriever, use this tool to synthesize a comprehensive answer with a Sources section.
CRITICAL: This tool generates the final answer that will be returned to the user. Do NOT modify or reformat its output in any way."""
inputs = {
"search_results": {
"type": "string",
"description": """The complete search results from mongodb_vector_search_retriever to analyze and synthesize into an answer. Pass the ENTIRE output from the retriever tool.""",
}
}
output_type = "string"
SYSTEM_PROMPT = """Based on the context below, create a comprehensive answer to the user's question.
{content}
IMPORTANT INSTRUCTIONS:
The context contains lightweight information from retrieved documents with this format:
- Doc X: Title | Date | User ID
- Contextual summaries as bullet points
Generate ONLY the ANSWER section with inline citations:
**ANSWER** (with inline citations):
- Base your answer ONLY on the provided context
- Focus on the core issues, concerns, or highlights identified
- DO NOT mention specific customer names or personal identifiers
- Group related insights by topic with bullet points
- Be concise and general, highlighting the problem/concern rather than individuals
- Add INLINE CITATIONS at the end of each point using ONLY this format: [Doc X]
- CRITICAL: Citations must be EXACTLY "[Doc 1]", "[Doc 2]", etc. - nothing else
- DO NOT add any other information in citations (no titles, dates, IDs, or sources)
- Number each unique document sequentially (Doc 1, Doc 2, etc.)
CORRECT Example:
• Organizations are planning phone number porting transitions, but custom porting is expensive (~$1,000) and should be done in bulk [Doc 1]
• Questions about additional license requirements for integrations ($45 per user) [Doc 1]
• Ringtone volume issues in embedded Salesforce app [Doc 2]
WRONG Example (DO NOT DO THIS):
• Custom porting costs around $1,000 [Source: JustCall Checkin, Document ID: abc123]
• License fees are $45 per user [JustCall, 2025-10-07]
CRITICAL RULES:
- Generate answer ONLY with the exact context provided above
- Use ONLY [Doc X] format for citations
- DO NOT add a Sources section (it will be appended automatically)
- Keep the answer focused and well-structured with bullet points"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__client = OpenAI(
base_url="https://api.openai.com/v1",
api_key=settings.OPENAI_API_KEY,
)
def forward(self, search_results: str) -> str:
"""Generate final answer with sources based on search results.
This method:
1. Sends ONLY lightweight context to the LLM for answer generation
2. Retrieves pre-formatted sources from MongoDBRetrieverTool._cached_sources
3. Appends sources directly to the LLM's answer
Args:
search_results: Lightweight context from the retriever (NOT including sources)
Returns:
Complete answer with Sources section appended
"""
# Import here to avoid circular dependency
from second_brain_online.application.agents.tools.mongodb_retriever import MongoDBRetrieverTool
# Step 1: Generate answer from LLM using ONLY lightweight context
# This significantly reduces token usage compared to sending full sources
result = self.__client.chat.completions.create(
model=settings.OPENAI_MODEL_ID,
messages=[
{
"role": "system",
"content": "You are an expert analyst. Follow the formatting instructions EXACTLY. Use only [Doc X] citations in the answer section, never include titles, dates, or IDs in citations."
},
{
"role": "user",
"content": self.SYSTEM_PROMPT.format(content=search_results),
},
],
temperature=0.0, # Deterministic output
max_tokens=1500, # Reduced for faster response
timeout=45.0, # Reduced timeout
)
llm_answer = result.choices[0].message.content
# Step 2: Retrieve pre-formatted sources from the retriever's class variable
# These sources were cached during the retrieval step and are NOT sent to LLM
cached_sources = MongoDBRetrieverTool._cached_sources
# Step 3: Append sources directly to the answer
# This ensures sources are included in the final output without being sent to LLM
if cached_sources:
final_answer = f"{llm_answer}\n\n{cached_sources}"
logger.info(f"Appended {len(cached_sources)} chars of sources to {len(llm_answer)} chars of LLM answer")
else:
final_answer = llm_answer
logger.warning("No cached sources found - returning LLM answer only")
return final_answer