# Updated Sparrow Agent with proper routing import asyncio import logging from src.graphs.masterGraph import master_graph from src.llms.groqllm import GroqLLM from src.states.queryState import SparrowAgentState, SparrowInputState from langgraph.graph import StateGraph, START, END from src.states.masterState import MasterState from langgraph.checkpoint.memory import MemorySaver from src.nodes.queryNode import QueryNode from langchain_core.messages import HumanMessage logger = logging.getLogger(__name__) llm = GroqLLM().get_llm() queryNode = QueryNode(llm) def convert_sparrow_to_master(state: SparrowAgentState) -> dict: """Convert SparrowAgentState to master graph input format""" return { "query_brief": state.get("query_brief", ""), "execution_jobs": [], "completed_jobs": [], "worker_outputs": [], "final_output": '' } def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state: dict) -> SparrowAgentState: """Update sparrow state with master results""" # Add the final result as a message and update notes from langchain_core.messages import AIMessage final_output = master_state.get("final_output", "") if final_output: sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)] sparrow_state["final_message"] = final_output # Add execution details to notes execution_jobs = master_state.get("execution_jobs", []) completed_jobs = master_state.get("completed_jobs", []) if execution_jobs: sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"] if completed_jobs: sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"] return sparrow_state def route_after_clarification(state: SparrowAgentState) -> str: """Route based on clarification status from queryNode response""" # Check messages for clarification status if state.get("clarification_complete", False): print("Routing: Clarification marked as complete") return "write_query_brief" if state.get("max_clarification_reached", False): print("Routing: Max clarification attempts reached") return "write_query_brief" if state.get("information_sufficient", False): print("Routing: Information marked as sufficient") return "write_query_brief" # Secondary safety checks - prevent infinite loops clarification_attempts = state.get("clarification_attempts", 0) if clarification_attempts >= 3: # Match the max_clarification_rounds in QueryNode print(f"Routing: Safety limit reached ({clarification_attempts} attempts)") # Set the flag for consistency state["max_clarification_reached"] = True return "write_query_brief" # Check total message count as final safety net messages = state.get("messages", []) if len(messages) > 12: # Higher threshold than before, but still a safety net print(f"Routing: Message count safety limit reached ({len(messages)} messages)") state["max_clarification_reached"] = True return "write_query_brief" # Check for completion indicators in notes (fallback for older state) notes = state.get("notes", []) completion_indicators = ["sufficient information", "clarification complete", "proceeding"] if any(indicator in note.lower() for note in notes for indicator in completion_indicators): print("Routing: Completion indicator found in notes") return "write_query_brief" # Default case - continue clarification print(f"Routing: Continue clarification (attempt {clarification_attempts + 1})") return "need_clarification" def route_after_query_brief(state: SparrowAgentState) -> str: """Route after query brief creation""" # Check if query brief exists and is adequate if state.get("query_creation_success", False): print("Query brief created successfully, proceeding to master subgraph") return "master_subgraph" # Check if we have a query brief at all query_brief = state.get("query_brief", "").strip() if query_brief and len(query_brief) > 10: # Lower threshold, more forgiving print(f"Query brief exists ({len(query_brief)} chars), proceeding to master subgraph") return "master_subgraph" # Check if we should give up due to too many attempts total_attempts = state.get("clarification_attempts", 0) messages = state.get("messages", []) if total_attempts >= 3 or len(messages) > 15: print("Too many attempts, ending conversation") return "__end__" # If query brief creation failed but we haven't exceeded limits, try more clarification if state.get("error") and total_attempts < 2: print("Query brief creation failed, requesting more clarification") # Reset some flags to allow more clarification state["clarification_complete"] = False state["needs_clarification"] = True state.setdefault("notes", []).append("Query brief creation failed, requesting additional clarification") return "clarify_with_user" # Final fallback - end the conversation print("Unable to create adequate query brief, ending conversation") return "__end__" def need_clarification(state: SparrowAgentState) -> SparrowAgentState: """Handle case where clarification is needed""" from langchain_core.messages import AIMessage print("Additional clarification needed.") state["notes"] = state.get("notes", []) + ["Requested additional clarification from user"] return state def run_master_subgraph(state: SparrowAgentState) -> SparrowAgentState: """Run the master subgraph - using sync version to avoid async issues with Send""" try: print("Running master subgraph...") master_input = convert_sparrow_to_master(state) # Use invoke instead of ainvoke to avoid issues with Send master_result = master_graph.invoke(master_input) return update_sparrow_from_master(state, master_result) except Exception as e: logger.error(f"Master subgraph failed: {e}") return {**state, "error": str(e)} def route_after_need_clarification(state: SparrowAgentState) -> str: """Route after need_clarification node - always end to wait for user input""" return "__end__" # Build the graph sparrowAgentBuilder = StateGraph(SparrowAgentState, input_schema=SparrowInputState) sparrowAgentBuilder.add_node("clarify_with_user", queryNode.clarify_with_user) sparrowAgentBuilder.add_node("need_clarification", need_clarification) sparrowAgentBuilder.add_node("write_query_brief", queryNode.write_query_brief) sparrowAgentBuilder.add_node("master_subgraph", run_master_subgraph) # Edges sparrowAgentBuilder.add_edge(START, "clarify_with_user") sparrowAgentBuilder.add_conditional_edges( "clarify_with_user", route_after_clarification, { "need_clarification": "need_clarification", "write_query_brief": "write_query_brief", "__end__": END } ) # Improved clarification flow sparrowAgentBuilder.add_conditional_edges( "need_clarification", route_after_need_clarification, { "clarify_with_user": "clarify_with_user", "__end__": END } ) sparrowAgentBuilder.add_conditional_edges( "write_query_brief", route_after_query_brief, { "clarify_with_user": "clarify_with_user", "master_subgraph": "master_subgraph", "__end__": END } ) sparrowAgentBuilder.add_edge("master_subgraph", END) sparrowAgent = sparrowAgentBuilder.compile()