| | |
| | """ |
| | Router Agent for GAIA Question Classification |
| | Analyzes questions and routes them to appropriate specialized agents |
| | """ |
| |
|
| | import re |
| | import logging |
| | from typing import List, Dict, Any |
| | from urllib.parse import urlparse |
| |
|
| | from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult |
| | from models.qwen_client import QwenClient, ModelTier |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class RouterAgent: |
| | """ |
| | Router agent that classifies GAIA questions and determines processing strategy |
| | """ |
| | |
| | def __init__(self, llm_client: QwenClient): |
| | self.llm_client = llm_client |
| | |
| | def route_question(self, state: GAIAAgentState) -> GAIAAgentState: |
| | """ |
| | Main routing function - analyzes question and updates state with routing decisions |
| | """ |
| | logger.info(f"Routing question: {state.question[:100]}...") |
| | state.add_processing_step("Router: Starting question analysis") |
| | |
| | |
| | question_type = self._classify_question_type(state.question, state.file_name) |
| | state.question_type = question_type |
| | state.add_processing_step(f"Router: Classified as {question_type.value}") |
| | |
| | |
| | complexity = self._assess_complexity(state.question) |
| | state.complexity_assessment = complexity |
| | state.add_processing_step(f"Router: Assessed complexity as {complexity}") |
| | |
| | |
| | selected_agents = self._select_agents(question_type, state.file_name is not None) |
| | state.selected_agents = selected_agents |
| | state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}") |
| | |
| | |
| | estimated_cost = self._estimate_cost(complexity, selected_agents) |
| | state.estimated_cost = estimated_cost |
| | state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}") |
| | |
| | |
| | state.routing_decision = { |
| | "question_type": question_type.value, |
| | "complexity": complexity, |
| | "agents": [agent.value for agent in selected_agents], |
| | "estimated_cost": estimated_cost, |
| | "reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents) |
| | } |
| | |
| | |
| | if complexity == "complex" or question_type == QuestionType.UNKNOWN: |
| | state = self._llm_enhanced_routing(state) |
| | |
| | logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}") |
| | return state |
| | |
| | def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType: |
| | """Classify question type using rule-based analysis""" |
| | |
| | question_lower = question.lower() |
| | |
| | |
| | if file_name: |
| | file_ext = file_name.lower().split('.')[-1] if '.' in file_name else "" |
| | |
| | if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']: |
| | return QuestionType.FILE_PROCESSING |
| | elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']: |
| | return QuestionType.FILE_PROCESSING |
| | elif file_ext in ['xlsx', 'xls', 'csv']: |
| | return QuestionType.FILE_PROCESSING |
| | elif file_ext in ['py', 'js', 'java', 'cpp', 'c']: |
| | return QuestionType.CODE_EXECUTION |
| | else: |
| | return QuestionType.FILE_PROCESSING |
| | |
| | |
| | url_patterns = { |
| | QuestionType.WIKIPEDIA: [ |
| | r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia' |
| | ], |
| | QuestionType.YOUTUBE: [ |
| | r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video' |
| | ] |
| | } |
| | |
| | for question_type, patterns in url_patterns.items(): |
| | if any(re.search(pattern, question_lower) for pattern in patterns): |
| | return question_type |
| | |
| | |
| | classification_patterns = { |
| | QuestionType.MATHEMATICAL: [ |
| | r'\bcalculate\b', r'\bcompute\b', r'\bsolve\b', r'\bequation\b', r'\bformula\b', |
| | r'\bsum\b', r'\btotal\b', r'\baverage\b', r'\bpercentage\b', r'\bratio\b', |
| | r'\bhow many\b', r'\bhow much\b', r'\d+\s*[\+\-\*/]\s*\d+', r'\bmath\b', |
| | r'\bsquare root\b', r'\bfactorial\b', r'\bdivided by\b', r'\bmultiply\b' |
| | ], |
| | QuestionType.CODE_EXECUTION: [ |
| | r'\bcode\b', r'\bprogram\b', r'\bscript\b', r'\bfunction\b', r'\balgorithm\b', |
| | r'\bexecute\b', r'\brun.*code\b', r'\bpython\b', r'\bjavascript\b' |
| | ], |
| | QuestionType.TEXT_MANIPULATION: [ |
| | r'\breverse\b', r'\bencode\b', r'\bdecode\b', r'\btransform\b', r'\bconvert\b', |
| | r'\buppercase\b', r'\blowercase\b', r'\breplace\b', r'\bextract\b' |
| | ], |
| | QuestionType.REASONING: [ |
| | r'\bwhy\b', r'\bexplain\b', r'\banalyze\b', r'\breasoning\b', r'\blogic\b', |
| | r'\brelationship\b', r'\bcompare\b', r'\bcontrast\b', r'\bconclusion\b' |
| | ], |
| | QuestionType.WEB_RESEARCH: [ |
| | r'\bsearch\b', r'\bfind.*information\b', r'\bresearch\b', r'\blook up\b', |
| | r'\bwebsite\b', r'\bonline\b', r'\binternet\b', r'\bwho\s+(?:is|was|are|were)\b', |
| | r'\bwhat\s+(?:is|was|are|were)\b', r'\bwhen\s+(?:is|was|did|does)\b', |
| | r'\bwhere\s+(?:is|was|are|were)\b' |
| | ] |
| | } |
| | |
| | |
| | type_scores = {} |
| | for question_type, patterns in classification_patterns.items(): |
| | score = 0 |
| | for pattern in patterns: |
| | matches = re.findall(pattern, question_lower) |
| | score += len(matches) |
| | if score > 0: |
| | type_scores[question_type] = score |
| | |
| | |
| | |
| | |
| | if any(term in question_lower for term in ['fictional', 'imaginary', 'non-existent', 'nonexistent']): |
| | type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2 |
| | |
| | |
| | if re.search(r'\bwho\s+(?:is|was|are|were|did|does)\b', question_lower): |
| | type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2 |
| | |
| | |
| | if any(term in question_lower for term in ['history', 'historical', 'century', 'year', 'published', 'author']): |
| | type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 1 |
| | |
| | |
| | if re.search(r'\d+\s*[\+\-\*/]\s*\d+', question_lower): |
| | type_scores[QuestionType.MATHEMATICAL] = type_scores.get(QuestionType.MATHEMATICAL, 0) + 3 |
| | |
| | |
| | if type_scores: |
| | best_type = max(type_scores.keys(), key=lambda t: type_scores[t]) |
| | |
| | |
| | max_score = type_scores[best_type] |
| | if max_score <= 1: |
| | |
| | info_patterns = [r'\bwhat\b', r'\bwho\b', r'\bwhen\b', r'\bwhere\b', r'\bhow\b'] |
| | if any(re.search(pattern, question_lower) for pattern in info_patterns): |
| | return QuestionType.WEB_RESEARCH |
| | |
| | return best_type |
| | |
| | |
| | return QuestionType.WEB_RESEARCH |
| | |
| | def _assess_complexity(self, question: str) -> str: |
| | """Assess question complexity""" |
| | |
| | question_lower = question.lower() |
| | |
| | |
| | complex_indicators = [ |
| | 'multi-step', 'multiple', 'several', 'complex', 'detailed', |
| | 'analyze', 'explain why', 'reasoning', 'relationship', |
| | 'compare and contrast', 'comprehensive', 'thorough' |
| | ] |
| | |
| | |
| | simple_indicators = [ |
| | 'what is', 'who is', 'when', 'where', 'yes or no', |
| | 'true or false', 'simple', 'quick', 'name', 'list' |
| | ] |
| | |
| | complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) |
| | simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) |
| | |
| | |
| | if len(question) > 200: |
| | complex_score += 1 |
| | if len(question.split()) > 30: |
| | complex_score += 1 |
| | if question.count('?') > 2: |
| | complex_score += 1 |
| | |
| | |
| | if complex_score >= 2: |
| | return "complex" |
| | elif simple_score >= 2 and complex_score == 0: |
| | return "simple" |
| | else: |
| | return "medium" |
| | |
| | def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]: |
| | """Select appropriate agents based on question type and presence of files""" |
| | |
| | agents = [] |
| | |
| | |
| | agents.append(AgentRole.SYNTHESIZER) |
| | |
| | |
| | if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]: |
| | agents.append(AgentRole.WEB_RESEARCHER) |
| | |
| | elif question_type == QuestionType.FILE_PROCESSING: |
| | agents.append(AgentRole.FILE_PROCESSOR) |
| | |
| | elif question_type == QuestionType.CODE_EXECUTION: |
| | agents.append(AgentRole.CODE_EXECUTOR) |
| | |
| | elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]: |
| | agents.append(AgentRole.REASONING_AGENT) |
| | |
| | elif question_type == QuestionType.TEXT_MANIPULATION: |
| | agents.append(AgentRole.REASONING_AGENT) |
| | |
| | else: |
| | |
| | agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT]) |
| | if has_file: |
| | agents.append(AgentRole.FILE_PROCESSOR) |
| | |
| | |
| | seen = set() |
| | unique_agents = [] |
| | for agent in agents: |
| | if agent not in seen: |
| | seen.add(agent) |
| | unique_agents.append(agent) |
| | |
| | return unique_agents |
| | |
| | def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float: |
| | """Estimate processing cost based on complexity and agents""" |
| | |
| | base_costs = { |
| | "simple": 0.005, |
| | "medium": 0.015, |
| | "complex": 0.035 |
| | } |
| | |
| | base_cost = base_costs.get(complexity, 0.015) |
| | |
| | |
| | agent_cost = len(agents) * 0.005 |
| | |
| | return base_cost + agent_cost |
| | |
| | def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str: |
| | """Generate human-readable reasoning for routing decision""" |
| | |
| | reasons = [] |
| | |
| | |
| | if question_type == QuestionType.WIKIPEDIA: |
| | reasons.append("Question references Wikipedia content") |
| | elif question_type == QuestionType.YOUTUBE: |
| | reasons.append("Question involves YouTube video analysis") |
| | elif question_type == QuestionType.FILE_PROCESSING: |
| | reasons.append("Question requires file processing") |
| | elif question_type == QuestionType.MATHEMATICAL: |
| | reasons.append("Question involves mathematical computation") |
| | elif question_type == QuestionType.CODE_EXECUTION: |
| | reasons.append("Question requires code execution") |
| | elif question_type == QuestionType.REASONING: |
| | reasons.append("Question requires logical reasoning") |
| | |
| | |
| | if complexity == "complex": |
| | reasons.append("Complex reasoning required") |
| | elif complexity == "simple": |
| | reasons.append("Straightforward question") |
| | |
| | |
| | agent_names = [agent.value.replace('_', ' ') for agent in agents] |
| | reasons.append(f"Selected agents: {', '.join(agent_names)}") |
| | |
| | return "; ".join(reasons) |
| | |
| | def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState: |
| | """Use LLM for enhanced routing analysis of complex/unknown questions""" |
| | |
| | prompt = f""" |
| | Analyze this GAIA benchmark question and provide routing guidance: |
| | |
| | Question: {state.question} |
| | File attached: {state.file_name if state.file_name else "None"} |
| | Current classification: {state.question_type.value} |
| | Current complexity: {state.complexity_assessment} |
| | |
| | Please provide: |
| | 1. Confirm or correct the question type |
| | 2. Confirm or adjust complexity assessment |
| | 3. Key challenges in answering this question |
| | 4. Recommended approach |
| | |
| | Keep response concise and focused on routing decisions. |
| | """ |
| | |
| | try: |
| | |
| | tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN |
| | result = self.llm_client.generate(prompt, tier=tier, max_tokens=200) |
| | |
| | if result.success: |
| | state.add_processing_step("Router: Enhanced with LLM analysis") |
| | state.routing_decision["llm_analysis"] = result.response |
| | logger.info("✅ LLM enhanced routing completed") |
| | else: |
| | state.add_error(f"LLM routing enhancement failed: {result.error}") |
| | |
| | except Exception as e: |
| | state.add_error(f"LLM routing error: {str(e)}") |
| | logger.error(f"LLM routing failed: {e}") |
| | |
| | return state |