Spaces:
Sleeping
Sleeping
| import json | |
| from typing import Annotated, Literal | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import ToolMessage | |
| from langgraph.graph.message import add_messages | |
| try: | |
| from IPython.display import Image, display | |
| except ImportError: | |
| pass | |
| class State(TypedDict): | |
| """Represents the state structure containing a list of messages. | |
| Attributes: | |
| messages (list): A list of messages, where each message can be processed | |
| by adding messages using the `add_messages` function. | |
| """ | |
| messages: Annotated[list, add_messages] | |
| class BasicToolNode: | |
| """A node that runs the tools requested in the last AIMessage. | |
| This class retrieves tool calls from the most recent AIMessage in the input | |
| and invokes the corresponding tool to generate responses. | |
| Attributes: | |
| tools_by_name (dict): A dictionary mapping tool names to tool instances. | |
| """ | |
| def __init__(self, tools: list) -> None: | |
| """Initializes the BasicToolNode with available tools. | |
| Args: | |
| tools (list): A list of tool objects, each having a `name` attribute. | |
| """ | |
| self.tools_by_name = {tool.name: tool for tool in tools} | |
| def __call__(self, inputs: dict): | |
| """Executes the tools based on the tool calls in the last message. | |
| Args: | |
| inputs (dict): A dictionary containing the input state with messages. | |
| Returns: | |
| dict: A dictionary with a list of `ToolMessage` outputs. | |
| Raises: | |
| ValueError: If no messages are found in the input. | |
| """ | |
| if messages := inputs.get("messages", []): | |
| message = messages[-1] | |
| else: | |
| raise ValueError("No message found in input") | |
| outputs = [] | |
| for tool_call in message.tool_calls: | |
| tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
| tool_call["args"] | |
| ) | |
| outputs.append( | |
| ToolMessage( | |
| content=json.dumps(tool_result), | |
| name=tool_call["name"], | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| return {"messages": outputs} | |
| def route_tools( | |
| state: State, | |
| ) -> Literal["tools", "__end__"]: | |
| """ | |
| Determines whether to route to the ToolNode or end the flow. | |
| This function is used in the conditional_edge and checks the last message in the state for tool calls. If tool | |
| calls exist, it routes to the 'tools' node; otherwise, it routes to the end. | |
| Args: | |
| state (State): The input state containing a list of messages. | |
| Returns: | |
| Literal["tools", "__end__"]: Returns 'tools' if there are tool calls; | |
| '__end__' otherwise. | |
| Raises: | |
| ValueError: If no messages are found in the input state. | |
| """ | |
| if isinstance(state, list): | |
| ai_message = state[-1] | |
| elif messages := state.get("messages", []): | |
| ai_message = messages[-1] | |
| else: | |
| raise ValueError( | |
| f"No messages found in input state to tool_edge: {state}") | |
| if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
| return "tools" | |
| return "__end__" | |
| def plot_agent_schema(graph): | |
| """Plots the agent schema using a graph object, if possible. | |
| Tries to display a visual representation of the agent's graph schema | |
| using Mermaid format and IPython's display capabilities. If the required | |
| dependencies are missing, it catches the exception and prints a message | |
| instead. | |
| Args: | |
| graph: A graph object that has a `get_graph` method, returning a graph | |
| structure that supports Mermaid diagram generation. | |
| Returns: | |
| None | |
| """ | |
| try: | |
| display(Image(graph.get_graph().draw_mermaid_png())) | |
| except Exception: | |
| # This requires some extra dependencies and is optional | |
| return print("Graph could not be displayed.") | |