|
|
from typing import Literal |
|
|
|
|
|
from langchain.chat_models import init_chat_model |
|
|
from langchain.tools import tool |
|
|
from langchain_core.messages import SystemMessage, ToolMessage, filter_messages, HumanMessage |
|
|
from langgraph.graph import StateGraph, START, END |
|
|
|
|
|
from core.state import MathAgentState, MathAgentOutputState |
|
|
from tools.python_executor import execute_python_code |
|
|
from tools.think_tool import think_tool |
|
|
from utils.prompt_manager import prompt_mgmt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools = [execute_python_code, think_tool] |
|
|
tools_by_name = {tool.name: tool for tool in tools} |
|
|
|
|
|
|
|
|
math_model = init_chat_model(model="openai:gpt-5") |
|
|
model_with_tools = math_model.bind_tools(tools) |
|
|
summarization_model = init_chat_model(model="openai:gpt-4.1-mini") |
|
|
compress_model = init_chat_model(model="openai:gpt-4.1", |
|
|
max_tokens=32000) |
|
|
|
|
|
|
|
|
def llm_call(state: MathAgentState): |
|
|
"""Analyze current state and decide on next actions. |
|
|
|
|
|
The model analyzes the current state and decides whether to: |
|
|
1. Call search tools to gather more information |
|
|
2. Provide a final answer based on gathered information |
|
|
|
|
|
Returns updated state with the model's response. |
|
|
""" |
|
|
return { |
|
|
"messages": [ |
|
|
model_with_tools.invoke( |
|
|
[SystemMessage(content=prompt_mgmt.render_template("math_agent_base_system", {}))] + state[ |
|
|
"messages"] |
|
|
) |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
def tool_node(state: MathAgentState): |
|
|
"""Execute all tool calls from the previous LLM response. |
|
|
|
|
|
Executes all tool calls from the previous LLM responses. |
|
|
Returns updated state with tool execution results. |
|
|
""" |
|
|
tool_calls = state["messages"][-1].tool_calls |
|
|
|
|
|
|
|
|
observations = [] |
|
|
for tool_call in tool_calls: |
|
|
tool = tools_by_name[tool_call["name"]] |
|
|
observations.append(tool.invoke(tool_call["args"])) |
|
|
|
|
|
|
|
|
tool_outputs = [ |
|
|
ToolMessage( |
|
|
content=observation, |
|
|
name=tool_call["name"], |
|
|
tool_call_id=tool_call["id"] |
|
|
) for observation, tool_call in zip(observations, tool_calls) |
|
|
] |
|
|
|
|
|
return {"messages": tool_outputs} |
|
|
|
|
|
|
|
|
def compress_research(state: MathAgentState) -> dict: |
|
|
"""Compress research findings into a concise summary. |
|
|
|
|
|
Takes all the research messages and tool outputs and creates |
|
|
a compressed summary suitable for the supervisor's decision-making. |
|
|
""" |
|
|
|
|
|
last_message = state.get("messages", [])[-1] |
|
|
|
|
|
|
|
|
raw_notes = [ |
|
|
str(m.content) for m in filter_messages( |
|
|
state["messages"], |
|
|
include_types=["tool", "ai"] |
|
|
) |
|
|
] |
|
|
|
|
|
return { |
|
|
"compressed_research": str(last_message.content), |
|
|
"raw_notes": ["\n".join(raw_notes)] |
|
|
} |
|
|
|
|
|
|
|
|
def should_continue(state: MathAgentState) -> Literal["tool_node", "compress_research"]: |
|
|
"""Determine whether to continue the solving of the problem or provide final answer. |
|
|
|
|
|
Determines whether the agent should continue the solving loop or provide |
|
|
a final answer based on whether the LLM made tool calls. |
|
|
|
|
|
Returns: |
|
|
"tool_node": Continue to tool execution |
|
|
"compress_research": Stop and compress research |
|
|
""" |
|
|
messages = state["messages"] |
|
|
last_message = messages[-1] |
|
|
|
|
|
|
|
|
if last_message.tool_calls: |
|
|
return "tool_node" |
|
|
|
|
|
return "compress_research" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent_builder = StateGraph(MathAgentState, output_schema=MathAgentOutputState) |
|
|
|
|
|
|
|
|
agent_builder.add_node("llm_call", llm_call) |
|
|
agent_builder.add_node("tool_node", tool_node) |
|
|
agent_builder.add_node("compress_research", compress_research) |
|
|
|
|
|
|
|
|
agent_builder.add_edge(START, "llm_call") |
|
|
agent_builder.add_conditional_edges( |
|
|
"llm_call", |
|
|
should_continue, |
|
|
{ |
|
|
"tool_node": "tool_node", |
|
|
"compress_research": "compress_research", |
|
|
}, |
|
|
) |
|
|
agent_builder.add_edge("tool_node", "llm_call") |
|
|
agent_builder.add_edge("compress_research", END) |
|
|
|
|
|
|
|
|
math_agent = agent_builder.compile() |
|
|
|
|
|
|
|
|
@tool |
|
|
def math_tool(problem: str): |
|
|
""" |
|
|
Tool for solving a mathematical problem |
|
|
:param problem: The problem to be solved |
|
|
:return: the solution to the given problem |
|
|
""" |
|
|
response = math_agent.invoke({"messages": [HumanMessage(content=problem)], "question": problem}) |
|
|
|
|
|
|
|
|
|
|
|
return response['compressed_research'] |
|
|
|