arthi.kasturirangan@informa.com commited on
Commit
8be7bcb
·
1 Parent(s): a014f29

add checkpointers

Browse files
Files changed (1) hide show
  1. app/agent/graph.py +17 -16
app/agent/graph.py CHANGED
@@ -59,18 +59,6 @@ async def call_model(state: AgentState) -> Dict[str, List[AIMessage]]:
59
  return {"messages": [response]}
60
 
61
 
62
- # Define a new graph
63
- builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)
64
-
65
- # Define the two nodes we will cycle between
66
- builder.add_node(call_model)
67
- builder.add_node("tools", ToolNode(TOOLS))
68
-
69
- # Set the entrypoint as `call_model`
70
- # This means that this node is the first one called
71
- builder.add_edge("__start__", "call_model")
72
-
73
-
74
  def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
75
  """Determine the next node based on the model's output."""
76
  last_message = state.messages[-1]
@@ -89,6 +77,20 @@ def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
89
  return "tools"
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Add a conditional edge to determine the next step after `call_model`
93
  builder.add_conditional_edges(
94
  "call_model",
@@ -101,9 +103,8 @@ builder.add_conditional_edges(
101
  # This creates a cycle: after using tools, we always return to the model
102
  builder.add_edge("tools", "call_model")
103
 
104
- # Compile the builder into an executable graph
105
- memory = MemorySaver()
106
- graph = builder.compile(name="powersim_agent")
107
 
108
  if __name__ == "__main__":
109
  import asyncio
@@ -164,4 +165,4 @@ if __name__ == "__main__":
164
  print(final_response)
165
 
166
  # Run the async main function
167
- asyncio.run(main())
 
59
  return {"messages": [response]}
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def route_model_output(state: SQLAgentState) -> Literal["__end__", "tools"]:
63
  """Determine the next node based on the model's output."""
64
  last_message = state.messages[-1]
 
77
  return "tools"
78
 
79
 
80
+ # Initialize the checkpointer
81
+ memory = MemorySaver()
82
+
83
+ # Define a new graph
84
+ builder = StateGraph(AgentState, input=InputState, config_schema=Configuration)
85
+
86
+ # Define the two nodes we will cycle between
87
+ builder.add_node(call_model)
88
+ builder.add_node("tools", ToolNode(TOOLS))
89
+
90
+ # Set the entrypoint as `call_model`
91
+ # This means that this node is the first one called
92
+ builder.add_edge("__start__", "call_model")
93
+
94
  # Add a conditional edge to determine the next step after `call_model`
95
  builder.add_conditional_edges(
96
  "call_model",
 
103
  # This creates a cycle: after using tools, we always return to the model
104
  builder.add_edge("tools", "call_model")
105
 
106
+ # Compile the builder into an executable graph WITH checkpointer
107
+ graph = builder.compile(checkpointer=memory, name="powersim_agent")
 
108
 
109
  if __name__ == "__main__":
110
  import asyncio
 
165
  print(final_response)
166
 
167
  # Run the async main function
168
+ asyncio.run(main())