backend_port / deep_research.py
Aakash jammula
:deep reasearch
5e6bfdb
import re
import json
import operator
from duckduckgo_search import DDGS
from typing_extensions import Literal
from dataclasses import dataclass, field
from langgraph.graph import START, END, StateGraph
from typing_extensions import TypedDict, Annotated
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage
from typing import Any, List ,Dict, Any, List, Optional
import os
from tavily import TavilyClient
# Initialize Tavily client
tavily_api_key = os.getenv("TAVILY_API_KEY")
tavily_client = TavilyClient(api_key=tavily_api_key)
max_web_research_loops=3
fetch_full_page: bool =False
llm_json_mode = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
model_kwargs={"format": "json"},
google_api_key = os.getenv("GOOGLE_API_KEY")
)
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash",temperature=0,google_api_key = os.getenv("GOOGLE_API_KEY"))
query_writer_instructions="""Your goal is to generate a targeted web search query.
The query will gather information related to a specific topic.
<TOPIC>
{research_topic}
</TOPIC>
<FORMAT>
Format your response as a JSON object with ALL three of these exact keys:
- "query": The actual search query string
- "aspect": The specific aspect of the topic being researched
- "rationale": Brief explanation of why this query is relevant
</FORMAT>
<EXAMPLE>
Example output:
{{
"query": "machine learning transformer architecture explained",
"aspect": "technical architecture",
"rationale": "Understanding the fundamental structure of transformer models"
}}
</EXAMPLE>
Provide your response in JSON format:"""
summarizer_instructions="""
<GOAL>
Generate a high-quality summary of the web search results and keep it concise / related to the user topic.
</GOAL>
<REQUIREMENTS>
When creating a NEW summary:
1. Highlight the most relevant information related to the user topic from the search results
2. Ensure a coherent flow of information
When EXTENDING an existing summary:
1. Read the existing summary and new search results carefully.
2. Compare the new information with the existing summary.
3. For each piece of new information:
a. If it's related to existing points, integrate it into the relevant paragraph.
b. If it's entirely new but relevant, add a new paragraph with a smooth transition.
c. If it's not relevant to the user topic, skip it.
4. Ensure all additions are relevant to the user's topic.
5. Verify that your final output differs from the input summary.
< /REQUIREMENTS >
< FORMATTING >
- Start directly with the updated summary, without preamble or titles. Do not use XML tags in the output.
< /FORMATTING >"""
reflection_instructions = """You are an expert research assistant analyzing a summary about {research_topic}.
<GOAL>
1. Identify knowledge gaps or areas that need deeper exploration
2. Generate a follow-up question that would help expand your understanding
3. Focus on technical details, implementation specifics, or emerging trends that weren't fully covered
</GOAL>
<REQUIREMENTS>
Ensure the follow-up question is self-contained and includes necessary context for web search.
</REQUIREMENTS>
<FORMAT>
Format your response as a JSON object with these exact keys:
- knowledge_gap: Describe what information is missing or needs clarification
- follow_up_query: Write a specific question to address this gap
</FORMAT>
<EXAMPLE>
Example output:
{{
"knowledge_gap": "The summary lacks information about performance metrics and benchmarks",
"follow_up_query": "What are typical performance benchmarks and metrics used to evaluate [specific technology]?"
}}
</EXAMPLE>
Provide your analysis in JSON format:"""
def deduplicate_and_format_sources(search_response, max_tokens_per_source, include_raw_content=False):
"""
Takes either a single search response or list of responses from search APIs and formats them.
Limits the raw_content to approximately max_tokens_per_source.
include_raw_content specifies whether to include the raw_content from Tavily in the formatted string.
Args:
search_response: Either:
- A dict with a 'results' key containing a list of search results
- A list of dicts, each containing search results
Returns:
str: Formatted string with deduplicated sources
"""
# Convert input to list of results
if isinstance(search_response, dict):
sources_list = search_response['results']
elif isinstance(search_response, list):
sources_list = []
for response in search_response:
if isinstance(response, dict) and 'results' in response:
sources_list.extend(response['results'])
else:
sources_list.extend(response)
else:
raise ValueError("Input must be either a dict with 'results' or a list of search results")
# Deduplicate by URL
unique_sources = {}
for source in sources_list:
if source['url'] not in unique_sources:
unique_sources[source['url']] = source
# Format output
formatted_text = "Sources:\n\n"
for i, source in enumerate(unique_sources.values(), 1):
formatted_text += f"Source {source['title']}:\n===\n"
formatted_text += f"URL: {source['url']}\n===\n"
formatted_text += f"Most relevant content from source: {source['content']}\n===\n"
if include_raw_content:
# Using rough estimate of 4 characters per token
char_limit = max_tokens_per_source * 4
# Handle None raw_content
raw_content = source.get('raw_content', '')
if raw_content is None:
raw_content = ''
print(f"Warning: No raw_content found for source {source['url']}")
if len(raw_content) > char_limit:
raw_content = raw_content[:char_limit] + "... [truncated]"
formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n"
return formatted_text.strip()
def format_sources(search_results):
"""Format search results into a bullet-point list of sources.
Args:
search_results (dict): Tavily search response containing results
Returns:
str: Formatted string with sources and their URLs
"""
return '\n'.join(
f"* {source['title']} : {source['url']}"
for source in search_results['results']
)
def tavily_search(query: str, max_results: int = 3, fetch_full_page: bool = False) -> Dict[str, List[Dict[str, str]]]:
"""Search the web using Tavily.
Args:
query (str): The search query to execute
max_results (int): Maximum number of results to return
Returns:
dict: Search response containing:
- results (list): List of search result dictionaries, each containing:
- title (str): Title of the search result
- url (str): URL of the search result
- content (str): Snippet/summary of the content
- raw_content (str): Full content if available, else same as content
"""
try:
response = tavily_client.search(query=query, max_results=max_results, include_raw_content=fetch_full_page)
results = []
for r in response["results"]:
result = {
"title": r.get("title"),
"url": r.get("url"),
"content": r.get("content"),
"raw_content": r.get("raw_content", r.get("content"))
}
results.append(result)
return {"results": results}
except Exception as e:
print(f"Error in Tavily search: {str(e)}")
return {"results": []}
@dataclass(kw_only=True)
class SummaryState:
research_topic: str = field(default=None) # Report topic
search_query: str = field(default=None) # Search query
web_research_results: Annotated[list, operator.add] = field(default_factory=list)
sources_gathered: Annotated[list, operator.add] = field(default_factory=list)
research_loop_count: int = field(default=0) # Research loop count
running_summary: str = field(default=None) # Final report
@dataclass(kw_only=True)
class SummaryStateInput:
research_topic: str = field(default=None) # Report topic
@dataclass(kw_only=True)
class SummaryStateOutput:
running_summary: str = field(default=None) # Final report
# Nodes
def clean_json_response(response_content: str) -> str:
"""Remove leading/trailing backticks, whitespace, and newlines from JSON response."""
# Remove Markdown-style code block backticks if present
response_content = re.sub(r'^```json\s*|\s*```$', '', response_content, flags=re.DOTALL)
# Strip leading/trailing whitespace and newlines
response_content = response_content.strip()
return response_content
def generate_query(state: SummaryState):
""" Generate a query for web search """
# Format the prompt
query_writer_instructions_formatted = query_writer_instructions.format(research_topic=state.research_topic)
llm_json_mode = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
model_kwargs={"format": "json"},google_api_key = os.getenv("GOOGLE_API_KEY"))
result = llm_json_mode.invoke(
[SystemMessage(content=query_writer_instructions_formatted),
HumanMessage(content=f"Generate a query for web search:")]
)
# Clean the response content
cleaned_content = clean_json_response(result.content)
try:
query = json.loads(cleaned_content)
return {"search_query": query['query']}
except json.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
print(f"Response content: {cleaned_content}")
return {"search_query": f"Tell me more about {state.research_topic}"} # Fallback query
def reflect_on_summary(state: SummaryState):
""" Reflect on the summary and generate a follow-up query """
# Generate a query
result = llm_json_mode.invoke(
[SystemMessage(content=reflection_instructions.format(research_topic=state.research_topic)),
HumanMessage(content=f"Identify a knowledge gap and generate a follow-up web search query based on our existing knowledge: {state.running_summary}")]
)
# Clean the response content
cleaned_content = clean_json_response(result.content)
try:
follow_up_query = json.loads(cleaned_content)
return {"search_query": follow_up_query['follow_up_query']}
except json.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
print(f"Response content: {cleaned_content}")
return {"search_query": f"Tell me more about {state.research_topic}"} # Fallback query
def web_research(state: SummaryState):
search_results = tavily_search(state.search_query, max_results=3, fetch_full_page=fetch_full_page)
search_str = deduplicate_and_format_sources(search_results, max_tokens_per_source=1000, include_raw_content=True)
return {"sources_gathered": [format_sources(search_results)], "research_loop_count": state.research_loop_count + 1, "web_research_results": [search_str]}
def summarize_sources(state: SummaryState):
""" Summarize the gathered sources """
# Existing summary
existing_summary = state.running_summary
# Most recent web research
most_recent_web_research = state.web_research_results[-1]
# Build the human message
if existing_summary:
human_message_content = (
f"<User Input> \n {state.research_topic} \n <User Input>\n\n"
f"<Existing Summary> \n {existing_summary} \n <Existing Summary>\n\n"
f"<New Search Results> \n {most_recent_web_research} \n <New Search Results>"
)
else:
human_message_content = (
f"<User Input> \n {state.research_topic} \n <User Input>\n\n"
f"<Search Results> \n {most_recent_web_research} \n <Search Results>"
)
result = llm.invoke(
[SystemMessage(content=summarizer_instructions),
HumanMessage(content=human_message_content)]
)
running_summary = result.content
# TODO: This is a hack to remove the <think> tags w/ Deepseek models
# It appears very challenging to prompt them out of the responses
while "<think>" in running_summary and "</think>" in running_summary:
start = running_summary.find("<think>")
end = running_summary.find("</think>") + len("</think>")
running_summary = running_summary[:start] + running_summary[end:]
return {"running_summary": running_summary}
def finalize_summary(state: SummaryState):
""" Finalize the summary """
# Format all accumulated sources into a single bulleted list
all_sources = "\n".join(source for source in state.sources_gathered)
state.running_summary = f"## Summary\n\n{state.running_summary}\n\n ### Sources:\n{all_sources}"
return {"running_summary": state.running_summary}
def route_research(state: SummaryState) -> Literal["finalize_summary", "web_research"]:
""" Route the research based on the follow-up query """
if state.research_loop_count <= int(max_web_research_loops):
return "web_research"
else:
return "finalize_summary"
# Add nodes and edges
builder = StateGraph(SummaryState, input=SummaryStateInput, output=SummaryStateOutput)
builder.add_node("generate_query", generate_query)
builder.add_node("web_research", web_research)
builder.add_node("summarize_sources", summarize_sources)
builder.add_node("reflect_on_summary", reflect_on_summary)
builder.add_node("finalize_summary", finalize_summary)
# Add edges
builder.add_edge(START, "generate_query")
builder.add_edge("generate_query", "web_research")
builder.add_edge("web_research", "summarize_sources")
builder.add_edge("summarize_sources", "reflect_on_summary")
builder.add_conditional_edges("reflect_on_summary", route_research)
builder.add_edge("finalize_summary", END)
graph = builder.compile()
if __name__ == "__main__":
# Define the research topic
research_topic = "gemma3 architecture"
# Create the input state
input_state = SummaryStateInput(research_topic=research_topic)
# Run the graph
result = graph.invoke(input_state)
print(result['running_summary'])