File size: 5,962 Bytes
560d5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be7bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
560d5c2
 
 
 
 
 
 
 
 
 
 
 
8be7bcb
 
560d5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be7bcb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Define a custom Reasoning and Action agent.

Works with a chat model with tool calling support.
"""

from typing import Dict, List, Literal, cast

from app.agent.configuration import Configuration
from app.agent.state import AgentState, InputState, SQLAgentState
from app.agent.tools import TOOLS
from app.agent.utils import load_chat_model
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode

load_dotenv()

# Define the function that calls the model
async def call_model(state: AgentState) -> Dict[str, List[AIMessage]]:
    """Call the LLM powering our "agent".

    This function prepares the prompt, initializes the model, and processes the response.

    Args:
        state (State): The current state of the conversation.
        config (RunnableConfig): Configuration for the model run.

    Returns:
        dict: A dictionary containing the model's response message.
    """
    configuration = Configuration.from_context()

    # Initialize the model with tool binding. Change the model or add more tools here.
    model = load_chat_model(configuration.model).bind_tools(TOOLS)

    # Format the system prompt. Customize this to change the agent's behavior.
    system_message = configuration.system_prompt

    # Get the model's response
    response = cast(
        AIMessage,
        await model.ainvoke([{"role": "system", "content": system_message}, *state.messages]),
    )

    # Handle the case when it's the last step and the model still wants to use a tool
    if state.is_last_step and response.tool_calls:
        return {
            "messages": [
                AIMessage(
                    id=response.id,
                    content="Sorry, I could not find an answer to your question in the specified number of steps.",
                )
            ]
        }

    # Return the model's response as a list to be added to existing messages
    return {"messages": [response]}


def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
    """Determine the next node based on the model's output."""
    last_message = state.messages[-1]
    if not isinstance(last_message, AIMessage):
        raise ValueError(f"Expected AIMessage in output edges, but got {type(last_message).__name__}")

    # If there is no tool call, then we finish
    if not last_message.tool_calls:
        return "__end__"

    # If we've exceeded max attempts, end the conversation
    if state.query_attempts >= 3:
        return "__end__"

    # Otherwise execute the requested actions
    return "tools"


# Initialize the checkpointer
memory = MemorySaver()

# Define a new graph
builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)

# Define the two nodes we will cycle between
builder.add_node(call_model)
builder.add_node("tools", ToolNode(TOOLS))

# Set the entrypoint as `call_model`
# This means that this node is the first one called
builder.add_edge("__start__", "call_model")

# Add a conditional edge to determine the next step after `call_model`
builder.add_conditional_edges(
    "call_model",
    # After call_model finishes running, the next node(s) are scheduled
    # based on the output from route_model_output
    route_model_output,
)

# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
builder.add_edge("tools", "call_model")

# Compile the builder into an executable graph WITH checkpointer
graph = builder.compile(checkpointer=memory, name="powersim_agent")

if __name__ == "__main__":
    import asyncio

    from langchain_core.messages import HumanMessage

    async def main():
        # Define the input using proper message format
        input_data = {
            "messages": [
                HumanMessage(content="What is the total revenue?"),
            ]
        }

        config = {
            "configurable": {
                "thread_id": "12345",
            }
        }

        # Stream the execution to see what's happening inside
        print("\n=== STARTING AGENT EXECUTION ===\n")

        # Use astream to see intermediate steps
        async for chunk in graph.astream(input_data, config, stream_mode="updates"):
            for node_name, node_output in chunk.items():
                print(f"\n--- OUTPUT FROM NODE: {node_name} ---")

                # Extract messages if they exist
                if "messages" in node_output and node_output["messages"]:
                    latest_message = node_output["messages"][-1]

                    # Print message content based on type
                    print(f"MESSAGE TYPE: {type(latest_message).__name__}")

                    if hasattr(latest_message, "content") and latest_message.content:
                        print(f"CONTENT: {latest_message.content[:500]}...")

                    # Print tool calls if present
                    if hasattr(latest_message, "tool_calls") and latest_message.tool_calls:
                        print(f"TOOL CALLS: {latest_message.tool_calls}")

                    # Handle tool messages specifically
                    if hasattr(latest_message, "name") and hasattr(latest_message, "tool_call_id"):
                        print(f"TOOL: {latest_message.name}")
                        print(f"TOOL CALL ID: {latest_message.tool_call_id}")
                        if hasattr(latest_message, "content"):
                            print(f"RESULT: {latest_message.content[:500]}...")

                print("-----------------------------------")

            print("\n==== CHUNK COMPLETE ====\n")

        # Get the final response
        final_response = await graph.ainvoke(input_data, config)

        print("\n=== FINAL RESPONSE ===\n")
        print(final_response)

    # Run the async main function
    asyncio.run(main())