Spaces:
Sleeping
Sleeping
File size: 4,408 Bytes
630d498 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from langgraph.graph import StateGraph, START, END
from src.states.queryState import SparrowAgentState, SparrowInputState
from src.nodes.queryNode import QueryNode
from src.llms.groqllm import GroqLLM
import logging
from src.graphs.masterGraph import master_graph
from langchain_core.messages import AIMessage
class SparrowV2GraphBuilder:
def __init__(self, llm):
self.llm = llm
self.graph = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
self.logger = logging.getLogger(__name__)
def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
"""Convert SparrowAgentState to master graph input format"""
return {
"query_brief": state.get("query_brief", ""),
"execution_jobs": [],
"completed_jobs": [],
"worker_outputs": [],
"final_output": ''
}
def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state:dict) -> SparrowAgentState:
"""Update sparrow state with master results"""
final_output = master_state.get("final_output", "")
if final_output:
sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
sparrow_state["final_message"] = final_output
execution_jobs = master_state.get("execution_jobs", [])
completed_jobs = master_state.get("completed_jobs", [])
if execution_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]
if completed_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]
return sparrow_state
def route_after_clarification(self, state: SparrowAgentState) -> str:
"""Route based on clarification status from queryNode response"""
if state.get("clarification_complete", False):
print("Routing: Clarification marked as complete")
return "master_subgraph"
if state.get("max_clarification_reached", False):
print("Routing: Max clarification attempts reached")
return "master_subgraph"
print("Routing: Clarification still needed")
return "clarify_with_user"
def run_master_subgraph(self, state: SparrowAgentState) -> SparrowAgentState:
"""Run the master subgraph - using sync version to avoid async issues with Send"""
try:
print("Running master subgraph..")
master_input = self.convert_sparrow_to_master(state)
master_result = master_graph.invoke(master_input)
return self.update_sparrow_from_master(state, master_result)
except Exception as e:
self.logger.error(f"Master subgraph failed: {e}")
return {**state, "error": f"Master subgraph failed: {str(e)}"}
def run_query_subgraph(self, state:SparrowAgentState) -> SparrowAgentState:
"""Run the query subgraph - using sync version to avoid async issues with Send"""
try:
print("Running query subgraph..")
except Exception as e:
self.logger.error(f"Query subgraph failed: {e}")
return {**state, "error": f"Query subgraph failed: {str(e)}"}
def build_query_graph(self):
"""
Build a graph for customer query inquiry
"""
self.query_node_obj= QueryNode(self.llm)
print(self.llm)
self.graph.add_node("clarify_with_user", self.query_node_obj.clarify_with_user)
self.graph.add_node("write_query_brief", self.query_node_obj.write_query_brief)
self.graph.add_node("master_subgraph", self.run_master_subgraph)
self.graph.add_edge(START, "clarify_with_user")
self.graph.add_edge("clarify_with_user", "write_query_brief")
self.graph.add_conditional_edges(
"write_query_brief",
self.route_after_clarification,
{
"clarify_with_user": END,
"master_subgraph": "master_subgraph"
}
)
self.graph.add_edge("master_subgraph", END)
return self.graph
llm = GroqLLM().get_llm()
graph_builder=SparrowV2GraphBuilder(llm)
graph=graph_builder.build_query_graph().compile() |