Spaces:
Sleeping
Sleeping
| """Define a custom Reasoning and Action agent. | |
| Works with a chat model with tool calling support. | |
| """ | |
| from typing import Dict, List, Literal, cast | |
| from app.agent.configuration import Configuration | |
| from app.agent.state import AgentState, InputState, SQLAgentState | |
| from app.agent.tools import TOOLS | |
| from app.agent.utils import load_chat_model | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import AIMessage | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import StateGraph | |
| from langgraph.prebuilt import ToolNode | |
| load_dotenv() | |
| # Define the function that calls the model | |
| async def call_model(state: AgentState) -> Dict[str, List[AIMessage]]: | |
| """Call the LLM powering our "agent". | |
| This function prepares the prompt, initializes the model, and processes the response. | |
| Args: | |
| state (State): The current state of the conversation. | |
| config (RunnableConfig): Configuration for the model run. | |
| Returns: | |
| dict: A dictionary containing the model's response message. | |
| """ | |
| configuration = Configuration.from_context() | |
| # Initialize the model with tool binding. Change the model or add more tools here. | |
| model = load_chat_model(configuration.model).bind_tools(TOOLS) | |
| # Format the system prompt. Customize this to change the agent's behavior. | |
| system_message = configuration.system_prompt | |
| # Get the model's response | |
| response = cast( | |
| AIMessage, | |
| await model.ainvoke([{"role": "system", "content": system_message}, *state.messages]), | |
| ) | |
| # Handle the case when it's the last step and the model still wants to use a tool | |
| if state.is_last_step and response.tool_calls: | |
| return { | |
| "messages": [ | |
| AIMessage( | |
| id=response.id, | |
| content="Sorry, I could not find an answer to your question in the specified number of steps.", | |
| ) | |
| ] | |
| } | |
| # Return the model's response as a list to be added to existing messages | |
| return {"messages": [response]} | |
| def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]: | |
| """Determine the next node based on the model's output.""" | |
| last_message = state.messages[-1] | |
| if not isinstance(last_message, AIMessage): | |
| raise ValueError(f"Expected AIMessage in output edges, but got {type(last_message).__name__}") | |
| # If there is no tool call, then we finish | |
| if not last_message.tool_calls: | |
| return "__end__" | |
| # If we've exceeded max attempts, end the conversation | |
| if state.query_attempts >= 3: | |
| return "__end__" | |
| # Otherwise execute the requested actions | |
| return "tools" | |
| # Initialize the checkpointer | |
| memory = MemorySaver() | |
| # Define a new graph | |
| builder = StateGraph(AgentState, input=InputState, config_schema=Configuration) | |
| # Define the two nodes we will cycle between | |
| builder.add_node(call_model) | |
| builder.add_node("tools", ToolNode(TOOLS)) | |
| # Set the entrypoint as `call_model` | |
| # This means that this node is the first one called | |
| builder.add_edge("__start__", "call_model") | |
| # Add a conditional edge to determine the next step after `call_model` | |
| builder.add_conditional_edges( | |
| "call_model", | |
| # After call_model finishes running, the next node(s) are scheduled | |
| # based on the output from route_model_output | |
| route_model_output, | |
| ) | |
| # Add a normal edge from `tools` to `call_model` | |
| # This creates a cycle: after using tools, we always return to the model | |
| builder.add_edge("tools", "call_model") | |
| # Compile the builder into an executable graph WITH checkpointer | |
| graph = builder.compile(checkpointer=memory, name="powersim_agent") | |
| if __name__ == "__main__": | |
| import asyncio | |
| from langchain_core.messages import HumanMessage | |
| async def main(): | |
| # Define the input using proper message format | |
| input_data = { | |
| "messages": [ | |
| HumanMessage(content="What is the total revenue?"), | |
| ] | |
| } | |
| config = { | |
| "configurable": { | |
| "thread_id": "12345", | |
| } | |
| } | |
| # Stream the execution to see what's happening inside | |
| print("\n=== STARTING AGENT EXECUTION ===\n") | |
| # Use astream to see intermediate steps | |
| async for chunk in graph.astream(input_data, config, stream_mode="updates"): | |
| for node_name, node_output in chunk.items(): | |
| print(f"\n--- OUTPUT FROM NODE: {node_name} ---") | |
| # Extract messages if they exist | |
| if "messages" in node_output and node_output["messages"]: | |
| latest_message = node_output["messages"][-1] | |
| # Print message content based on type | |
| print(f"MESSAGE TYPE: {type(latest_message).__name__}") | |
| if hasattr(latest_message, "content") and latest_message.content: | |
| print(f"CONTENT: {latest_message.content[:500]}...") | |
| # Print tool calls if present | |
| if hasattr(latest_message, "tool_calls") and latest_message.tool_calls: | |
| print(f"TOOL CALLS: {latest_message.tool_calls}") | |
| # Handle tool messages specifically | |
| if hasattr(latest_message, "name") and hasattr(latest_message, "tool_call_id"): | |
| print(f"TOOL: {latest_message.name}") | |
| print(f"TOOL CALL ID: {latest_message.tool_call_id}") | |
| if hasattr(latest_message, "content"): | |
| print(f"RESULT: {latest_message.content[:500]}...") | |
| print("-----------------------------------") | |
| print("\n==== CHUNK COMPLETE ====\n") | |
| # Get the final response | |
| final_response = await graph.ainvoke(input_data, config) | |
| print("\n=== FINAL RESPONSE ===\n") | |
| print(final_response) | |
| # Run the async main function | |
| asyncio.run(main()) |