Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Router Agent for GAIA Question Classification | |
| Analyzes questions and routes them to appropriate specialized agents | |
| """ | |
| import re | |
| import logging | |
| from typing import List, Dict, Any | |
| from urllib.parse import urlparse | |
| from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult | |
| from models.qwen_client import QwenClient, ModelTier | |
| logger = logging.getLogger(__name__) | |
| class RouterAgent: | |
| """ | |
| Router agent that classifies GAIA questions and determines processing strategy | |
| """ | |
| def __init__(self, llm_client: QwenClient): | |
| self.llm_client = llm_client | |
| def route_question(self, state: GAIAAgentState) -> GAIAAgentState: | |
| """ | |
| Main routing function - analyzes question and updates state with routing decisions | |
| """ | |
| logger.info(f"Routing question: {state.question[:100]}...") | |
| state.add_processing_step("Router: Starting question analysis") | |
| # Step 1: Rule-based classification | |
| question_type = self._classify_question_type(state.question, state.file_name) | |
| state.question_type = question_type | |
| state.add_processing_step(f"Router: Classified as {question_type.value}") | |
| # Step 2: Complexity assessment | |
| complexity = self._assess_complexity(state.question) | |
| state.complexity_assessment = complexity | |
| state.add_processing_step(f"Router: Assessed complexity as {complexity}") | |
| # Step 3: Select appropriate agents | |
| selected_agents = self._select_agents(question_type, state.file_name is not None) | |
| state.selected_agents = selected_agents | |
| state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}") | |
| # Step 4: Estimate cost | |
| estimated_cost = self._estimate_cost(complexity, selected_agents) | |
| state.estimated_cost = estimated_cost | |
| state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}") | |
| # Step 5: Create routing decision summary | |
| state.routing_decision = { | |
| "question_type": question_type.value, | |
| "complexity": complexity, | |
| "agents": [agent.value for agent in selected_agents], | |
| "estimated_cost": estimated_cost, | |
| "reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents) | |
| } | |
| # Step 6: Use LLM for complex routing decisions if needed | |
| if complexity == "complex" or question_type == QuestionType.UNKNOWN: | |
| state = self._llm_enhanced_routing(state) | |
| logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}") | |
| return state | |
| def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType: | |
| """Classify question type using rule-based analysis""" | |
| question_lower = question.lower() | |
| # File processing questions | |
| if file_name: | |
| file_ext = file_name.lower().split('.')[-1] if '.' in file_name else "" | |
| if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']: | |
| return QuestionType.FILE_PROCESSING | |
| elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']: | |
| return QuestionType.FILE_PROCESSING | |
| elif file_ext in ['xlsx', 'xls', 'csv']: | |
| return QuestionType.FILE_PROCESSING | |
| elif file_ext in ['py', 'js', 'java', 'cpp', 'c']: | |
| return QuestionType.CODE_EXECUTION | |
| else: | |
| return QuestionType.FILE_PROCESSING | |
| # URL-based classification | |
| url_patterns = { | |
| QuestionType.WIKIPEDIA: [ | |
| r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia' | |
| ], | |
| QuestionType.YOUTUBE: [ | |
| r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video' | |
| ] | |
| } | |
| for question_type, patterns in url_patterns.items(): | |
| if any(re.search(pattern, question_lower) for pattern in patterns): | |
| return question_type | |
| # Content-based classification | |
| classification_patterns = { | |
| QuestionType.MATHEMATICAL: [ | |
| r'calculate', r'compute', r'solve', r'equation', r'formula', | |
| r'sum', r'total', r'average', r'percentage', r'ratio', | |
| r'how many', r'how much', r'\d+.*\d+', r'math' | |
| ], | |
| QuestionType.CODE_EXECUTION: [ | |
| r'code', r'program', r'script', r'function', r'algorithm', | |
| r'execute', r'run.*code', r'python', r'javascript' | |
| ], | |
| QuestionType.TEXT_MANIPULATION: [ | |
| r'reverse', r'encode', r'decode', r'transform', r'convert', | |
| r'uppercase', r'lowercase', r'replace', r'extract' | |
| ], | |
| QuestionType.REASONING: [ | |
| r'why', r'explain', r'analyze', r'reasoning', r'logic', | |
| r'relationship', r'compare', r'contrast', r'conclusion' | |
| ], | |
| QuestionType.WEB_RESEARCH: [ | |
| r'search', r'find.*information', r'research', r'look up', | |
| r'website', r'online', r'internet' | |
| ] | |
| } | |
| # Score each category | |
| type_scores = {} | |
| for question_type, patterns in classification_patterns.items(): | |
| score = sum(1 for pattern in patterns if re.search(pattern, question_lower)) | |
| if score > 0: | |
| type_scores[question_type] = score | |
| # Return highest scoring type, or UNKNOWN if no clear match | |
| if type_scores: | |
| return max(type_scores.keys(), key=lambda t: type_scores[t]) | |
| return QuestionType.UNKNOWN | |
| def _assess_complexity(self, question: str) -> str: | |
| """Assess question complexity""" | |
| question_lower = question.lower() | |
| # Complex indicators | |
| complex_indicators = [ | |
| 'multi-step', 'multiple', 'several', 'complex', 'detailed', | |
| 'analyze', 'explain why', 'reasoning', 'relationship', | |
| 'compare and contrast', 'comprehensive', 'thorough' | |
| ] | |
| # Simple indicators | |
| simple_indicators = [ | |
| 'what is', 'who is', 'when', 'where', 'yes or no', | |
| 'true or false', 'simple', 'quick', 'name', 'list' | |
| ] | |
| complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) | |
| simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) | |
| # Additional complexity factors | |
| if len(question) > 200: | |
| complex_score += 1 | |
| if len(question.split()) > 30: | |
| complex_score += 1 | |
| if question.count('?') > 2: # Multiple questions | |
| complex_score += 1 | |
| # Determine complexity | |
| if complex_score >= 2: | |
| return "complex" | |
| elif simple_score >= 2 and complex_score == 0: | |
| return "simple" | |
| else: | |
| return "medium" | |
| def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]: | |
| """Select appropriate agents based on question type and presence of files""" | |
| agents = [] | |
| # Always include synthesizer for final answer compilation | |
| agents.append(AgentRole.SYNTHESIZER) | |
| # Type-specific agent selection | |
| if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]: | |
| agents.append(AgentRole.WEB_RESEARCHER) | |
| elif question_type == QuestionType.FILE_PROCESSING: | |
| agents.append(AgentRole.FILE_PROCESSOR) | |
| elif question_type == QuestionType.CODE_EXECUTION: | |
| agents.append(AgentRole.CODE_EXECUTOR) | |
| elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]: | |
| agents.append(AgentRole.REASONING_AGENT) | |
| elif question_type == QuestionType.TEXT_MANIPULATION: | |
| agents.append(AgentRole.REASONING_AGENT) # Can handle text operations | |
| else: # UNKNOWN or complex cases | |
| # Use multiple agents for better coverage | |
| agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT]) | |
| if has_file: | |
| agents.append(AgentRole.FILE_PROCESSOR) | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| unique_agents = [] | |
| for agent in agents: | |
| if agent not in seen: | |
| seen.add(agent) | |
| unique_agents.append(agent) | |
| return unique_agents | |
| def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float: | |
| """Estimate processing cost based on complexity and agents""" | |
| base_costs = { | |
| "simple": 0.005, # Router model mostly | |
| "medium": 0.015, # Mix of router and main | |
| "complex": 0.035 # Include complex model usage | |
| } | |
| base_cost = base_costs.get(complexity, 0.015) | |
| # Additional cost per agent | |
| agent_cost = len(agents) * 0.005 | |
| return base_cost + agent_cost | |
| def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str: | |
| """Generate human-readable reasoning for routing decision""" | |
| reasons = [] | |
| # Question type reasoning | |
| if question_type == QuestionType.WIKIPEDIA: | |
| reasons.append("Question references Wikipedia content") | |
| elif question_type == QuestionType.YOUTUBE: | |
| reasons.append("Question involves YouTube video analysis") | |
| elif question_type == QuestionType.FILE_PROCESSING: | |
| reasons.append("Question requires file processing") | |
| elif question_type == QuestionType.MATHEMATICAL: | |
| reasons.append("Question involves mathematical computation") | |
| elif question_type == QuestionType.CODE_EXECUTION: | |
| reasons.append("Question requires code execution") | |
| elif question_type == QuestionType.REASONING: | |
| reasons.append("Question requires logical reasoning") | |
| # Complexity reasoning | |
| if complexity == "complex": | |
| reasons.append("Complex reasoning required") | |
| elif complexity == "simple": | |
| reasons.append("Straightforward question") | |
| # Agent reasoning | |
| agent_names = [agent.value.replace('_', ' ') for agent in agents] | |
| reasons.append(f"Selected agents: {', '.join(agent_names)}") | |
| return "; ".join(reasons) | |
| def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState: | |
| """Use LLM for enhanced routing analysis of complex/unknown questions""" | |
| prompt = f""" | |
| Analyze this GAIA benchmark question and provide routing guidance: | |
| Question: {state.question} | |
| File attached: {state.file_name if state.file_name else "None"} | |
| Current classification: {state.question_type.value} | |
| Current complexity: {state.complexity_assessment} | |
| Please provide: | |
| 1. Confirm or correct the question type | |
| 2. Confirm or adjust complexity assessment | |
| 3. Key challenges in answering this question | |
| 4. Recommended approach | |
| Keep response concise and focused on routing decisions. | |
| """ | |
| try: | |
| # Use router model for this analysis | |
| tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN | |
| result = self.llm_client.generate(prompt, tier=tier, max_tokens=200) | |
| if result.success: | |
| state.add_processing_step("Router: Enhanced with LLM analysis") | |
| state.routing_decision["llm_analysis"] = result.response | |
| logger.info("✅ LLM enhanced routing completed") | |
| else: | |
| state.add_error(f"LLM routing enhancement failed: {result.error}") | |
| except Exception as e: | |
| state.add_error(f"LLM routing error: {str(e)}") | |
| logger.error(f"LLM routing failed: {e}") | |
| return state |