File size: 3,541 Bytes
fe36046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Router Node - Decides which specialized agent to use"""
from typing import Dict, Any, Literal
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_groq import ChatGroq
from src.tracing import get_langfuse_callback_handler


def load_router_prompt() -> str:
    """Load the router prompt from file"""
    try:
        with open("./prompts/router_prompt.txt", "r", encoding="utf-8") as f:
            return f.read().strip()
    except FileNotFoundError:
        return """You are an intelligent agent router. Analyze the query and respond with exactly one of: RETRIEVAL, EXECUTION, or CRITIC"""


def router_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Router node that analyzes the user query and determines which agent should handle it
    Returns: next_agent = 'retrieval' | 'execution' | 'critic'
    """
    print("Router Node: Analyzing query for agent selection")
    
    try:
        # Get router prompt
        router_prompt = load_router_prompt()
        
        # Initialize LLM for routing decision
        llm = ChatGroq(model="qwen-qwq-32b", temperature=0.0)  # Low temperature for consistent routing
        
        # Get callback handler for tracing
        callback_handler = get_langfuse_callback_handler()
        callbacks = [callback_handler] if callback_handler else []
        
        # Extract the last human message for routing decision
        messages = state.get("messages", [])
        user_query = None
        
        for msg in reversed(messages):
            if msg.type == "human":
                user_query = msg.content
                break
        
        if not user_query:
            print("Router Node: No user query found, defaulting to retrieval")
            return {
                **state,
                "next_agent": "retrieval",
                "routing_reason": "No user query found"
            }
        
        # Build routing messages
        routing_messages = [
            SystemMessage(content=router_prompt),
            HumanMessage(content=f"Query to route: {user_query}")
        ]
        
        # Get routing decision
        response = llm.invoke(routing_messages, config={"callbacks": callbacks})
        routing_decision = response.content.strip().upper()
        
        # Map decision to next agent
        next_agent = "retrieval"  # Default fallback
        if "RETRIEVAL" in routing_decision:
            next_agent = "retrieval"
        elif "EXECUTION" in routing_decision:
            next_agent = "execution"
        elif "CRITIC" in routing_decision:
            next_agent = "critic"
        
        print(f"Router Node: Routing to {next_agent} agent (decision: {routing_decision})")
        
        return {
            **state,
            "next_agent": next_agent,
            "routing_decision": routing_decision,
            "routing_reason": f"Query analysis resulted in: {routing_decision}",
            "current_step": next_agent
        }
        
    except Exception as e:
        print(f"Router Node Error: {e}")
        # Fallback to retrieval agent
        return {
            **state,
            "next_agent": "retrieval",
            "routing_reason": f"Error in routing: {e}"
        }


def should_route_to_agent(state: Dict[str, Any]) -> Literal["retrieval", "execution", "critic"]:
    """
    Conditional edge function that determines which agent to route to
    """
    next_agent = state.get("next_agent", "retrieval")
    print(f"Routing to: {next_agent}")
    return next_agent