| """Define a custom Reasoning and Action agent.
|
|
|
| Works with a chat model with tool calling support.
|
| """
|
|
|
| from datetime import UTC, datetime
|
| from typing import Dict, List, Literal, cast
|
|
|
| from langchain_core.messages import AIMessage
|
| from langgraph.graph import StateGraph
|
| from langgraph.prebuilt import ToolNode
|
| from langgraph.runtime import Runtime
|
|
|
| from react_agent.context import Context
|
| from react_agent.state import InputState, State
|
| from react_agent.tools import TOOLS
|
| from react_agent.utils import load_chat_model
|
|
|
|
|
|
|
|
|
| async def call_model(
|
| state: State, runtime: Runtime[Context]
|
| ) -> Dict[str, List[AIMessage]]:
|
| """Call the LLM powering our "agent".
|
|
|
| This function prepares the prompt, initializes the model, and processes the response.
|
|
|
| Args:
|
| state (State): The current state of the conversation.
|
| config (RunnableConfig): Configuration for the model run.
|
|
|
| Returns:
|
| dict: A dictionary containing the model's response message.
|
| """
|
|
|
| model = load_chat_model(runtime.context.model).bind_tools(TOOLS)
|
|
|
|
|
| system_message = runtime.context.system_prompt.format(
|
| system_time=datetime.now(tz=UTC).isoformat()
|
| )
|
|
|
|
|
| response = cast(
|
| AIMessage,
|
| await model.ainvoke(
|
| [{"role": "system", "content": system_message}, *state.messages]
|
| ),
|
| )
|
|
|
|
|
| if state.is_last_step and response.tool_calls:
|
| return {
|
| "messages": [
|
| AIMessage(
|
| id=response.id,
|
| content="Sorry, I could not find an answer to your question in the specified number of steps.",
|
| )
|
| ]
|
| }
|
|
|
|
|
| return {"messages": [response]}
|
|
|
|
|
|
|
|
|
| builder = StateGraph(State, input_schema=InputState, context_schema=Context)
|
|
|
|
|
| builder.add_node(call_model)
|
| builder.add_node("tools", ToolNode(TOOLS))
|
|
|
|
|
|
|
| builder.add_edge("__start__", "call_model")
|
|
|
|
|
| def route_model_output(state: State) -> Literal["__end__", "tools"]:
|
| """Determine the next node based on the model's output.
|
|
|
| This function checks if the model's last message contains tool calls.
|
|
|
| Args:
|
| state (State): The current state of the conversation.
|
|
|
| Returns:
|
| str: The name of the next node to call ("__end__" or "tools").
|
| """
|
| last_message = state.messages[-1]
|
| if not isinstance(last_message, AIMessage):
|
| raise ValueError(
|
| f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
|
| )
|
|
|
| if not last_message.tool_calls:
|
| return "__end__"
|
|
|
| return "tools"
|
|
|
|
|
|
|
| builder.add_conditional_edges(
|
| "call_model",
|
|
|
|
|
| route_model_output,
|
| )
|
|
|
|
|
|
|
| builder.add_edge("tools", "call_model")
|
|
|
|
|
| graph = builder.compile(name="ReAct Agent")
|
|
|