Spaces:
Sleeping
Sleeping
File size: 14,806 Bytes
225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e a248c93 225a75e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
#!/usr/bin/env python3
"""
Router Agent for GAIA Question Classification
Analyzes questions and routes them to appropriate specialized agents
"""
import re
import logging
from typing import List, Dict, Any
from urllib.parse import urlparse
from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult
from models.qwen_client import QwenClient, ModelTier
logger = logging.getLogger(__name__)
class RouterAgent:
"""
Router agent that classifies GAIA questions and determines processing strategy
"""
def __init__(self, llm_client: QwenClient):
self.llm_client = llm_client
def route_question(self, state: GAIAAgentState) -> GAIAAgentState:
"""
Main routing function - analyzes question and updates state with routing decisions
"""
logger.info(f"Routing question: {state.question[:100]}...")
state.add_processing_step("Router: Starting question analysis")
# Step 1: Rule-based classification
question_type = self._classify_question_type(state.question, state.file_name)
state.question_type = question_type
state.add_processing_step(f"Router: Classified as {question_type.value}")
# Step 2: Complexity assessment
complexity = self._assess_complexity(state.question)
state.complexity_assessment = complexity
state.add_processing_step(f"Router: Assessed complexity as {complexity}")
# Step 3: Select appropriate agents
selected_agents = self._select_agents(question_type, state.file_name is not None)
state.selected_agents = selected_agents
state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}")
# Step 4: Estimate cost
estimated_cost = self._estimate_cost(complexity, selected_agents)
state.estimated_cost = estimated_cost
state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}")
# Step 5: Create routing decision summary
state.routing_decision = {
"question_type": question_type.value,
"complexity": complexity,
"agents": [agent.value for agent in selected_agents],
"estimated_cost": estimated_cost,
"reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents)
}
# Step 6: Use LLM for complex routing decisions if needed
if complexity == "complex" or question_type == QuestionType.UNKNOWN:
state = self._llm_enhanced_routing(state)
logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}")
return state
def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType:
"""Classify question type using rule-based analysis"""
question_lower = question.lower()
# File processing questions
if file_name:
file_ext = file_name.lower().split('.')[-1] if '.' in file_name else ""
if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['xlsx', 'xls', 'csv']:
return QuestionType.FILE_PROCESSING
elif file_ext in ['py', 'js', 'java', 'cpp', 'c']:
return QuestionType.CODE_EXECUTION
else:
return QuestionType.FILE_PROCESSING
# URL-based classification
url_patterns = {
QuestionType.WIKIPEDIA: [
r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia'
],
QuestionType.YOUTUBE: [
r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video'
]
}
for question_type, patterns in url_patterns.items():
if any(re.search(pattern, question_lower) for pattern in patterns):
return question_type
# Content-based classification
classification_patterns = {
QuestionType.MATHEMATICAL: [
r'\bcalculate\b', r'\bcompute\b', r'\bsolve\b', r'\bequation\b', r'\bformula\b',
r'\bsum\b', r'\btotal\b', r'\baverage\b', r'\bpercentage\b', r'\bratio\b',
r'\bhow many\b', r'\bhow much\b', r'\d+\s*[\+\-\*/]\s*\d+', r'\bmath\b',
r'\bsquare root\b', r'\bfactorial\b', r'\bdivided by\b', r'\bmultiply\b'
],
QuestionType.CODE_EXECUTION: [
r'\bcode\b', r'\bprogram\b', r'\bscript\b', r'\bfunction\b', r'\balgorithm\b',
r'\bexecute\b', r'\brun.*code\b', r'\bpython\b', r'\bjavascript\b'
],
QuestionType.TEXT_MANIPULATION: [
r'\breverse\b', r'\bencode\b', r'\bdecode\b', r'\btransform\b', r'\bconvert\b',
r'\buppercase\b', r'\blowercase\b', r'\breplace\b', r'\bextract\b'
],
QuestionType.REASONING: [
r'\bwhy\b', r'\bexplain\b', r'\banalyze\b', r'\breasoning\b', r'\blogic\b',
r'\brelationship\b', r'\bcompare\b', r'\bcontrast\b', r'\bconclusion\b'
],
QuestionType.WEB_RESEARCH: [
r'\bsearch\b', r'\bfind.*information\b', r'\bresearch\b', r'\blook up\b',
r'\bwebsite\b', r'\bonline\b', r'\binternet\b', r'\bwho\s+(?:is|was|are|were)\b',
r'\bwhat\s+(?:is|was|are|were)\b', r'\bwhen\s+(?:is|was|did|does)\b',
r'\bwhere\s+(?:is|was|are|were)\b'
]
}
# Score each category with refined scoring
type_scores = {}
for question_type, patterns in classification_patterns.items():
score = 0
for pattern in patterns:
matches = re.findall(pattern, question_lower)
score += len(matches)
if score > 0:
type_scores[question_type] = score
# Special handling for specific question patterns
# Check for fictional/non-existent content (should be WEB_RESEARCH)
if any(term in question_lower for term in ['fictional', 'imaginary', 'non-existent', 'nonexistent']):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
# Check for research questions about people, places, things
if re.search(r'\bwho\s+(?:is|was|are|were|did|does)\b', question_lower):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 2
# Check for historical or factual queries
if any(term in question_lower for term in ['history', 'historical', 'century', 'year', 'published', 'author']):
type_scores[QuestionType.WEB_RESEARCH] = type_scores.get(QuestionType.WEB_RESEARCH, 0) + 1
# Check for specific mathematical operations (boost mathematical score)
if re.search(r'\d+\s*[\+\-\*/]\s*\d+', question_lower):
type_scores[QuestionType.MATHEMATICAL] = type_scores.get(QuestionType.MATHEMATICAL, 0) + 3
# Return highest scoring type, or WEB_RESEARCH as default for informational questions
if type_scores:
best_type = max(type_scores.keys(), key=lambda t: type_scores[t])
# If it's a tie or low score, check for general informational patterns
max_score = type_scores[best_type]
if max_score <= 1:
# Check if it's a general informational question
info_patterns = [r'\bwhat\b', r'\bwho\b', r'\bwhen\b', r'\bwhere\b', r'\bhow\b']
if any(re.search(pattern, question_lower) for pattern in info_patterns):
return QuestionType.WEB_RESEARCH
return best_type
# Default to WEB_RESEARCH for unknown informational questions
return QuestionType.WEB_RESEARCH
def _assess_complexity(self, question: str) -> str:
"""Assess question complexity"""
question_lower = question.lower()
# Complex indicators
complex_indicators = [
'multi-step', 'multiple', 'several', 'complex', 'detailed',
'analyze', 'explain why', 'reasoning', 'relationship',
'compare and contrast', 'comprehensive', 'thorough'
]
# Simple indicators
simple_indicators = [
'what is', 'who is', 'when', 'where', 'yes or no',
'true or false', 'simple', 'quick', 'name', 'list'
]
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower)
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower)
# Additional complexity factors
if len(question) > 200:
complex_score += 1
if len(question.split()) > 30:
complex_score += 1
if question.count('?') > 2: # Multiple questions
complex_score += 1
# Determine complexity
if complex_score >= 2:
return "complex"
elif simple_score >= 2 and complex_score == 0:
return "simple"
else:
return "medium"
def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]:
"""Select appropriate agents based on question type and presence of files"""
agents = []
# Always include synthesizer for final answer compilation
agents.append(AgentRole.SYNTHESIZER)
# Type-specific agent selection
if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]:
agents.append(AgentRole.WEB_RESEARCHER)
elif question_type == QuestionType.FILE_PROCESSING:
agents.append(AgentRole.FILE_PROCESSOR)
elif question_type == QuestionType.CODE_EXECUTION:
agents.append(AgentRole.CODE_EXECUTOR)
elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]:
agents.append(AgentRole.REASONING_AGENT)
elif question_type == QuestionType.TEXT_MANIPULATION:
agents.append(AgentRole.REASONING_AGENT) # Can handle text operations
else: # UNKNOWN or complex cases
# Use multiple agents for better coverage
agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT])
if has_file:
agents.append(AgentRole.FILE_PROCESSOR)
# Remove duplicates while preserving order
seen = set()
unique_agents = []
for agent in agents:
if agent not in seen:
seen.add(agent)
unique_agents.append(agent)
return unique_agents
def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float:
"""Estimate processing cost based on complexity and agents"""
base_costs = {
"simple": 0.005, # Router model mostly
"medium": 0.015, # Mix of router and main
"complex": 0.035 # Include complex model usage
}
base_cost = base_costs.get(complexity, 0.015)
# Additional cost per agent
agent_cost = len(agents) * 0.005
return base_cost + agent_cost
def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str:
"""Generate human-readable reasoning for routing decision"""
reasons = []
# Question type reasoning
if question_type == QuestionType.WIKIPEDIA:
reasons.append("Question references Wikipedia content")
elif question_type == QuestionType.YOUTUBE:
reasons.append("Question involves YouTube video analysis")
elif question_type == QuestionType.FILE_PROCESSING:
reasons.append("Question requires file processing")
elif question_type == QuestionType.MATHEMATICAL:
reasons.append("Question involves mathematical computation")
elif question_type == QuestionType.CODE_EXECUTION:
reasons.append("Question requires code execution")
elif question_type == QuestionType.REASONING:
reasons.append("Question requires logical reasoning")
# Complexity reasoning
if complexity == "complex":
reasons.append("Complex reasoning required")
elif complexity == "simple":
reasons.append("Straightforward question")
# Agent reasoning
agent_names = [agent.value.replace('_', ' ') for agent in agents]
reasons.append(f"Selected agents: {', '.join(agent_names)}")
return "; ".join(reasons)
def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState:
"""Use LLM for enhanced routing analysis of complex/unknown questions"""
prompt = f"""
Analyze this GAIA benchmark question and provide routing guidance:
Question: {state.question}
File attached: {state.file_name if state.file_name else "None"}
Current classification: {state.question_type.value}
Current complexity: {state.complexity_assessment}
Please provide:
1. Confirm or correct the question type
2. Confirm or adjust complexity assessment
3. Key challenges in answering this question
4. Recommended approach
Keep response concise and focused on routing decisions.
"""
try:
# Use router model for this analysis
tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN
result = self.llm_client.generate(prompt, tier=tier, max_tokens=200)
if result.success:
state.add_processing_step("Router: Enhanced with LLM analysis")
state.routing_decision["llm_analysis"] = result.response
logger.info("✅ LLM enhanced routing completed")
else:
state.add_error(f"LLM routing enhancement failed: {result.error}")
except Exception as e:
state.add_error(f"LLM routing error: {str(e)}")
logger.error(f"LLM routing failed: {e}")
return state |