HF_Agents_Course_GAIA_Agent / src /agents /langgraph_agent_v2.py
agercas's picture
add agents
1ffaf53
from collections.abc import Sequence
from typing import Annotated, Literal
from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from src.agents.models import FeasibilityCheck, FinalAnswer, FinalConclusion, NextStep
from src.agents.prompts import GAIAPrompts
from src.agents.tools import tools
# Initialize
model = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
model_with_tools = model.bind_tools(tools)
tools_by_name = {tool.name: tool for tool in tools}
prompts = GAIAPrompts()
# Graph state
class GraphState(BaseModel):
"""The state of the graph"""
# History
history: Annotated[Sequence[BaseMessage], add_messages] = Field(
default_factory=list
) # Complete history with node info
coordinator_messages: Annotated[Sequence[BaseMessage], add_messages] = Field(
default_factory=list
) # Coordinator-specific messages
executor_messages: Sequence[BaseMessage] = Field(default_factory=list) # Executor-specific messages
# Input
question: str
# Feasibility check
feasibility: FeasibilityCheck | None = None
# Coordinator state
next_step: NextStep | None = None
coordinator_conclusion: FinalConclusion | None = None
coordinator_iterations: int
coordinator_max_iterations: int
# Executor state
executor_conclusion: FinalConclusion | None = None
executor_iterations: int
executor_max_iterations: int
# Final answer state
final_answer: FinalAnswer | None = None
def __getitem__(self, item):
return getattr(self, item)
# Nodes
def check_feasibility(state: GraphState, config: RunnableConfig):
"""Check if the question is feasible to answer with the available tools"""
question = state["question"]
system_message = SystemMessage(content=prompts.get_feasibility_check_prompt(tools), node="feasibility")
question_message = HumanMessage(content=question, node="feasibility")
messages = [system_message, question_message]
structured_model = model.with_structured_output(FeasibilityCheck)
response = structured_model.invoke(messages, config)
response_message = AIMessage(content=str(response), node="feasibility")
messages += [response_message]
return {
"history": messages,
"feasibility": response,
}
def coordinator_node(state: GraphState, config: RunnableConfig):
"""Determine the next step in the plan and select appropriate tools"""
coordinator_messages = state["coordinator_messages"]
new_messages = []
if not coordinator_messages:
system_message = SystemMessage(content=prompts.get_coordinator_system_prompt(tools), node="coordinator")
human_message = HumanMessage(
content=prompts.get_coordinator_context_prompt(state["question"]), node="coordinator"
)
coordinator_messages = [system_message, human_message]
new_messages = coordinator_messages
if state["executor_conclusion"]:
executor_message = AIMessage(
content=f"Executor conclusion: {state['executor_conclusion'].conclusion}. Complete text: {str(state['executor_conclusion'])}",
node="executor",
)
coordinator_messages += [executor_message]
new_messages += [executor_message]
# Check if we've reached max iterations
if (state["next_step"] and state["next_step"].is_final) or (
state["coordinator_iterations"] >= state["coordinator_max_iterations"]
):
# Generate final conclusion instead of next step
human_message = HumanMessage(
content=prompts.get_coordinator_max_iterations_prompt(state["question"]), node="coordinator"
)
structured_model = model.with_structured_output(FinalConclusion)
response = structured_model.invoke(coordinator_messages + [human_message], config)
response_message = AIMessage(content=str(response), node="coordinator")
new_messages += [human_message, response_message]
return {
"history": new_messages,
"coordinator_messages": new_messages,
"coordinator_conclusion": response,
"coordinator_iterations": state["coordinator_iterations"] + 1,
}
structured_model = model.with_structured_output(NextStep)
response = structured_model.invoke(coordinator_messages, config)
response_message = AIMessage(content=str(response), node="coordinator")
new_messages += [response_message]
return {
"history": new_messages,
"coordinator_messages": new_messages,
"coordinator_iterations": state["coordinator_iterations"] + 1,
"next_step": response,
"executor_messages": [],
"executor_conclusion": None,
"executor_iterations": 0,
}
def executor_node(state: GraphState, config: RunnableConfig):
"""Plan the execution of the current step using ReAct pattern"""
if not state["next_step"]:
return {
"executor_conclusion": FinalConclusion(conclusion="No next step", partial_results=""),
"executor_iterations": state["executor_iterations"] + 1,
}
messages = state["executor_messages"]
if not messages:
system_message = SystemMessage(
content=prompts.get_executor_system_prompt(state["next_step"].tools),
node="executor",
)
human_message = HumanMessage(content=prompts.get_executor_task_prompt(state["next_step"].step), node="executor")
messages = [system_message, human_message]
if state["executor_iterations"] >= state["executor_max_iterations"]:
# Generate final conclusion and return to coordinator
human_message = HumanMessage(
content=prompts.get_executor_max_iterations_prompt(state["next_step"].step),
node="executor",
)
messages += [human_message]
structured_model = model.with_structured_output(FinalConclusion)
response = structured_model.invoke(messages, config)
response_message = AIMessage(
content=f"Executor conclusion: {str(response)}",
node="executor",
)
return {
"history": [human_message, response_message],
"executor_conclusion": response
or FinalConclusion(conclusion="Failed to generate conclusion", partial_results=""),
"executor_iterations": state["executor_iterations"] + 1,
}
selected_tools = [tool for tool in tools if tool.name in state["next_step"].tools]
model_with_selected_tools = model.bind_tools(selected_tools)
response_message = model_with_selected_tools.invoke(messages, config)
response_message.node = "executor"
return {
"history": response_message,
"executor_messages": messages + [response_message],
"executor_iterations": state["executor_iterations"] + 1,
}
def tool_node(state: GraphState):
"""Execute tools based on the last message's tool calls"""
outputs = []
messages = state["executor_messages"]
last_message = state["executor_messages"][-1]
for tool_call in last_message.tool_calls:
try:
tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
tool_message = ToolMessage(
content=str(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
node="tools",
)
outputs.append(tool_message)
except Exception as e:
tool_message = ToolMessage(
content=f"Error executing tool {tool_call['name']}: {str(e)}",
name=tool_call["name"],
tool_call_id=tool_call["id"],
node="tools",
)
outputs.append(tool_message)
return {
"history": outputs,
"executor_messages": messages + outputs,
}
def finalise(state: GraphState, config: RunnableConfig):
"""Generate the final answer based on coordinator history"""
system_message = SystemMessage(content=prompts.get_finalizer_prompt(), node="finalise")
messages = [system_message] + state["coordinator_messages"]
structured_model = model.with_structured_output(FinalAnswer)
response = structured_model.invoke(messages, config)
response_message = AIMessage(content=str(response), node="finalise")
return {"history": response_message, "final_answer": response}
# Edges
def should_continue_after_feasibility(state: GraphState) -> Literal["coordinator", END]:
"""Decide whether to continue with coordination or end"""
if state["feasibility"] and state["feasibility"].feasible:
return "coordinator"
return END
def should_continue_after_coordinator(state: GraphState) -> Literal["executor", "finalise"]:
"""Decide whether to continue with execution or go to final answer"""
if state["coordinator_conclusion"] or (state["coordinator_iterations"] >= state["coordinator_max_iterations"]):
return "finalise"
return "executor"
def should_continue_after_executor(state: GraphState) -> Literal["tools", "coordinator", "executor"]:
"""Decide whether to continue with tools or go back to coordinator"""
last_message = state["executor_messages"][-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
if state["executor_conclusion"]:
return "coordinator"
return "executor"
def should_continue_after_tools(state: GraphState) -> Literal["executor"]:
"""Tools always go back to executor"""
return "executor"
# Graph
def build_graph():
"""Build the graph"""
graph = StateGraph(GraphState)
# Add nodes
graph.add_node("check_feasibility", check_feasibility)
graph.add_node("coordinator", coordinator_node)
graph.add_node("executor", executor_node)
graph.add_node("tools", tool_node)
graph.add_node("finalise", finalise)
# Set entry point
graph.set_entry_point("check_feasibility")
# Add edges
graph.add_conditional_edges(
"check_feasibility", should_continue_after_feasibility, {"coordinator": "coordinator", END: END}
)
graph.add_conditional_edges(
"coordinator", should_continue_after_coordinator, {"executor": "executor", "finalise": "finalise"}
)
graph.add_conditional_edges(
"executor",
should_continue_after_executor,
{"executor": "executor", "tools": "tools", "coordinator": "coordinator"},
)
graph.add_conditional_edges(
"tools",
should_continue_after_tools,
{"executor": "executor"},
)
# Finalise node goes to END
graph.add_edge("finalise", END)
return graph.compile()
def run_agent(question: str, coordinator_max_iterations: int = 5, executor_max_iterations: int = 3):
"""Run the agent with a question"""
graph = build_graph()
initial_state = {
"question": question,
"history": [],
"coordinator_messages": [],
"executor_messages": [],
"coordinator_iterations": 0,
"executor_iterations": 0,
"coordinator_max_iterations": coordinator_max_iterations,
"executor_max_iterations": executor_max_iterations,
}
# Stream the execution
print(f"Question: {question}")
print("=" * 50)
for step in graph.stream(initial_state):
for node, output in step.items():
print(f"\n--- {node.upper()} ---")
# Print history with node information
if "history" in output and output["history"]:
print("\nComplete History (with node info):")
for msg in output["history"]:
node_info = getattr(msg, "node", "unknown") if hasattr(msg, "node") else "unknown"
content = getattr(msg, "content", str(msg)) if hasattr(msg, "content") else str(msg)
print(f"[{node_info}] {msg.__class__.__name__}: {content}")
if "coordinator_messages" in output and output["coordinator_messages"]:
print("\nCoordinator Messages:")
for msg in output["coordinator_messages"]:
if hasattr(msg, "content"):
print(f"{msg.__class__.__name__}: {msg.content}")
if "executor_messages" in output and output["executor_messages"]:
print("\nExecutor Messages:")
for msg in output["executor_messages"]:
if hasattr(msg, "content"):
print(f"{msg.__class__.__name__}: {msg.content}")
if "executor_conclusion" in output and output["executor_conclusion"]:
print("\n=== EXECUTOR CONCLUSION ===")
print(f"Conclusion: {output['executor_conclusion'].conclusion}")
print(f"Partial Results: {output['executor_conclusion'].partial_results}")
print(f"Confidence: {output['executor_conclusion'].confidence}")
if "final_answer" in output and output["final_answer"]:
print("\n=== FINAL ANSWER ===")
print(f"Answer: {output['final_answer'].answer}")
print(f"Confidence: {output['final_answer'].confidence}")
print(f"Reasoning: {output['final_answer'].reasoning}")