File size: 14,806 Bytes
225a75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a248c93
 
 
 
225a75e
 
a248c93
 
225a75e
 
a248c93
 
225a75e
 
a248c93
 
225a75e
 
a248c93
 
 
 
225a75e
 
 
a248c93
225a75e
 
a248c93
 
 
 
225a75e
 
 
a248c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225a75e
a248c93
 
 
 
 
 
 
 
 
 
 
225a75e
a248c93
 
225a75e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
#!/usr/bin/env python3
"""
Router Agent for GAIA Question Classification
Analyzes questions and routes them to appropriate specialized agents
"""

import re
import logging
from typing import List, Dict, Any
from urllib.parse import urlparse

from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult
from models.qwen_client import QwenClient, ModelTier

logger = logging.getLogger(__name__)

class RouterAgent:
    """
    Router agent that classifies GAIA questions and determines processing strategy
    """
    
    def __init__(self, llm_client: QwenClient):
        self.llm_client = llm_client
        
    def route_question(self, state: GAIAAgentState) -> GAIAAgentState:
        """
        Main routing function - analyzes question and updates state with routing decisions
        """
        logger.info(f"Routing question: {state.question[:100]}...")
        state.add_processing_step("Router: Starting question analysis")
        
        # Step 1: Rule-based classification
        question_type = self._classify_question_type(state.question, state.file_name)
        state.question_type = question_type
        state.add_processing_step(f"Router: Classified as {question_type.value}")
        
        # Step 2: Complexity assessment
        complexity = self._assess_complexity(state.question)
        state.complexity_assessment = complexity
        state.add_processing_step(f"Router: Assessed complexity as {complexity}")
        
        # Step 3: Select appropriate agents
        selected_agents = self._select_agents(question_type, state.file_name is not None)
        state.selected_agents = selected_agents
        state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}")
        
        # Step 4: Estimate cost
        estimated_cost = self._estimate_cost(complexity, selected_agents)
        state.estimated_cost = estimated_cost
        state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}")
        
        # Step 5: Create routing decision summary
        state.routing_decision = {
            "question_type": question_type.value,
            "complexity": complexity,
            "agents": [agent.value for agent in selected_agents],
            "estimated_cost": estimated_cost,
            "reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents)
        }
        
        # Step 6: Use LLM for complex routing decisions if needed
        if complexity == "complex" or question_type == QuestionType.UNKNOWN:
            state = self._llm_enhanced_routing(state)
        
        logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}")
        return state
    
    def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType:
        """Classify question type using rule-based analysis"""
        
        question_lower = question.lower()
        
        # File processing questions
        if file_name:
            file_ext = file_name.lower().split('.')[-1] if '.' in file_name else ""
            
            if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']:
                return QuestionType.FILE_PROCESSING
            elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
                return QuestionType.FILE_PROCESSING
            elif file_ext in ['xlsx', 'xls', 'csv']:
                return QuestionType.FILE_PROCESSING
            elif file_ext in ['py', 'js', 'java', 'cpp', 'c']:
                return QuestionType.CODE_EXECUTION
            else:
                return QuestionType.FILE_PROCESSING
        
        # URL-based classification
        url_patterns = {
            QuestionType.WIKIPEDIA: [
                r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia'
            ],
            QuestionType.YOUTUBE: [
                r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video'
            ]
        }
        
        for question_type, patterns in url_patterns.items():
            if any(re.search(pattern, question_lower) for pattern in patterns):
                return question_type
        
        # Content-based classification
        classification_patterns = {
            QuestionType.MATHEMATICAL: [
                r'\bcalculate\b', r'\bcompute\b', r'\bsolve\b', r'\bequation\b', r'\bformula\b',
                r'\bsum\b', r'\btotal\b', r'\baverage\b', r'\bpercentage\b', r'\bratio\b',
                r'\bhow many\b', r'\bhow much\b', r'\d+\s*[\+\-\*/]\s*\d+', r'\bmath\b',
                r'\bsquare root\b', r'\bfactorial\b', r'\bdivided by\b', r'\bmultiply\b'
            ],
            QuestionType.CODE_EXECUTION: [
                r'\bcode\b', r'\bprogram\b', r'\bscript\b', r'\bfunction\b', r'\balgorithm\b',
                r'\bexecute\b', r'\brun.*code\b', r'\bpython\b', r'\bjavascript\b'
            ],
            QuestionType.TEXT_MANIPULATION: [
                r'\breverse\b', r'\bencode\b', r'\bdecode\b', r'\btransform\b', r'\bconvert\b',
                r'\buppercase\b', r'\blowercase\b', r'\breplace\b', r'\bextract\b'
            ],
            QuestionType.REASONING: [
                r'\bwhy\b', r'\bexplain\b', r'\banalyze\b', r'\breasoning\b', r'\blogic\b',
                r'\brelationship\b', r'\bcompare\b', r'\bcontrast\b', r'\bconclusion\b'
            ],
            QuestionType.WEB_RESEARCH: [
                r'\bsearch\b', r'\bfind.*information\b', r'\bresearch\b', r'\blook up\b',
                r'\bwebsite\b', r'\bonline\b', r'\binternet\b', r'\bwho\s+(?:is|was|are|were)\b',
                r'\bwhat\s+(?:is|was|are|were)\b', r'\bwhen\s+(?:is|was|did|does)\b',
                r'\bwhere\s+(?:is|was|are|were)\b'
            ]
        }
        
        # Score each category with refined scoring
        type_scores = {}
        for question_type, patterns in classification_patterns.items():
            score = 0
            for pattern in patterns:
                matches = re.findall(pattern, question_lower)
                score += len(matches)
            if score > 0:
                type_scores[question_type] = score
        
        # Special handling for specific question patterns
        
        # Check for fictional/non-existent content (should be WEB_RESEARCH)
        if any(term in question_lower for term in ['fictional', 'imaginary', 'non-existent', 'nonexistent']):
            type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
        
        # Check for research questions about people, places, things
        if re.search(r'\bwho\s+(?:is|was|are|were|did|does)\b', question_lower):
            type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
        
        # Check for historical or factual queries
        if any(term in question_lower for term in ['history', 'historical', 'century', 'year', 'published', 'author']):
            type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 1
        
        # Check for specific mathematical operations (boost mathematical score)
        if re.search(r'\d+\s*[\+\-\*/]\s*\d+', question_lower):
            type_scores[QuestionType.MATHEMATICAL] = type_scores.get(QuestionType.MATHEMATICAL, 0) + 3
        
        # Return highest scoring type, or WEB_RESEARCH as default for informational questions
        if type_scores:
            best_type = max(type_scores.keys(), key=lambda t: type_scores[t])
            
            # If it's a tie or low score, check for general informational patterns
            max_score = type_scores[best_type]
            if max_score <= 1:
                # Check if it's a general informational question
                info_patterns = [r'\bwhat\b', r'\bwho\b', r'\bwhen\b', r'\bwhere\b', r'\bhow\b']
                if any(re.search(pattern, question_lower) for pattern in info_patterns):
                    return QuestionType.WEB_RESEARCH
            
            return best_type
        
        # Default to WEB_RESEARCH for unknown informational questions
        return QuestionType.WEB_RESEARCH
    
    def _assess_complexity(self, question: str) -> str:
        """Assess question complexity"""
        
        question_lower = question.lower()
        
        # Complex indicators
        complex_indicators = [
            'multi-step', 'multiple', 'several', 'complex', 'detailed',
            'analyze', 'explain why', 'reasoning', 'relationship',
            'compare and contrast', 'comprehensive', 'thorough'
        ]
        
        # Simple indicators  
        simple_indicators = [
            'what is', 'who is', 'when', 'where', 'yes or no',
            'true or false', 'simple', 'quick', 'name', 'list'
        ]
        
        complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower)
        simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower)
        
        # Additional complexity factors
        if len(question) > 200:
            complex_score += 1
        if len(question.split()) > 30:
            complex_score += 1
        if question.count('?') > 2:  # Multiple questions
            complex_score += 1
        
        # Determine complexity
        if complex_score >= 2:
            return "complex"
        elif simple_score >= 2 and complex_score == 0:
            return "simple"
        else:
            return "medium"
    
    def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]:
        """Select appropriate agents based on question type and presence of files"""
        
        agents = []
        
        # Always include synthesizer for final answer compilation
        agents.append(AgentRole.SYNTHESIZER)
        
        # Type-specific agent selection
        if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]:
            agents.append(AgentRole.WEB_RESEARCHER)
            
        elif question_type == QuestionType.FILE_PROCESSING:
            agents.append(AgentRole.FILE_PROCESSOR)
            
        elif question_type == QuestionType.CODE_EXECUTION:
            agents.append(AgentRole.CODE_EXECUTOR)
            
        elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]:
            agents.append(AgentRole.REASONING_AGENT)
            
        elif question_type == QuestionType.TEXT_MANIPULATION:
            agents.append(AgentRole.REASONING_AGENT)  # Can handle text operations
            
        else:  # UNKNOWN or complex cases
            # Use multiple agents for better coverage
            agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT])
            if has_file:
                agents.append(AgentRole.FILE_PROCESSOR)
        
        # Remove duplicates while preserving order
        seen = set()
        unique_agents = []
        for agent in agents:
            if agent not in seen:
                seen.add(agent)
                unique_agents.append(agent)
                
        return unique_agents
    
    def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float:
        """Estimate processing cost based on complexity and agents"""
        
        base_costs = {
            "simple": 0.005,   # Router model mostly
            "medium": 0.015,   # Mix of router and main
            "complex": 0.035   # Include complex model usage
        }
        
        base_cost = base_costs.get(complexity, 0.015)
        
        # Additional cost per agent
        agent_cost = len(agents) * 0.005
        
        return base_cost + agent_cost
    
    def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str:
        """Generate human-readable reasoning for routing decision"""
        
        reasons = []
        
        # Question type reasoning
        if question_type == QuestionType.WIKIPEDIA:
            reasons.append("Question references Wikipedia content")
        elif question_type == QuestionType.YOUTUBE:
            reasons.append("Question involves YouTube video analysis")
        elif question_type == QuestionType.FILE_PROCESSING:
            reasons.append("Question requires file processing")
        elif question_type == QuestionType.MATHEMATICAL:
            reasons.append("Question involves mathematical computation")
        elif question_type == QuestionType.CODE_EXECUTION:
            reasons.append("Question requires code execution")
        elif question_type == QuestionType.REASONING:
            reasons.append("Question requires logical reasoning")
        
        # Complexity reasoning
        if complexity == "complex":
            reasons.append("Complex reasoning required")
        elif complexity == "simple":
            reasons.append("Straightforward question")
        
        # Agent reasoning
        agent_names = [agent.value.replace('_', ' ') for agent in agents]
        reasons.append(f"Selected agents: {', '.join(agent_names)}")
        
        return "; ".join(reasons)
    
    def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState:
        """Use LLM for enhanced routing analysis of complex/unknown questions"""
        
        prompt = f"""
        Analyze this GAIA benchmark question and provide routing guidance:
        
        Question: {state.question}
        File attached: {state.file_name if state.file_name else "None"}
        Current classification: {state.question_type.value}
        Current complexity: {state.complexity_assessment}
        
        Please provide:
        1. Confirm or correct the question type
        2. Confirm or adjust complexity assessment  
        3. Key challenges in answering this question
        4. Recommended approach
        
        Keep response concise and focused on routing decisions.
        """
        
        try:
            # Use router model for this analysis
            tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN
            result = self.llm_client.generate(prompt, tier=tier, max_tokens=200)
            
            if result.success:
                state.add_processing_step("Router: Enhanced with LLM analysis")
                state.routing_decision["llm_analysis"] = result.response
                logger.info("✅ LLM enhanced routing completed")
            else:
                state.add_error(f"LLM routing enhancement failed: {result.error}")
                
        except Exception as e:
            state.add_error(f"LLM routing error: {str(e)}")
            logger.error(f"LLM routing failed: {e}")
        
        return state