|
|
import os |
|
|
import logging |
|
|
from agents.tools_and_schemas import SearchQueryList, Reflection |
|
|
from dotenv import load_dotenv |
|
|
from langchain_core.messages import AIMessage |
|
|
from langgraph.types import Send |
|
|
from langgraph.graph import StateGraph |
|
|
from langgraph.graph import START, END |
|
|
from langchain_core.runnables import RunnableConfig |
|
|
from google.genai import Client |
|
|
|
|
|
from agents.state import ( |
|
|
OverallState, |
|
|
QueryGenerationState, |
|
|
ReflectionState, |
|
|
WebSearchState, |
|
|
) |
|
|
from agents.configuration import Configuration |
|
|
from agents.prompts import ( |
|
|
get_current_date, |
|
|
query_writer_instructions, |
|
|
web_searcher_instructions, |
|
|
reflection_instructions, |
|
|
answer_instructions, |
|
|
gaia_system_instructions, |
|
|
) |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from agents.utils import ( |
|
|
get_citations, |
|
|
get_research_topic, |
|
|
insert_citation_markers, |
|
|
resolve_urls, |
|
|
) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
if os.getenv("GEMINI_API_KEY") is None: |
|
|
raise ValueError("GEMINI_API_KEY is not set") |
|
|
|
|
|
|
|
|
genai_client = Client(api_key=os.getenv("GEMINI_API_KEY")) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerationState: |
|
|
"""LangGraph node that generates a search queries based on the User's question. |
|
|
|
|
|
Uses Gemini 2.0 Flash to create an optimized search query for web research based on |
|
|
the User's question. |
|
|
|
|
|
Args: |
|
|
state: Current graph state containing the User's question |
|
|
config: Configuration for the runnable, including LLM provider settings |
|
|
|
|
|
Returns: |
|
|
Dictionary with state update, including search_query key containing the generated query |
|
|
""" |
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
|
|
|
if state.get("initial_search_query_count") is None: |
|
|
state["initial_search_query_count"] = configurable.number_of_initial_queries |
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
|
model=configurable.query_generator_model, |
|
|
temperature=2.0, |
|
|
max_retries=2, |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
) |
|
|
structured_llm = llm.with_structured_output(SearchQueryList) |
|
|
|
|
|
|
|
|
current_date = get_current_date() |
|
|
formatted_prompt = query_writer_instructions.format( |
|
|
current_date=current_date, |
|
|
research_topic=get_research_topic(state["messages"]), |
|
|
number_queries=state["initial_search_query_count"], |
|
|
) |
|
|
|
|
|
result = structured_llm.invoke(formatted_prompt) |
|
|
return {"query_list": result.query} |
|
|
|
|
|
|
|
|
def continue_to_web_research(state: QueryGenerationState): |
|
|
"""LangGraph node that sends the search queries to the web research node. |
|
|
|
|
|
This is used to spawn n number of web research nodes, one for each search query. |
|
|
""" |
|
|
return [ |
|
|
Send("web_research", {"search_query": search_query, "id": int(idx)}) |
|
|
for idx, search_query in enumerate(state["query_list"]) |
|
|
] |
|
|
|
|
|
|
|
|
def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState: |
|
|
"""LangGraph node that performs web research using the native Google Search API tool. |
|
|
|
|
|
Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash. |
|
|
|
|
|
Args: |
|
|
state: Current graph state containing the search query and research loop count |
|
|
config: Configuration for the runnable, including search API settings |
|
|
|
|
|
Returns: |
|
|
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results |
|
|
""" |
|
|
|
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
formatted_prompt = web_searcher_instructions.format( |
|
|
current_date=get_current_date(), |
|
|
research_topic=state["search_query"], |
|
|
) |
|
|
|
|
|
|
|
|
response = genai_client.models.generate_content( |
|
|
model=configurable.query_generator_model, |
|
|
contents=formatted_prompt, |
|
|
config={ |
|
|
"tools": [{"google_search": {}}], |
|
|
"temperature": 0, |
|
|
}, |
|
|
) |
|
|
|
|
|
resolved_urls = resolve_urls( |
|
|
response.candidates[0].grounding_metadata.grounding_chunks, state["id"] |
|
|
) |
|
|
|
|
|
citations = get_citations(response, resolved_urls) |
|
|
modified_text = insert_citation_markers(response.text, citations) |
|
|
sources_gathered = [item for citation in citations for item in citation["segments"]] |
|
|
|
|
|
return { |
|
|
"sources_gathered": sources_gathered, |
|
|
"search_query": [state["search_query"]], |
|
|
"web_research_result": [modified_text], |
|
|
"web_research_result": [response.text], |
|
|
} |
|
|
|
|
|
|
|
|
def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState: |
|
|
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries. |
|
|
|
|
|
Analyzes the current summary to identify areas for further research and generates |
|
|
potential follow-up queries. Uses structured output to extract |
|
|
the follow-up query in JSON format. |
|
|
|
|
|
Args: |
|
|
state: Current graph state containing the running summary and research topic |
|
|
config: Configuration for the runnable, including LLM provider settings |
|
|
|
|
|
Returns: |
|
|
Dictionary with state update, including search_query key containing the generated follow-up query |
|
|
""" |
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
|
|
|
state["research_loop_count"] = state.get("research_loop_count", 0) + 1 |
|
|
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model |
|
|
|
|
|
|
|
|
current_date = get_current_date() |
|
|
formatted_prompt = reflection_instructions.format( |
|
|
current_date=current_date, |
|
|
research_topic=get_research_topic(state["messages"]), |
|
|
summaries="\n\n---\n\n".join(state["web_research_result"]), |
|
|
) |
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
|
model=reasoning_model, |
|
|
temperature=0.3, |
|
|
max_retries=2, |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
) |
|
|
logger.info(f"Reflection node invoked with research prompt:\n{formatted_prompt}") |
|
|
result = llm.with_structured_output(Reflection).invoke(formatted_prompt) |
|
|
|
|
|
return { |
|
|
"is_sufficient": result.is_sufficient, |
|
|
"knowledge_gap": result.knowledge_gap, |
|
|
"follow_up_queries": result.follow_up_queries, |
|
|
"research_loop_count": state["research_loop_count"], |
|
|
"number_of_ran_queries": len(state["search_query"]), |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_research( |
|
|
state: ReflectionState, |
|
|
config: RunnableConfig, |
|
|
) -> OverallState: |
|
|
"""LangGraph routing function that determines the next step in the research flow. |
|
|
|
|
|
Controls the research loop by deciding whether to continue gathering information |
|
|
or to finalize the summary based on the configured maximum number of research loops. |
|
|
|
|
|
Args: |
|
|
state: Current graph state containing the research loop count |
|
|
config: Configuration for the runnable, including max_research_loops setting |
|
|
|
|
|
Returns: |
|
|
String literal indicating the next node to visit ("web_research" or "finalize_summary") |
|
|
""" |
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
max_research_loops = ( |
|
|
state.get("max_research_loops") |
|
|
if state.get("max_research_loops") is not None |
|
|
else configurable.max_research_loops |
|
|
) |
|
|
if state["is_sufficient"] or state["research_loop_count"] >= max_research_loops: |
|
|
return "finalize_answer" |
|
|
else: |
|
|
return [ |
|
|
Send( |
|
|
"web_research", |
|
|
{ |
|
|
"search_query": follow_up_query, |
|
|
"id": state["number_of_ran_queries"] + int(idx), |
|
|
}, |
|
|
) |
|
|
for idx, follow_up_query in enumerate(state["follow_up_queries"]) |
|
|
] |
|
|
|
|
|
|
|
|
def finalize_answer(state: OverallState, config: RunnableConfig): |
|
|
"""LangGraph node that finalizes the research summary. |
|
|
|
|
|
Prepares the final output by deduplicating and formatting sources, then |
|
|
combining them with the running summary to create a well-structured |
|
|
research report with proper citations. |
|
|
|
|
|
Args: |
|
|
state: Current graph state containing the running summary and sources gathered |
|
|
|
|
|
Returns: |
|
|
Dictionary with state update, including running_summary key containing the formatted final summary with sources |
|
|
""" |
|
|
configurable = Configuration.from_runnable_config(config) |
|
|
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model |
|
|
|
|
|
|
|
|
current_date = get_current_date() |
|
|
formatted_prompt = answer_instructions.format( |
|
|
current_date=current_date, |
|
|
research_topic=get_research_topic(state["messages"]), |
|
|
summaries="\n---\n\n".join(state["web_research_result"]), |
|
|
) |
|
|
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI( |
|
|
model=reasoning_model, |
|
|
temperature=0, |
|
|
max_retries=5, |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
) |
|
|
result = llm.invoke(formatted_prompt) |
|
|
|
|
|
|
|
|
gaia_question = state["messages"][-1].content |
|
|
messages = [ |
|
|
("system", gaia_system_instructions), |
|
|
("user", f"Context: {result.content}\nQuestion: {gaia_question}"), |
|
|
] |
|
|
gaia_result = llm.invoke(messages) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
|
|
|
"messages": [gaia_result], |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def build_graph(): |
|
|
|
|
|
builder = StateGraph(OverallState, config_schema=Configuration) |
|
|
|
|
|
|
|
|
builder.add_node("generate_query", generate_query) |
|
|
builder.add_node("web_research", web_research) |
|
|
builder.add_node("reflection", reflection) |
|
|
builder.add_node("finalize_answer", finalize_answer) |
|
|
builder.add_node("evaluate_research", evaluate_research) |
|
|
|
|
|
|
|
|
|
|
|
builder.add_edge(START, "generate_query") |
|
|
|
|
|
builder.add_conditional_edges( |
|
|
"generate_query", continue_to_web_research, ["web_research"] |
|
|
) |
|
|
|
|
|
builder.add_edge("web_research", "reflection") |
|
|
|
|
|
builder.add_conditional_edges( |
|
|
"reflection", evaluate_research, ["web_research", "finalize_answer"] |
|
|
) |
|
|
|
|
|
builder.add_edge("finalize_answer", END) |
|
|
|
|
|
graph = builder.compile(name="pro-search-agent") |
|
|
return graph |
|
|
|