File size: 4,408 Bytes
630d498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from langgraph.graph import StateGraph, START, END 
from src.states.queryState import SparrowAgentState, SparrowInputState

from src.nodes.queryNode import QueryNode
from src.llms.groqllm import GroqLLM

import logging 
from src.graphs.masterGraph import master_graph
from langchain_core.messages import AIMessage



class SparrowV2GraphBuilder:
    def __init__(self, llm):
        self.llm = llm
        self.graph = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
        self.logger = logging.getLogger(__name__)

    def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
        """Convert SparrowAgentState to master graph input format"""
        return {
            "query_brief": state.get("query_brief", ""),
            "execution_jobs": [],
            "completed_jobs": [],
            "worker_outputs": [],
            "final_output": ''
        }
    
    def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state:dict) -> SparrowAgentState:
        """Update sparrow state with master results"""
         
        final_output = master_state.get("final_output", "")
        if final_output:
            sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
            sparrow_state["final_message"] = final_output

        execution_jobs = master_state.get("execution_jobs", [])
        completed_jobs = master_state.get("completed_jobs", [])

        if execution_jobs:
            sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]

        if completed_jobs:
            sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]

        return sparrow_state
    
   

        
 
    def route_after_clarification(self, state: SparrowAgentState) -> str:
        """Route based on clarification status from queryNode response"""

        if state.get("clarification_complete", False):
            print("Routing: Clarification marked as complete")
            return "master_subgraph"

        if state.get("max_clarification_reached", False):
            print("Routing: Max clarification attempts reached")
            return "master_subgraph"
        
        print("Routing: Clarification still needed")
        
        return "clarify_with_user"
    
    def run_master_subgraph(self, state: SparrowAgentState) -> SparrowAgentState:
        """Run the master subgraph - using sync version to avoid async issues with Send"""

        try:
            print("Running master subgraph..")
            master_input = self.convert_sparrow_to_master(state)

            master_result = master_graph.invoke(master_input)

            return self.update_sparrow_from_master(state, master_result)

        except Exception as e:
            self.logger.error(f"Master subgraph failed: {e}")
            return {**state, "error": f"Master subgraph failed: {str(e)}"}
        

    

    def run_query_subgraph(self, state:SparrowAgentState) -> SparrowAgentState:
        """Run the query subgraph - using sync version to avoid async issues with Send"""
        try:
            print("Running query subgraph..")

        except Exception as e:
            self.logger.error(f"Query subgraph failed: {e}")
            return {**state, "error": f"Query subgraph failed: {str(e)}"}

    
    def build_query_graph(self):
        """
        Build a graph for customer query inquiry

        """
        self.query_node_obj= QueryNode(self.llm)
        print(self.llm)

        self.graph.add_node("clarify_with_user", self.query_node_obj.clarify_with_user)
        self.graph.add_node("write_query_brief", self.query_node_obj.write_query_brief)
        self.graph.add_node("master_subgraph", self.run_master_subgraph)

        self.graph.add_edge(START, "clarify_with_user")
        self.graph.add_edge("clarify_with_user", "write_query_brief")
        self.graph.add_conditional_edges(
            "write_query_brief",
            self.route_after_clarification,
            {
                "clarify_with_user": END,
                "master_subgraph": "master_subgraph"
            }
        )
        
        self.graph.add_edge("master_subgraph", END)

        return self.graph
    
llm = GroqLLM().get_llm()

graph_builder=SparrowV2GraphBuilder(llm)

graph=graph_builder.build_query_graph().compile()