Spaces:
Sleeping
Sleeping
arthi.kasturirangan@informa.com commited on
Commit ·
8be7bcb
1
Parent(s): a014f29
add checkpointers
Browse files- app/agent/graph.py +17 -16
app/agent/graph.py
CHANGED
|
@@ -59,18 +59,6 @@ async def call_model(state: AgentState) -> Dict[str, List[AIMessage]]:
|
|
| 59 |
return {"messages": [response]}
|
| 60 |
|
| 61 |
|
| 62 |
-
# Define a new graph
|
| 63 |
-
builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)
|
| 64 |
-
|
| 65 |
-
# Define the two nodes we will cycle between
|
| 66 |
-
builder.add_node(call_model)
|
| 67 |
-
builder.add_node("tools", ToolNode(TOOLS))
|
| 68 |
-
|
| 69 |
-
# Set the entrypoint as `call_model`
|
| 70 |
-
# This means that this node is the first one called
|
| 71 |
-
builder.add_edge("__start__", "call_model")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
|
| 75 |
"""Determine the next node based on the model's output."""
|
| 76 |
last_message = state.messages[-1]
|
|
@@ -89,6 +77,20 @@ def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
|
|
| 89 |
return "tools"
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Add a conditional edge to determine the next step after `call_model`
|
| 93 |
builder.add_conditional_edges(
|
| 94 |
"call_model",
|
|
@@ -101,9 +103,8 @@ builder.add_conditional_edges(
|
|
| 101 |
# This creates a cycle: after using tools, we always return to the model
|
| 102 |
builder.add_edge("tools", "call_model")
|
| 103 |
|
| 104 |
-
# Compile the builder into an executable graph
|
| 105 |
-
|
| 106 |
-
graph = builder.compile(name="powersim_agent")
|
| 107 |
|
| 108 |
if __name__ == "__main__":
|
| 109 |
import asyncio
|
|
@@ -164,4 +165,4 @@ if __name__ == "__main__":
|
|
| 164 |
print(final_response)
|
| 165 |
|
| 166 |
# Run the async main function
|
| 167 |
-
asyncio.run(main())
|
|
|
|
| 59 |
return {"messages": [response]}
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
|
| 63 |
"""Determine the next node based on the model's output."""
|
| 64 |
last_message = state.messages[-1]
|
|
|
|
| 77 |
return "tools"
|
| 78 |
|
| 79 |
|
| 80 |
+
# Initialize the checkpointer
|
| 81 |
+
memory = MemorySaver()
|
| 82 |
+
|
| 83 |
+
# Define a new graph
|
| 84 |
+
builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)
|
| 85 |
+
|
| 86 |
+
# Define the two nodes we will cycle between
|
| 87 |
+
builder.add_node(call_model)
|
| 88 |
+
builder.add_node("tools", ToolNode(TOOLS))
|
| 89 |
+
|
| 90 |
+
# Set the entrypoint as `call_model`
|
| 91 |
+
# This means that this node is the first one called
|
| 92 |
+
builder.add_edge("__start__", "call_model")
|
| 93 |
+
|
| 94 |
# Add a conditional edge to determine the next step after `call_model`
|
| 95 |
builder.add_conditional_edges(
|
| 96 |
"call_model",
|
|
|
|
| 103 |
# This creates a cycle: after using tools, we always return to the model
|
| 104 |
builder.add_edge("tools", "call_model")
|
| 105 |
|
| 106 |
+
# Compile the builder into an executable graph WITH checkpointer
|
| 107 |
+
graph = builder.compile(checkpointer=memory, name="powersim_agent")
|
|
|
|
| 108 |
|
| 109 |
if __name__ == "__main__":
|
| 110 |
import asyncio
|
|
|
|
| 165 |
print(final_response)
|
| 166 |
|
| 167 |
# Run the async main function
|
| 168 |
+
asyncio.run(main())
|