QueryMind / src /agent_graph /agent_backend.py
7beshoyarnest's picture
fix IPython import for production
3103e1a
Raw
History Blame Contribute Delete
3.94 kB
import json
from typing import Annotated, Literal
from typing_extensions import TypedDict
from langchain_core.messages import ToolMessage
from langgraph.graph.message import add_messages
try:
from IPython.display import Image, display
except ImportError:
pass
class State(TypedDict):
"""Represents the state structure containing a list of messages.
Attributes:
messages (list): A list of messages, where each message can be processed
by adding messages using the `add_messages` function.
"""
messages: Annotated[list, add_messages]
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage.
This class retrieves tool calls from the most recent AIMessage in the input
and invokes the corresponding tool to generate responses.
Attributes:
tools_by_name (dict): A dictionary mapping tool names to tool instances.
"""
def __init__(self, tools: list) -> None:
"""Initializes the BasicToolNode with available tools.
Args:
tools (list): A list of tool objects, each having a `name` attribute.
"""
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
"""Executes the tools based on the tool calls in the last message.
Args:
inputs (dict): A dictionary containing the input state with messages.
Returns:
dict: A dictionary with a list of `ToolMessage` outputs.
Raises:
ValueError: If no messages are found in the input.
"""
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def route_tools(
state: State,
) -> Literal["tools", "__end__"]:
"""
Determines whether to route to the ToolNode or end the flow.
This function is used in the conditional_edge and checks the last message in the state for tool calls. If tool
calls exist, it routes to the 'tools' node; otherwise, it routes to the end.
Args:
state (State): The input state containing a list of messages.
Returns:
Literal["tools", "__end__"]: Returns 'tools' if there are tool calls;
'__end__' otherwise.
Raises:
ValueError: If no messages are found in the input state.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(
f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
def plot_agent_schema(graph):
"""Plots the agent schema using a graph object, if possible.
Tries to display a visual representation of the agent's graph schema
using Mermaid format and IPython's display capabilities. If the required
dependencies are missing, it catches the exception and prints a message
instead.
Args:
graph: A graph object that has a `get_graph` method, returning a graph
structure that supports Mermaid diagram generation.
Returns:
None
"""
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
return print("Graph could not be displayed.")