|
|
"""Router Node - Decides which specialized agent to use""" |
|
|
from typing import Dict, Any, Literal |
|
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
|
from langchain_groq import ChatGroq |
|
|
from src.tracing import get_langfuse_callback_handler |
|
|
|
|
|
|
|
|
def load_router_prompt() -> str: |
|
|
"""Load the router prompt from file""" |
|
|
try: |
|
|
with open("./prompts/router_prompt.txt", "r", encoding="utf-8") as f: |
|
|
return f.read().strip() |
|
|
except FileNotFoundError: |
|
|
return """You are an intelligent agent router. Analyze the query and respond with exactly one of: RETRIEVAL, EXECUTION, or CRITIC""" |
|
|
|
|
|
|
|
|
def router_node(state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Router node that analyzes the user query and determines which agent should handle it |
|
|
Returns: next_agent = 'retrieval' | 'execution' | 'critic' |
|
|
""" |
|
|
print("Router Node: Analyzing query for agent selection") |
|
|
|
|
|
try: |
|
|
|
|
|
router_prompt = load_router_prompt() |
|
|
|
|
|
|
|
|
llm = ChatGroq(model="qwen-qwq-32b", temperature=0.0) |
|
|
|
|
|
|
|
|
callback_handler = get_langfuse_callback_handler() |
|
|
callbacks = [callback_handler] if callback_handler else [] |
|
|
|
|
|
|
|
|
messages = state.get("messages", []) |
|
|
user_query = None |
|
|
|
|
|
for msg in reversed(messages): |
|
|
if msg.type == "human": |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
if not user_query: |
|
|
print("Router Node: No user query found, defaulting to retrieval") |
|
|
return { |
|
|
**state, |
|
|
"next_agent": "retrieval", |
|
|
"routing_reason": "No user query found" |
|
|
} |
|
|
|
|
|
|
|
|
routing_messages = [ |
|
|
SystemMessage(content=router_prompt), |
|
|
HumanMessage(content=f"Query to route: {user_query}") |
|
|
] |
|
|
|
|
|
|
|
|
response = llm.invoke(routing_messages, config={"callbacks": callbacks}) |
|
|
routing_decision = response.content.strip().upper() |
|
|
|
|
|
|
|
|
next_agent = "retrieval" |
|
|
if "RETRIEVAL" in routing_decision: |
|
|
next_agent = "retrieval" |
|
|
elif "EXECUTION" in routing_decision: |
|
|
next_agent = "execution" |
|
|
elif "CRITIC" in routing_decision: |
|
|
next_agent = "critic" |
|
|
|
|
|
print(f"Router Node: Routing to {next_agent} agent (decision: {routing_decision})") |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"next_agent": next_agent, |
|
|
"routing_decision": routing_decision, |
|
|
"routing_reason": f"Query analysis resulted in: {routing_decision}", |
|
|
"current_step": next_agent |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Router Node Error: {e}") |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"next_agent": "retrieval", |
|
|
"routing_reason": f"Error in routing: {e}" |
|
|
} |
|
|
|
|
|
|
|
|
def should_route_to_agent(state: Dict[str, Any]) -> Literal["retrieval", "execution", "critic"]: |
|
|
""" |
|
|
Conditional edge function that determines which agent to route to |
|
|
""" |
|
|
next_agent = state.get("next_agent", "retrieval") |
|
|
print(f"Routing to: {next_agent}") |
|
|
return next_agent |