| | |
| | """ |
| | Test Router Agent for GAIA Agent System |
| | Tests question classification and agent selection logic |
| | """ |
| |
|
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from agents.state import GAIAAgentState, QuestionType, AgentRole |
| | from agents.router import RouterAgent |
| | from models.qwen_client import QwenClient |
| |
|
| | def test_router_agent(): |
| | """Test the router agent with various question types""" |
| | |
| | print("π§ GAIA Router Agent Test") |
| | print("=" * 40) |
| | |
| | |
| | try: |
| | llm_client = QwenClient() |
| | router = RouterAgent(llm_client) |
| | except Exception as e: |
| | print(f"β Failed to initialize router: {e}") |
| | return False |
| | |
| | |
| | test_cases = [ |
| | { |
| | "question": "What is the capital of France?", |
| | "expected_type": [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.UNKNOWN], |
| | "expected_agents": [AgentRole.WEB_RESEARCHER] |
| | }, |
| | { |
| | "question": "Calculate 25% of 400 and add 50", |
| | "expected_type": [QuestionType.MATHEMATICAL], |
| | "expected_agents": [AgentRole.REASONING_AGENT] |
| | }, |
| | { |
| | "question": "What information can you extract from this CSV file?", |
| | "expected_type": [QuestionType.FILE_PROCESSING], |
| | "expected_agents": [AgentRole.FILE_PROCESSOR], |
| | "has_file": True |
| | }, |
| | { |
| | "question": "Search for recent news about artificial intelligence", |
| | "expected_type": [QuestionType.WEB_RESEARCH], |
| | "expected_agents": [AgentRole.WEB_RESEARCHER] |
| | }, |
| | { |
| | "question": "What does this Python code do and how can it be improved?", |
| | "expected_type": [QuestionType.CODE_EXECUTION, QuestionType.FILE_PROCESSING], |
| | "expected_agents": [AgentRole.FILE_PROCESSOR, AgentRole.CODE_EXECUTOR], |
| | "has_file": True |
| | } |
| | ] |
| | |
| | results = [] |
| | |
| | for i, test_case in enumerate(test_cases, 1): |
| | print(f"\n--- Test {i}: {test_case['question'][:50]}... ---") |
| | |
| | |
| | state = GAIAAgentState() |
| | state.question = test_case["question"] |
| | if test_case.get("has_file"): |
| | state.file_name = "test_file.csv" |
| | state.file_path = "/tmp/test_file.csv" |
| | |
| | try: |
| | |
| | result_state = router.route_question(state) |
| | |
| | |
| | type_correct = result_state.question_type in test_case["expected_type"] |
| | agents_correct = any(agent in result_state.selected_agents for agent in test_case["expected_agents"]) |
| | |
| | success = type_correct and agents_correct |
| | results.append(success) |
| | |
| | print(f" Question Type: {result_state.question_type.value} ({'β
' if type_correct else 'β'})") |
| | print(f" Selected Agents: {[a.value for a in result_state.selected_agents]} ({'β
' if agents_correct else 'β'})") |
| | print(f" Complexity: {result_state.complexity_assessment}") |
| | print(f" Overall: {'β
PASS' if success else 'β FAIL'}") |
| | |
| | except Exception as e: |
| | print(f" β FAIL: {e}") |
| | results.append(False) |
| | |
| | |
| | passed = sum(results) |
| | total = len(results) |
| | pass_rate = (passed / total) * 100 |
| | |
| | print("\n" + "=" * 40) |
| | print(f"π― ROUTER RESULTS: {passed}/{total} ({pass_rate:.1f}%)") |
| | |
| | if pass_rate >= 80: |
| | print("π Router working correctly!") |
| | return True |
| | else: |
| | print("β οΈ Router needs improvement") |
| | return False |
| |
|
| | if __name__ == "__main__": |
| | success = test_router_agent() |
| | sys.exit(0 if success else 1) |