Spaces:
Sleeping
Sleeping
File size: 5,962 Bytes
560d5c2 8be7bcb 560d5c2 8be7bcb 560d5c2 8be7bcb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """Define a custom Reasoning and Action agent.
Works with a chat model with tool calling support.
"""
from typing import Dict, List, Literal, cast
from app.agent.configuration import Configuration
from app.agent.state import AgentState, InputState, SQLAgentState
from app.agent.tools import TOOLS
from app.agent.utils import load_chat_model
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
load_dotenv()
# Define the function that calls the model
async def call_model(state: AgentState) -> Dict[str, List[AIMessage]]:
"""Call the LLM powering our "agent".
This function prepares the prompt, initializes the model, and processes the response.
Args:
state (State): The current state of the conversation.
config (RunnableConfig): Configuration for the model run.
Returns:
dict: A dictionary containing the model's response message.
"""
configuration = Configuration.from_context()
# Initialize the model with tool binding. Change the model or add more tools here.
model = load_chat_model(configuration.model).bind_tools(TOOLS)
# Format the system prompt. Customize this to change the agent's behavior.
system_message = configuration.system_prompt
# Get the model's response
response = cast(
AIMessage,
await model.ainvoke([{"role": "system", "content": system_message}, *state.messages]),
)
# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, I could not find an answer to your question in the specified number of steps.",
)
]
}
# Return the model's response as a list to be added to existing messages
return {"messages": [response]}
def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
"""Determine the next node based on the model's output."""
last_message = state.messages[-1]
if not isinstance(last_message, AIMessage):
raise ValueError(f"Expected AIMessage in output edges, but got {type(last_message).__name__}")
# If there is no tool call, then we finish
if not last_message.tool_calls:
return "__end__"
# If we've exceeded max attempts, end the conversation
if state.query_attempts >= 3:
return "__end__"
# Otherwise execute the requested actions
return "tools"
# Initialize the checkpointer
memory = MemorySaver()
# Define a new graph
builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)
# Define the two nodes we will cycle between
builder.add_node(call_model)
builder.add_node("tools", ToolNode(TOOLS))
# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")
# Add a conditional edge to determine the next step after `call_model`
builder.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
route_model_output,
)
# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
builder.add_edge("tools", "call_model")
# Compile the builder into an executable graph WITH checkpointer
graph = builder.compile(checkpointer=memory, name="powersim_agent")
if __name__ == "__main__":
import asyncio
from langchain_core.messages import HumanMessage
async def main():
# Define the input using proper message format
input_data = {
"messages": [
HumanMessage(content="What is the total revenue?"),
]
}
config = {
"configurable": {
"thread_id": "12345",
}
}
# Stream the execution to see what's happening inside
print("\n=== STARTING AGENT EXECUTION ===\n")
# Use astream to see intermediate steps
async for chunk in graph.astream(input_data, config, stream_mode="updates"):
for node_name, node_output in chunk.items():
print(f"\n--- OUTPUT FROM NODE: {node_name} ---")
# Extract messages if they exist
if "messages" in node_output and node_output["messages"]:
latest_message = node_output["messages"][-1]
# Print message content based on type
print(f"MESSAGE TYPE: {type(latest_message).__name__}")
if hasattr(latest_message, "content") and latest_message.content:
print(f"CONTENT: {latest_message.content[:500]}...")
# Print tool calls if present
if hasattr(latest_message, "tool_calls") and latest_message.tool_calls:
print(f"TOOL CALLS: {latest_message.tool_calls}")
# Handle tool messages specifically
if hasattr(latest_message, "name") and hasattr(latest_message, "tool_call_id"):
print(f"TOOL: {latest_message.name}")
print(f"TOOL CALL ID: {latest_message.tool_call_id}")
if hasattr(latest_message, "content"):
print(f"RESULT: {latest_message.content[:500]}...")
print("-----------------------------------")
print("\n==== CHUNK COMPLETE ====\n")
# Get the final response
final_response = await graph.ainvoke(input_data, config)
print("\n=== FINAL RESPONSE ===\n")
print(final_response)
# Run the async main function
asyncio.run(main()) |