cryogenic22 commited on
Commit
9759420
·
verified ·
1 Parent(s): 61abc82

Update graph/workflow.py

Browse files
Files changed (1) hide show
  1. graph/workflow.py +42 -15
graph/workflow.py CHANGED
@@ -134,21 +134,48 @@ def create_agent_graph(anthropic_client, db):
134
  # Set the entry point to understanding_agent
135
  workflow.add_edge(START, "understanding_agent")
136
 
137
- # Modify conditional routing to handle the new Annotated current_agent type
138
- def route_understanding(state: AgentState) -> Union[str, None]:
139
- """Route from understanding agent based on the last added current_agent value."""
140
- current_agents = state['current_agent']
141
- return current_agents[-1] if current_agents else "understanding_agent"
142
-
143
- def route_planning(state: AgentState) -> Union[str, None]:
144
- """Route from planning agent based on the last added current_agent value."""
145
- current_agents = state['current_agent']
146
- return current_agents[-1] if current_agents else "planning_agent"
147
-
148
- def route_sql_generator(state: AgentState) -> Union[str, None]:
149
- """Route from SQL generator agent based on the last added current_agent value."""
150
- current_agents = state['current_agent']
151
- return current_agents[-1] if current_agents else "sql_generator_agent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  workflow.add_conditional_edges(
154
  "understanding_agent",
 
134
  # Set the entry point to understanding_agent
135
  workflow.add_edge(START, "understanding_agent")
136
 
137
+ # Modify routing functions to handle str or list current_agent
138
+ def route_understanding(state: AgentState):
139
+ """Route from understanding agent."""
140
+ current_agent = state['current_agent']
141
+ # If it's a list, take the last element
142
+ if isinstance(current_agent, list):
143
+ current_agent = current_agent[-1]
144
+
145
+ # Routing logic
146
+ routing_map = {
147
+ "understanding_agent": "understanding_agent",
148
+ "planning_agent": "planning_agent"
149
+ }
150
+ return routing_map.get(current_agent, "understanding_agent")
151
+
152
+ def route_planning(state: AgentState):
153
+ """Route from planning agent."""
154
+ current_agent = state['current_agent']
155
+ # If it's a list, take the last element
156
+ if isinstance(current_agent, list):
157
+ current_agent = current_agent[-1]
158
+
159
+ # Routing logic
160
+ routing_map = {
161
+ "planning_agent": "planning_agent",
162
+ "sql_generator_agent": "sql_generator_agent"
163
+ }
164
+ return routing_map.get(current_agent, "planning_agent")
165
+
166
+ def route_sql_generator(state: AgentState):
167
+ """Route from SQL generator agent."""
168
+ current_agent = state['current_agent']
169
+ # If it's a list, take the last element
170
+ if isinstance(current_agent, list):
171
+ current_agent = current_agent[-1]
172
+
173
+ # Routing logic
174
+ routing_map = {
175
+ "sql_generator_agent": "sql_generator_agent",
176
+ "executor_agent": "executor_agent"
177
+ }
178
+ return routing_map.get(current_agent, "sql_generator_agent")
179
 
180
  workflow.add_conditional_edges(
181
  "understanding_agent",