Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, START, END | |
| from src.states.queryState import SparrowAgentState, SparrowInputState | |
| from src.nodes.queryNode import QueryNode | |
| from src.llms.groqllm import GroqLLM | |
| import logging | |
| from src.graphs.masterGraph import master_graph | |
| from langchain_core.messages import AIMessage | |
| class SparrowV2GraphBuilder: | |
| def __init__(self, llm): | |
| self.llm = llm | |
| self.graph = StateGraph(SparrowAgentState, input_schema=SparrowInputState) | |
| self.logger = logging.getLogger(__name__) | |
| def convert_sparrow_to_master(state: SparrowAgentState) -> dict: | |
| """Convert SparrowAgentState to master graph input format""" | |
| return { | |
| "query_brief": state.get("query_brief", ""), | |
| "execution_jobs": [], | |
| "completed_jobs": [], | |
| "worker_outputs": [], | |
| "final_output": '' | |
| } | |
| def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state:dict) -> SparrowAgentState: | |
| """Update sparrow state with master results""" | |
| final_output = master_state.get("final_output", "") | |
| if final_output: | |
| sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)] | |
| sparrow_state["final_message"] = final_output | |
| execution_jobs = master_state.get("execution_jobs", []) | |
| completed_jobs = master_state.get("completed_jobs", []) | |
| if execution_jobs: | |
| sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"] | |
| if completed_jobs: | |
| sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"] | |
| return sparrow_state | |
| def route_after_clarification(self, state: SparrowAgentState) -> str: | |
| """Route based on clarification status from queryNode response""" | |
| if state.get("clarification_complete", False): | |
| print("Routing: Clarification marked as complete") | |
| return "master_subgraph" | |
| if state.get("max_clarification_reached", False): | |
| print("Routing: Max clarification attempts reached") | |
| return "master_subgraph" | |
| print("Routing: Clarification still needed") | |
| return "clarify_with_user" | |
| def run_master_subgraph(self, state: SparrowAgentState) -> SparrowAgentState: | |
| """Run the master subgraph - using sync version to avoid async issues with Send""" | |
| try: | |
| print("Running master subgraph..") | |
| master_input = self.convert_sparrow_to_master(state) | |
| master_result = master_graph.invoke(master_input) | |
| return self.update_sparrow_from_master(state, master_result) | |
| except Exception as e: | |
| self.logger.error(f"Master subgraph failed: {e}") | |
| return {**state, "error": f"Master subgraph failed: {str(e)}"} | |
| def run_query_subgraph(self, state:SparrowAgentState) -> SparrowAgentState: | |
| """Run the query subgraph - using sync version to avoid async issues with Send""" | |
| try: | |
| print("Running query subgraph..") | |
| except Exception as e: | |
| self.logger.error(f"Query subgraph failed: {e}") | |
| return {**state, "error": f"Query subgraph failed: {str(e)}"} | |
| def build_query_graph(self): | |
| """ | |
| Build a graph for customer query inquiry | |
| """ | |
| self.query_node_obj= QueryNode(self.llm) | |
| print(self.llm) | |
| self.graph.add_node("clarify_with_user", self.query_node_obj.clarify_with_user) | |
| self.graph.add_node("write_query_brief", self.query_node_obj.write_query_brief) | |
| self.graph.add_node("master_subgraph", self.run_master_subgraph) | |
| self.graph.add_edge(START, "clarify_with_user") | |
| self.graph.add_edge("clarify_with_user", "write_query_brief") | |
| self.graph.add_conditional_edges( | |
| "write_query_brief", | |
| self.route_after_clarification, | |
| { | |
| "clarify_with_user": END, | |
| "master_subgraph": "master_subgraph" | |
| } | |
| ) | |
| self.graph.add_edge("master_subgraph", END) | |
| return self.graph | |
| llm = GroqLLM().get_llm() | |
| graph_builder=SparrowV2GraphBuilder(llm) | |
| graph=graph_builder.build_query_graph().compile() |