Spaces:
Sleeping
Sleeping
File size: 7,845 Bytes
03c1af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Updated Sparrow Agent with proper routing
import asyncio
import logging
from src.graphs.masterGraph import master_graph
from src.llms.groqllm import GroqLLM
from src.states.queryState import SparrowAgentState, SparrowInputState
from langgraph.graph import StateGraph, START, END
from src.states.masterState import MasterState
from langgraph.checkpoint.memory import MemorySaver
from src.nodes.queryNode import QueryNode
from langchain_core.messages import HumanMessage
logger = logging.getLogger(__name__)
llm = GroqLLM().get_llm()
queryNode = QueryNode(llm)
def convert_sparrow_to_master(state: SparrowAgentState) -> dict:
"""Convert SparrowAgentState to master graph input format"""
return {
"query_brief": state.get("query_brief", ""),
"execution_jobs": [],
"completed_jobs": [],
"worker_outputs": [],
"final_output": ''
}
def update_sparrow_from_master(sparrow_state: SparrowAgentState, master_state: dict) -> SparrowAgentState:
"""Update sparrow state with master results"""
# Add the final result as a message and update notes
from langchain_core.messages import AIMessage
final_output = master_state.get("final_output", "")
if final_output:
sparrow_state["messages"] = sparrow_state.get("messages", []) + [AIMessage(content=final_output)]
sparrow_state["final_message"] = final_output
# Add execution details to notes
execution_jobs = master_state.get("execution_jobs", [])
completed_jobs = master_state.get("completed_jobs", [])
if execution_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Execution jobs: {', '.join(execution_jobs)}"]
if completed_jobs:
sparrow_state["notes"] = sparrow_state.get("notes", []) + [f"Completed: {', '.join(completed_jobs)}"]
return sparrow_state
def route_after_clarification(state: SparrowAgentState) -> str:
"""Route based on clarification status from queryNode response"""
# Check messages for clarification status
if state.get("clarification_complete", False):
print("Routing: Clarification marked as complete")
return "write_query_brief"
if state.get("max_clarification_reached", False):
print("Routing: Max clarification attempts reached")
return "write_query_brief"
if state.get("information_sufficient", False):
print("Routing: Information marked as sufficient")
return "write_query_brief"
# Secondary safety checks - prevent infinite loops
clarification_attempts = state.get("clarification_attempts", 0)
if clarification_attempts >= 3: # Match the max_clarification_rounds in QueryNode
print(f"Routing: Safety limit reached ({clarification_attempts} attempts)")
# Set the flag for consistency
state["max_clarification_reached"] = True
return "write_query_brief"
# Check total message count as final safety net
messages = state.get("messages", [])
if len(messages) > 12: # Higher threshold than before, but still a safety net
print(f"Routing: Message count safety limit reached ({len(messages)} messages)")
state["max_clarification_reached"] = True
return "write_query_brief"
# Check for completion indicators in notes (fallback for older state)
notes = state.get("notes", [])
completion_indicators = ["sufficient information", "clarification complete", "proceeding"]
if any(indicator in note.lower() for note in notes for indicator in completion_indicators):
print("Routing: Completion indicator found in notes")
return "write_query_brief"
# Default case - continue clarification
print(f"Routing: Continue clarification (attempt {clarification_attempts + 1})")
return "need_clarification"
def route_after_query_brief(state: SparrowAgentState) -> str:
"""Route after query brief creation"""
# Check if query brief exists and is adequate
if state.get("query_creation_success", False):
print("Query brief created successfully, proceeding to master subgraph")
return "master_subgraph"
# Check if we have a query brief at all
query_brief = state.get("query_brief", "").strip()
if query_brief and len(query_brief) > 10: # Lower threshold, more forgiving
print(f"Query brief exists ({len(query_brief)} chars), proceeding to master subgraph")
return "master_subgraph"
# Check if we should give up due to too many attempts
total_attempts = state.get("clarification_attempts", 0)
messages = state.get("messages", [])
if total_attempts >= 3 or len(messages) > 15:
print("Too many attempts, ending conversation")
return "__end__"
# If query brief creation failed but we haven't exceeded limits, try more clarification
if state.get("error") and total_attempts < 2:
print("Query brief creation failed, requesting more clarification")
# Reset some flags to allow more clarification
state["clarification_complete"] = False
state["needs_clarification"] = True
state.setdefault("notes", []).append("Query brief creation failed, requesting additional clarification")
return "clarify_with_user"
# Final fallback - end the conversation
print("Unable to create adequate query brief, ending conversation")
return "__end__"
def need_clarification(state: SparrowAgentState) -> SparrowAgentState:
"""Handle case where clarification is needed"""
from langchain_core.messages import AIMessage
print("Additional clarification needed.")
state["notes"] = state.get("notes", []) + ["Requested additional clarification from user"]
return state
def run_master_subgraph(state: SparrowAgentState) -> SparrowAgentState:
"""Run the master subgraph - using sync version to avoid async issues with Send"""
try:
print("Running master subgraph...")
master_input = convert_sparrow_to_master(state)
# Use invoke instead of ainvoke to avoid issues with Send
master_result = master_graph.invoke(master_input)
return update_sparrow_from_master(state, master_result)
except Exception as e:
logger.error(f"Master subgraph failed: {e}")
return {**state, "error": str(e)}
def route_after_need_clarification(state: SparrowAgentState) -> str:
"""Route after need_clarification node - always end to wait for user input"""
return "__end__"
# Build the graph
sparrowAgentBuilder = StateGraph(SparrowAgentState, input_schema=SparrowInputState)
sparrowAgentBuilder.add_node("clarify_with_user", queryNode.clarify_with_user)
sparrowAgentBuilder.add_node("need_clarification", need_clarification)
sparrowAgentBuilder.add_node("write_query_brief", queryNode.write_query_brief)
sparrowAgentBuilder.add_node("master_subgraph", run_master_subgraph)
# Edges
sparrowAgentBuilder.add_edge(START, "clarify_with_user")
sparrowAgentBuilder.add_conditional_edges(
"clarify_with_user",
route_after_clarification,
{
"need_clarification": "need_clarification",
"write_query_brief": "write_query_brief",
"__end__": END
}
)
# Improved clarification flow
sparrowAgentBuilder.add_conditional_edges(
"need_clarification",
route_after_need_clarification,
{
"clarify_with_user": "clarify_with_user",
"__end__": END
}
)
sparrowAgentBuilder.add_conditional_edges(
"write_query_brief",
route_after_query_brief,
{
"clarify_with_user": "clarify_with_user",
"master_subgraph": "master_subgraph",
"__end__": END
}
)
sparrowAgentBuilder.add_edge("master_subgraph", END)
sparrowAgent = sparrowAgentBuilder.compile() |