Chris
Final 5.0
a248c93
raw
history blame
14.8 kB
#!/usr/bin/env python3
"""
Router Agent for GAIA Question Classification
Analyzes questions and routes them to appropriate specialized agents
"""
import re
import logging
from typing import List, Dict, Any
from urllib.parse import urlparse
from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult
from models.qwen_client import QwenClient, ModelTier
logger = logging.getLogger(__name__)
class RouterAgent:
"""
Router agent that classifies GAIA questions and determines processing strategy
"""
def __init__(self, llm_client: QwenClient):
self.llm_client = llm_client
def route_question(self, state: GAIAAgentState) -> GAIAAgentState:
"""
Main routing function - analyzes question and updates state with routing decisions
"""
logger.info(f"Routing question: {state.question[:100]}...")
state.add_processing_step("Router: Starting question analysis")
# Step 1: Rule-based classification
question_type = self._classify_question_type(state.question, state.file_name)
state.question_type = question_type
state.add_processing_step(f"Router: Classified as {question_type.value}")
# Step 2: Complexity assessment
complexity = self._assess_complexity(state.question)
state.complexity_assessment = complexity
state.add_processing_step(f"Router: Assessed complexity as {complexity}")
# Step 3: Select appropriate agents
selected_agents = self._select_agents(question_type, state.file_name is not None)
state.selected_agents = selected_agents
state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}")
# Step 4: Estimate cost
estimated_cost = self._estimate_cost(complexity, selected_agents)
state.estimated_cost = estimated_cost
state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}")
# Step 5: Create routing decision summary
state.routing_decision = {
"question_type": question_type.value,
"complexity": complexity,
"agents": [agent.value for agent in selected_agents],
"estimated_cost": estimated_cost,
"reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents)
}
# Step 6: Use LLM for complex routing decisions if needed
if complexity == "complex" or question_type == QuestionType.UNKNOWN:
state = self._llm_enhanced_routing(state)
logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}")
return state
def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType:
"""Classify question type using rule-based analysis"""
question_lower = question.lower()
# File processing questions
if file_name:
file_ext = file_name.lower().split('.')[-1] if '.' in file_name else ""
if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['xlsx', 'xls', 'csv']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['py', 'js', 'java', 'cpp', 'c']:
return QuestionType.CODE_EXECUTION
else:
return QuestionType.FILE_PROCESSING
# URL-based classification
url_patterns = {
QuestionType.WIKIPEDIA: [
r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia'
],
QuestionType.YOUTUBE: [
r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video'
]
}
for question_type, patterns in url_patterns.items():
if any(re.search(pattern, question_lower) for pattern in patterns):
return question_type
# Content-based classification
classification_patterns = {
QuestionType.MATHEMATICAL: [
r'\bcalculate\b', r'\bcompute\b', r'\bsolve\b', r'\bequation\b', r'\bformula\b',
r'\bsum\b', r'\btotal\b', r'\baverage\b', r'\bpercentage\b', r'\bratio\b',
r'\bhow many\b', r'\bhow much\b', r'\d+\s*[\+\-\*/]\s*\d+', r'\bmath\b',
r'\bsquare root\b', r'\bfactorial\b', r'\bdivided by\b', r'\bmultiply\b'
],
QuestionType.CODE_EXECUTION: [
r'\bcode\b', r'\bprogram\b', r'\bscript\b', r'\bfunction\b', r'\balgorithm\b',
r'\bexecute\b', r'\brun.*code\b', r'\bpython\b', r'\bjavascript\b'
],
QuestionType.TEXT_MANIPULATION: [
r'\breverse\b', r'\bencode\b', r'\bdecode\b', r'\btransform\b', r'\bconvert\b',
r'\buppercase\b', r'\blowercase\b', r'\breplace\b', r'\bextract\b'
],
QuestionType.REASONING: [
r'\bwhy\b', r'\bexplain\b', r'\banalyze\b', r'\breasoning\b', r'\blogic\b',
r'\brelationship\b', r'\bcompare\b', r'\bcontrast\b', r'\bconclusion\b'
],
QuestionType.WEB_RESEARCH: [
r'\bsearch\b', r'\bfind.*information\b', r'\bresearch\b', r'\blook up\b',
r'\bwebsite\b', r'\bonline\b', r'\binternet\b', r'\bwho\s+(?:is|was|are|were)\b',
r'\bwhat\s+(?:is|was|are|were)\b', r'\bwhen\s+(?:is|was|did|does)\b',
r'\bwhere\s+(?:is|was|are|were)\b'
]
}
# Score each category with refined scoring
type_scores = {}
for question_type, patterns in classification_patterns.items():
score = 0
for pattern in patterns:
matches = re.findall(pattern, question_lower)
score += len(matches)
if score > 0:
type_scores[question_type] = score
# Special handling for specific question patterns
# Check for fictional/non-existent content (should be WEB_RESEARCH)
if any(term in question_lower for term in ['fictional', 'imaginary', 'non-existent', 'nonexistent']):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
# Check for research questions about people, places, things
if re.search(r'\bwho\s+(?:is|was|are|were|did|does)\b', question_lower):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
# Check for historical or factual queries
if any(term in question_lower for term in ['history', 'historical', 'century', 'year', 'published', 'author']):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 1
# Check for specific mathematical operations (boost mathematical score)
if re.search(r'\d+\s*[\+\-\*/]\s*\d+', question_lower):
type_scores[QuestionType.MATHEMATICAL] = type_scores.get(QuestionType.MATHEMATICAL, 0) + 3
# Return highest scoring type, or WEB_RESEARCH as default for informational questions
if type_scores:
best_type = max(type_scores.keys(), key=lambda t: type_scores[t])
# If it's a tie or low score, check for general informational patterns
max_score = type_scores[best_type]
if max_score <= 1:
# Check if it's a general informational question
info_patterns = [r'\bwhat\b', r'\bwho\b', r'\bwhen\b', r'\bwhere\b', r'\bhow\b']
if any(re.search(pattern, question_lower) for pattern in info_patterns):
return QuestionType.WEB_RESEARCH
return best_type
# Default to WEB_RESEARCH for unknown informational questions
return QuestionType.WEB_RESEARCH
def _assess_complexity(self, question: str) -> str:
"""Assess question complexity"""
question_lower = question.lower()
# Complex indicators
complex_indicators = [
'multi-step', 'multiple', 'several', 'complex', 'detailed',
'analyze', 'explain why', 'reasoning', 'relationship',
'compare and contrast', 'comprehensive', 'thorough'
]
# Simple indicators
simple_indicators = [
'what is', 'who is', 'when', 'where', 'yes or no',
'true or false', 'simple', 'quick', 'name', 'list'
]
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower)
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower)
# Additional complexity factors
if len(question) > 200:
complex_score += 1
if len(question.split()) > 30:
complex_score += 1
if question.count('?') > 2: # Multiple questions
complex_score += 1
# Determine complexity
if complex_score >= 2:
return "complex"
elif simple_score >= 2 and complex_score == 0:
return "simple"
else:
return "medium"
def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]:
"""Select appropriate agents based on question type and presence of files"""
agents = []
# Always include synthesizer for final answer compilation
agents.append(AgentRole.SYNTHESIZER)
# Type-specific agent selection
if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]:
agents.append(AgentRole.WEB_RESEARCHER)
elif question_type == QuestionType.FILE_PROCESSING:
agents.append(AgentRole.FILE_PROCESSOR)
elif question_type == QuestionType.CODE_EXECUTION:
agents.append(AgentRole.CODE_EXECUTOR)
elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]:
agents.append(AgentRole.REASONING_AGENT)
elif question_type == QuestionType.TEXT_MANIPULATION:
agents.append(AgentRole.REASONING_AGENT) # Can handle text operations
else: # UNKNOWN or complex cases
# Use multiple agents for better coverage
agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT])
if has_file:
agents.append(AgentRole.FILE_PROCESSOR)
# Remove duplicates while preserving order
seen = set()
unique_agents = []
for agent in agents:
if agent not in seen:
seen.add(agent)
unique_agents.append(agent)
return unique_agents
def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float:
"""Estimate processing cost based on complexity and agents"""
base_costs = {
"simple": 0.005, # Router model mostly
"medium": 0.015, # Mix of router and main
"complex": 0.035 # Include complex model usage
}
base_cost = base_costs.get(complexity, 0.015)
# Additional cost per agent
agent_cost = len(agents) * 0.005
return base_cost + agent_cost
def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str:
"""Generate human-readable reasoning for routing decision"""
reasons = []
# Question type reasoning
if question_type == QuestionType.WIKIPEDIA:
reasons.append("Question references Wikipedia content")
elif question_type == QuestionType.YOUTUBE:
reasons.append("Question involves YouTube video analysis")
elif question_type == QuestionType.FILE_PROCESSING:
reasons.append("Question requires file processing")
elif question_type == QuestionType.MATHEMATICAL:
reasons.append("Question involves mathematical computation")
elif question_type == QuestionType.CODE_EXECUTION:
reasons.append("Question requires code execution")
elif question_type == QuestionType.REASONING:
reasons.append("Question requires logical reasoning")
# Complexity reasoning
if complexity == "complex":
reasons.append("Complex reasoning required")
elif complexity == "simple":
reasons.append("Straightforward question")
# Agent reasoning
agent_names = [agent.value.replace('_', ' ') for agent in agents]
reasons.append(f"Selected agents: {', '.join(agent_names)}")
return "; ".join(reasons)
def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState:
"""Use LLM for enhanced routing analysis of complex/unknown questions"""
prompt = f"""
Analyze this GAIA benchmark question and provide routing guidance:
Question: {state.question}
File attached: {state.file_name if state.file_name else "None"}
Current classification: {state.question_type.value}
Current complexity: {state.complexity_assessment}
Please provide:
1. Confirm or correct the question type
2. Confirm or adjust complexity assessment
3. Key challenges in answering this question
4. Recommended approach
Keep response concise and focused on routing decisions.
"""
try:
# Use router model for this analysis
tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN
result = self.llm_client.generate(prompt, tier=tier, max_tokens=200)
if result.success:
state.add_processing_step("Router: Enhanced with LLM analysis")
state.routing_decision["llm_analysis"] = result.response
logger.info("✅ LLM enhanced routing completed")
else:
state.add_error(f"LLM routing enhancement failed: {result.error}")
except Exception as e:
state.add_error(f"LLM routing error: {str(e)}")
logger.error(f"LLM routing failed: {e}")
return state