|
|
|
|
|
""" |
|
|
Router Agent for GAIA Question Classification |
|
|
Analyzes questions and routes them to appropriate specialized agents |
|
|
""" |
|
|
|
|
|
import re |
|
|
import logging |
|
|
from typing import List, Dict, Any |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult |
|
|
from models.qwen_client import QwenClient, ModelTier |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RouterAgent: |
|
|
""" |
|
|
Router agent that classifies GAIA questions and determines processing strategy |
|
|
""" |
|
|
|
|
|
def __init__(self, llm_client: QwenClient): |
|
|
self.llm_client = llm_client |
|
|
|
|
|
def route_question(self, state: GAIAAgentState) -> GAIAAgentState: |
|
|
""" |
|
|
Main routing function - analyzes question and updates state with routing decisions |
|
|
""" |
|
|
logger.info(f"Routing question: {state.question[:100]}...") |
|
|
state.add_processing_step("Router: Starting question analysis") |
|
|
|
|
|
|
|
|
question_type = self._classify_question_type(state.question, state.file_name) |
|
|
state.question_type = question_type |
|
|
state.add_processing_step(f"Router: Classified as {question_type.value}") |
|
|
|
|
|
|
|
|
complexity = self._assess_complexity(state.question) |
|
|
state.complexity_assessment = complexity |
|
|
state.add_processing_step(f"Router: Assessed complexity as {complexity}") |
|
|
|
|
|
|
|
|
selected_agents = self._select_agents(question_type, state.file_name is not None) |
|
|
state.selected_agents = selected_agents |
|
|
state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}") |
|
|
|
|
|
|
|
|
estimated_cost = self._estimate_cost(complexity, selected_agents) |
|
|
state.estimated_cost = estimated_cost |
|
|
state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}") |
|
|
|
|
|
|
|
|
state.routing_decision = { |
|
|
"question_type": question_type.value, |
|
|
"complexity": complexity, |
|
|
"agents": [agent.value for agent in selected_agents], |
|
|
"estimated_cost": estimated_cost, |
|
|
"reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents) |
|
|
} |
|
|
|
|
|
|
|
|
if complexity == "complex" or question_type == QuestionType.UNKNOWN: |
|
|
state = self._llm_enhanced_routing(state) |
|
|
|
|
|
logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}") |
|
|
return state |
|
|
|
|
|
def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType: |
|
|
"""Classify question type using rule-based analysis""" |
|
|
|
|
|
question_lower = question.lower() |
|
|
|
|
|
|
|
|
if file_name: |
|
|
file_ext = file_name.lower().split('.')[-1] if '.' in file_name else "" |
|
|
|
|
|
if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']: |
|
|
return QuestionType.FILE_PROCESSING |
|
|
elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']: |
|
|
return QuestionType.FILE_PROCESSING |
|
|
elif file_ext in ['xlsx', 'xls', 'csv']: |
|
|
return QuestionType.FILE_PROCESSING |
|
|
elif file_ext in ['py', 'js', 'java', 'cpp', 'c']: |
|
|
return QuestionType.CODE_EXECUTION |
|
|
else: |
|
|
return QuestionType.FILE_PROCESSING |
|
|
|
|
|
|
|
|
url_patterns = { |
|
|
QuestionType.WIKIPEDIA: [ |
|
|
r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia' |
|
|
], |
|
|
QuestionType.YOUTUBE: [ |
|
|
r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video' |
|
|
] |
|
|
} |
|
|
|
|
|
for question_type, patterns in url_patterns.items(): |
|
|
if any(re.search(pattern, question_lower) for pattern in patterns): |
|
|
return question_type |
|
|
|
|
|
|
|
|
classification_patterns = { |
|
|
QuestionType.MATHEMATICAL: [ |
|
|
r'calculate', r'compute', r'solve', r'equation', r'formula', |
|
|
r'sum', r'total', r'average', r'percentage', r'ratio', |
|
|
r'how many', r'how much', r'\d+.*\d+', r'math' |
|
|
], |
|
|
QuestionType.CODE_EXECUTION: [ |
|
|
r'code', r'program', r'script', r'function', r'algorithm', |
|
|
r'execute', r'run.*code', r'python', r'javascript' |
|
|
], |
|
|
QuestionType.TEXT_MANIPULATION: [ |
|
|
r'reverse', r'encode', r'decode', r'transform', r'convert', |
|
|
r'uppercase', r'lowercase', r'replace', r'extract' |
|
|
], |
|
|
QuestionType.REASONING: [ |
|
|
r'why', r'explain', r'analyze', r'reasoning', r'logic', |
|
|
r'relationship', r'compare', r'contrast', r'conclusion' |
|
|
], |
|
|
QuestionType.WEB_RESEARCH: [ |
|
|
r'search', r'find.*information', r'research', r'look up', |
|
|
r'website', r'online', r'internet' |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
type_scores = {} |
|
|
for question_type, patterns in classification_patterns.items(): |
|
|
score = sum(1 for pattern in patterns if re.search(pattern, question_lower)) |
|
|
if score > 0: |
|
|
type_scores[question_type] = score |
|
|
|
|
|
|
|
|
if type_scores: |
|
|
return max(type_scores.keys(), key=lambda t: type_scores[t]) |
|
|
|
|
|
return QuestionType.UNKNOWN |
|
|
|
|
|
def _assess_complexity(self, question: str) -> str: |
|
|
"""Assess question complexity""" |
|
|
|
|
|
question_lower = question.lower() |
|
|
|
|
|
|
|
|
complex_indicators = [ |
|
|
'multi-step', 'multiple', 'several', 'complex', 'detailed', |
|
|
'analyze', 'explain why', 'reasoning', 'relationship', |
|
|
'compare and contrast', 'comprehensive', 'thorough' |
|
|
] |
|
|
|
|
|
|
|
|
simple_indicators = [ |
|
|
'what is', 'who is', 'when', 'where', 'yes or no', |
|
|
'true or false', 'simple', 'quick', 'name', 'list' |
|
|
] |
|
|
|
|
|
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower) |
|
|
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower) |
|
|
|
|
|
|
|
|
if len(question) > 200: |
|
|
complex_score += 1 |
|
|
if len(question.split()) > 30: |
|
|
complex_score += 1 |
|
|
if question.count('?') > 2: |
|
|
complex_score += 1 |
|
|
|
|
|
|
|
|
if complex_score >= 2: |
|
|
return "complex" |
|
|
elif simple_score >= 2 and complex_score == 0: |
|
|
return "simple" |
|
|
else: |
|
|
return "medium" |
|
|
|
|
|
def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]: |
|
|
"""Select appropriate agents based on question type and presence of files""" |
|
|
|
|
|
agents = [] |
|
|
|
|
|
|
|
|
agents.append(AgentRole.SYNTHESIZER) |
|
|
|
|
|
|
|
|
if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]: |
|
|
agents.append(AgentRole.WEB_RESEARCHER) |
|
|
|
|
|
elif question_type == QuestionType.FILE_PROCESSING: |
|
|
agents.append(AgentRole.FILE_PROCESSOR) |
|
|
|
|
|
elif question_type == QuestionType.CODE_EXECUTION: |
|
|
agents.append(AgentRole.CODE_EXECUTOR) |
|
|
|
|
|
elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]: |
|
|
agents.append(AgentRole.REASONING_AGENT) |
|
|
|
|
|
elif question_type == QuestionType.TEXT_MANIPULATION: |
|
|
agents.append(AgentRole.REASONING_AGENT) |
|
|
|
|
|
else: |
|
|
|
|
|
agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT]) |
|
|
if has_file: |
|
|
agents.append(AgentRole.FILE_PROCESSOR) |
|
|
|
|
|
|
|
|
seen = set() |
|
|
unique_agents = [] |
|
|
for agent in agents: |
|
|
if agent not in seen: |
|
|
seen.add(agent) |
|
|
unique_agents.append(agent) |
|
|
|
|
|
return unique_agents |
|
|
|
|
|
def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float: |
|
|
"""Estimate processing cost based on complexity and agents""" |
|
|
|
|
|
base_costs = { |
|
|
"simple": 0.005, |
|
|
"medium": 0.015, |
|
|
"complex": 0.035 |
|
|
} |
|
|
|
|
|
base_cost = base_costs.get(complexity, 0.015) |
|
|
|
|
|
|
|
|
agent_cost = len(agents) * 0.005 |
|
|
|
|
|
return base_cost + agent_cost |
|
|
|
|
|
def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str: |
|
|
"""Generate human-readable reasoning for routing decision""" |
|
|
|
|
|
reasons = [] |
|
|
|
|
|
|
|
|
if question_type == QuestionType.WIKIPEDIA: |
|
|
reasons.append("Question references Wikipedia content") |
|
|
elif question_type == QuestionType.YOUTUBE: |
|
|
reasons.append("Question involves YouTube video analysis") |
|
|
elif question_type == QuestionType.FILE_PROCESSING: |
|
|
reasons.append("Question requires file processing") |
|
|
elif question_type == QuestionType.MATHEMATICAL: |
|
|
reasons.append("Question involves mathematical computation") |
|
|
elif question_type == QuestionType.CODE_EXECUTION: |
|
|
reasons.append("Question requires code execution") |
|
|
elif question_type == QuestionType.REASONING: |
|
|
reasons.append("Question requires logical reasoning") |
|
|
|
|
|
|
|
|
if complexity == "complex": |
|
|
reasons.append("Complex reasoning required") |
|
|
elif complexity == "simple": |
|
|
reasons.append("Straightforward question") |
|
|
|
|
|
|
|
|
agent_names = [agent.value.replace('_', ' ') for agent in agents] |
|
|
reasons.append(f"Selected agents: {', '.join(agent_names)}") |
|
|
|
|
|
return "; ".join(reasons) |
|
|
|
|
|
def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState: |
|
|
"""Use LLM for enhanced routing analysis of complex/unknown questions""" |
|
|
|
|
|
prompt = f""" |
|
|
Analyze this GAIA benchmark question and provide routing guidance: |
|
|
|
|
|
Question: {state.question} |
|
|
File attached: {state.file_name if state.file_name else "None"} |
|
|
Current classification: {state.question_type.value} |
|
|
Current complexity: {state.complexity_assessment} |
|
|
|
|
|
Please provide: |
|
|
1. Confirm or correct the question type |
|
|
2. Confirm or adjust complexity assessment |
|
|
3. Key challenges in answering this question |
|
|
4. Recommended approach |
|
|
|
|
|
Keep response concise and focused on routing decisions. |
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN |
|
|
result = self.llm_client.generate(prompt, tier=tier, max_tokens=200) |
|
|
|
|
|
if result.success: |
|
|
state.add_processing_step("Router: Enhanced with LLM analysis") |
|
|
state.routing_decision["llm_analysis"] = result.response |
|
|
logger.info("✅ LLM enhanced routing completed") |
|
|
else: |
|
|
state.add_error(f"LLM routing enhancement failed: {result.error}") |
|
|
|
|
|
except Exception as e: |
|
|
state.add_error(f"LLM routing error: {str(e)}") |
|
|
logger.error(f"LLM routing failed: {e}") |
|
|
|
|
|
return state |