nivakaran commited on
Commit
630d498
·
verified ·
1 Parent(s): a5052b0

Create SparrowV2Graph.py

Browse files
Files changed (1) hide show
  1. src/graphs/SparrowV2Graph.py +126 -0
src/graphs/SparrowV2Graph.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, START, END
2
+ from src.states.queryState import SparrowAgentState, SparrowInputState
3
+
4
+ from src.nodes.queryNode import QueryNode
5
+ from src.llms.groqllm import GroqLLM
6
+
7
+ import logging
8
+ from src.graphs.masterGraph import master_graph
9
+ from langchain_core.messages import AIMessage
10
+
11
+
12
+
13
+ class SparrowV2GraphBuilder:
14
+ def __init__(self, llm):
15
+ self.llm = llm
16
+ self.graph = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
17
+ self.logger = logging.getLogger(__name__)
18
+
19
+ def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
20
+ """Convert SparrowAgentState to master graph input format"""
21
+ return {
22
+ "query_brief": state.get("query_brief", ""),
23
+ "execution_jobs": [],
24
+ "completed_jobs": [],
25
+ "worker_outputs": [],
26
+ "final_output": ''
27
+ }
28
+
29
+ def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state:dict) -> SparrowAgentState:
30
+ """Update sparrow state with master results"""
31
+
32
+ final_output = master_state.get("final_output", "")
33
+ if final_output:
34
+ sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
35
+ sparrow_state["final_message"] = final_output
36
+
37
+ execution_jobs = master_state.get("execution_jobs", [])
38
+ completed_jobs = master_state.get("completed_jobs", [])
39
+
40
+ if execution_jobs:
41
+ sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]
42
+
43
+ if completed_jobs:
44
+ sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]
45
+
46
+ return sparrow_state
47
+
48
+
49
+
50
+
51
+
52
+ def route_after_clarification(self, state: SparrowAgentState) -> str:
53
+ """Route based on clarification status from queryNode response"""
54
+
55
+ if state.get("clarification_complete", False):
56
+ print("Routing: Clarification marked as complete")
57
+ return "master_subgraph"
58
+
59
+ if state.get("max_clarification_reached", False):
60
+ print("Routing: Max clarification attempts reached")
61
+ return "master_subgraph"
62
+
63
+ print("Routing: Clarification still needed")
64
+
65
+ return "clarify_with_user"
66
+
67
+ def run_master_subgraph(self, state: SparrowAgentState) -> SparrowAgentState:
68
+ """Run the master subgraph - using sync version to avoid async issues with Send"""
69
+
70
+ try:
71
+ print("Running master subgraph..")
72
+ master_input = self.convert_sparrow_to_master(state)
73
+
74
+ master_result = master_graph.invoke(master_input)
75
+
76
+ return self.update_sparrow_from_master(state, master_result)
77
+
78
+ except Exception as e:
79
+ self.logger.error(f"Master subgraph failed: {e}")
80
+ return {**state, "error": f"Master subgraph failed: {str(e)}"}
81
+
82
+
83
+
84
+
85
+ def run_query_subgraph(self, state:SparrowAgentState) -> SparrowAgentState:
86
+ """Run the query subgraph - using sync version to avoid async issues with Send"""
87
+ try:
88
+ print("Running query subgraph..")
89
+
90
+ except Exception as e:
91
+ self.logger.error(f"Query subgraph failed: {e}")
92
+ return {**state, "error": f"Query subgraph failed: {str(e)}"}
93
+
94
+
95
+ def build_query_graph(self):
96
+ """
97
+ Build a graph for customer query inquiry
98
+
99
+ """
100
+ self.query_node_obj= QueryNode(self.llm)
101
+ print(self.llm)
102
+
103
+ self.graph.add_node("clarify_with_user", self.query_node_obj.clarify_with_user)
104
+ self.graph.add_node("write_query_brief", self.query_node_obj.write_query_brief)
105
+ self.graph.add_node("master_subgraph", self.run_master_subgraph)
106
+
107
+ self.graph.add_edge(START, "clarify_with_user")
108
+ self.graph.add_edge("clarify_with_user", "write_query_brief")
109
+ self.graph.add_conditional_edges(
110
+ "write_query_brief",
111
+ self.route_after_clarification,
112
+ {
113
+ "clarify_with_user": END,
114
+ "master_subgraph": "master_subgraph"
115
+ }
116
+ )
117
+
118
+ self.graph.add_edge("master_subgraph", END)
119
+
120
+ return self.graph
121
+
122
+ llm = GroqLLM().get_llm()
123
+
124
+ graph_builder=SparrowV2GraphBuilder(llm)
125
+
126
+ graph=graph_builder.build_query_graph().compile()