File size: 3,759 Bytes
8b2d981 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """Define a custom Reasoning and Action agent.
Works with a chat model with tool calling support.
"""
from datetime import UTC, datetime
from typing import Dict, List, Literal, cast
from langchain_core.messages import AIMessage
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.runtime import Runtime
from react_agent.context import Context
from react_agent.state import InputState, State
from react_agent.tools import TOOLS
from react_agent.utils import load_chat_model
# Define the function that calls the model
async def call_model(
state: State, runtime: Runtime[Context]
) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
# Initialize the model with tool binding. Change the model or add more tools here.
model = load_chat_model(runtime.context.model).bind_tools(TOOLS)
# Format the system prompt. Customize this to change the agent's behavior.
system_message = runtime.context.system_prompt.format(
system_time=datetime.now(tz=UTC).isoformat()
)
# Get the model's response
response = cast( # type: ignore[redundant-cast]
AIMessage,
await model.ainvoke(
[{"role": "system", "content": system_message}, *state.messages]
),
)
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}
# Define a new graph
builder = StateGraph(State, input_schema=InputState, context_schema=Context)
# Define the two nodes we will cycle between
builder.add_node(call_model)
builder.add_node("tools", ToolNode(TOOLS))
# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
def route_model_output(state: State) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output.
This function checks if the model's last message contains tool calls.
Args:
state (State): The current state of the conversation.
Returns:
str: The name of the next node to call ("__end__" or "tools").
"""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(
f"Expected AIMessage in output edges, but got {type(last_message).__name__}"
)
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# Otherwise we execute the requested actions
return "tools"
# Add a conditional edge to determine the next step after `call_model`
builder.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
builder.add_edge("tools", "call_model")
# Compile the builder into an executable graph
graph = builder.compile(name="ReAct Agent")
|