SparrowAgent / src /graphs /SparrowV2Graph.py
nivakaran's picture
Create SparrowV2Graph.py
630d498 verified
from langgraph.graph import StateGraph, START, END
from src.states.queryState import SparrowAgentState, SparrowInputState
from src.nodes.queryNode import QueryNode
from src.llms.groqllm import GroqLLM
import logging
from src.graphs.masterGraph import master_graph
from langchain_core.messages import AIMessage
class SparrowV2GraphBuilder:
def __init__(self, llm):
self.llm = llm
self.graph = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
self.logger = logging.getLogger(__name__)
def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
"""Convert SparrowAgentState to master graph input format"""
return {
"query_brief": state.get("query_brief", ""),
"execution_jobs": [],
"completed_jobs": [],
"worker_outputs": [],
"final_output": ''
}
def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state:dict) -> SparrowAgentState:
"""Update sparrow state with master results"""
final_output = master_state.get("final_output", "")
if final_output:
sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
sparrow_state["final_message"] = final_output
execution_jobs = master_state.get("execution_jobs", [])
completed_jobs = master_state.get("completed_jobs", [])
if execution_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]
if completed_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]
return sparrow_state
def route_after_clarification(self, state: SparrowAgentState) -> str:
"""Route based on clarification status from queryNode response"""
if state.get("clarification_complete", False):
print("Routing: Clarification marked as complete")
return "master_subgraph"
if state.get("max_clarification_reached", False):
print("Routing: Max clarification attempts reached")
return "master_subgraph"
print("Routing: Clarification still needed")
return "clarify_with_user"
def run_master_subgraph(self, state: SparrowAgentState) -> SparrowAgentState:
"""Run the master subgraph - using sync version to avoid async issues with Send"""
try:
print("Running master subgraph..")
master_input = self.convert_sparrow_to_master(state)
master_result = master_graph.invoke(master_input)
return self.update_sparrow_from_master(state, master_result)
except Exception as e:
self.logger.error(f"Master subgraph failed: {e}")
return {**state, "error": f"Master subgraph failed: {str(e)}"}
def run_query_subgraph(self, state:SparrowAgentState) -> SparrowAgentState:
"""Run the query subgraph - using sync version to avoid async issues with Send"""
try:
print("Running query subgraph..")
except Exception as e:
self.logger.error(f"Query subgraph failed: {e}")
return {**state, "error": f"Query subgraph failed: {str(e)}"}
def build_query_graph(self):
"""
Build a graph for customer query inquiry
"""
self.query_node_obj= QueryNode(self.llm)
print(self.llm)
self.graph.add_node("clarify_with_user", self.query_node_obj.clarify_with_user)
self.graph.add_node("write_query_brief", self.query_node_obj.write_query_brief)
self.graph.add_node("master_subgraph", self.run_master_subgraph)
self.graph.add_edge(START, "clarify_with_user")
self.graph.add_edge("clarify_with_user", "write_query_brief")
self.graph.add_conditional_edges(
"write_query_brief",
self.route_after_clarification,
{
"clarify_with_user": END,
"master_subgraph": "master_subgraph"
}
)
self.graph.add_edge("master_subgraph", END)
return self.graph
llm = GroqLLM().get_llm()
graph_builder=SparrowV2GraphBuilder(llm)
graph=graph_builder.build_query_graph().compile()