Chris
commited on
Commit
·
225a75e
1
Parent(s):
e266fe2
Final 4
Browse files- .gitignore +1 -0
- README.md +7 -3
- env.example +17 -0
- requirements.txt +103 -2
- src/__init__.py +1 -0
- src/__pycache__/app.cpython-310.pyc +0 -0
- src/agents/__init__.py +26 -0
- src/agents/__pycache__/__init__.cpython-310.pyc +0 -0
- src/agents/__pycache__/file_processor_agent.cpython-310.pyc +0 -0
- src/agents/__pycache__/reasoning_agent.cpython-310.pyc +0 -0
- src/agents/__pycache__/router.cpython-310.pyc +0 -0
- src/agents/__pycache__/state.cpython-310.pyc +0 -0
- src/agents/__pycache__/synthesizer.cpython-310.pyc +0 -0
- src/agents/__pycache__/web_researcher.cpython-310.pyc +0 -0
- src/agents/file_processor_agent.py +532 -0
- src/agents/reasoning_agent.py +633 -0
- src/agents/router.py +300 -0
- src/agents/state.py +186 -0
- src/agents/synthesizer.py +284 -0
- src/agents/web_researcher.py +600 -0
- src/api/unit4_client.py +349 -0
- src/app.py +594 -0
- src/main.py +151 -0
- src/models/__init__.py +1 -0
- src/models/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/__pycache__/qwen_client.cpython-310.pyc +0 -0
- src/models/qwen_client.py +377 -0
- src/test_agents.py +200 -0
- src/test_all_tools.py +189 -0
- src/test_integration.py +196 -0
- src/test_real_gaia.py +248 -0
- src/test_router.py +111 -0
- src/test_workflow.py +316 -0
- src/tools/__init__.py +86 -0
- src/tools/__pycache__/__init__.cpython-310.pyc +0 -0
- src/tools/__pycache__/calculator.cpython-310.pyc +0 -0
- src/tools/__pycache__/file_processor.cpython-310.pyc +0 -0
- src/tools/__pycache__/web_search_tool.cpython-310.pyc +0 -0
- src/tools/__pycache__/wikipedia_tool.cpython-310.pyc +0 -0
- src/tools/calculator.py +423 -0
- src/tools/file_processor.py +681 -0
- src/tools/web_search_tool.py +350 -0
- src/tools/wikipedia_tool.py +296 -0
- src/workflow/__init__.py +9 -0
- src/workflow/__pycache__/__init__.cpython-310.pyc +0 -0
- src/workflow/__pycache__/gaia_workflow.cpython-310.pyc +0 -0
- src/workflow/gaia_workflow.py +304 -0
.gitignore
CHANGED
|
@@ -2,3 +2,4 @@ todo.md
|
|
| 2 |
project_data.md
|
| 3 |
.env
|
| 4 |
questions.json
|
|
|
|
|
|
| 2 |
project_data.md
|
| 3 |
.env
|
| 4 |
questions.json
|
| 5 |
+
venv/
|
README.md
CHANGED
|
@@ -1,16 +1,20 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.25.2
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
hf_oauth: true
|
| 11 |
# optional, default duration is 8 hours/480 minutes. Max duration is 30 days/43200 minutes.
|
| 12 |
hf_oauth_expiration_minutes: 480
|
| 13 |
---
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 16 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: GAIA Agent System
|
| 3 |
+
emoji: 🤖
|
| 4 |
colorFrom: indigo
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.25.2
|
| 8 |
+
app_file: ./src/app.py
|
| 9 |
pinned: false
|
| 10 |
hf_oauth: true
|
| 11 |
# optional, default duration is 8 hours/480 minutes. Max duration is 30 days/43200 minutes.
|
| 12 |
hf_oauth_expiration_minutes: 480
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# 🤖 GAIA Agent System
|
| 16 |
+
|
| 17 |
+
Advanced Multi-Agent AI System for GAIA Benchmark Questions using LangGraph orchestration.
|
| 18 |
+
|
| 19 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 20 |
|
env.example
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace Token for model access
|
| 2 |
+
HUGGINGFACE_TOKEN=your_token_here
|
| 3 |
+
|
| 4 |
+
# Optional: LangSmith for observability (bonus feature)
|
| 5 |
+
LANGCHAIN_API_KEY=your_langsmith_key_here
|
| 6 |
+
LANGCHAIN_TRACING_V2=true
|
| 7 |
+
LANGCHAIN_PROJECT=gaia-agent-system
|
| 8 |
+
|
| 9 |
+
# Model Configuration (defaults to free Qwen models)
|
| 10 |
+
ROUTER_MODEL=Qwen/Qwen2.5-3B-Instruct
|
| 11 |
+
MAIN_MODEL=Qwen/Qwen2.5-14B-Instruct
|
| 12 |
+
COMPLEX_MODEL=Qwen/Qwen2.5-32B-Instruct
|
| 13 |
+
|
| 14 |
+
# API Configuration
|
| 15 |
+
MAX_TOKENS=1000
|
| 16 |
+
TEMPERATURE=0.1
|
| 17 |
+
TIMEOUT=30
|
requirements.txt
CHANGED
|
@@ -1,2 +1,103 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==24.1.0
|
| 2 |
+
annotated-types==0.7.0
|
| 3 |
+
anyio==4.9.0
|
| 4 |
+
async-timeout==4.0.3
|
| 5 |
+
certifi==2025.4.26
|
| 6 |
+
charset-normalizer==3.4.2
|
| 7 |
+
click==8.2.1
|
| 8 |
+
exceptiongroup==1.3.0
|
| 9 |
+
fastapi==0.115.12
|
| 10 |
+
ffmpy==0.5.0
|
| 11 |
+
filelock==3.18.0
|
| 12 |
+
fsspec==2025.5.1
|
| 13 |
+
gradio==5.31.0
|
| 14 |
+
gradio_client==1.10.1
|
| 15 |
+
greenlet==3.2.2
|
| 16 |
+
groovy==0.1.2
|
| 17 |
+
h11==0.16.0
|
| 18 |
+
hf-xet==1.1.2
|
| 19 |
+
httpcore==1.0.9
|
| 20 |
+
httpx==0.28.1
|
| 21 |
+
huggingface-hub==0.32.2
|
| 22 |
+
idna==3.10
|
| 23 |
+
Jinja2==3.1.6
|
| 24 |
+
joblib==1.5.1
|
| 25 |
+
jsonpatch==1.33
|
| 26 |
+
jsonpointer==3.0.0
|
| 27 |
+
langchain==0.3.25
|
| 28 |
+
langchain-core==0.3.62
|
| 29 |
+
langchain-huggingface==0.2.0
|
| 30 |
+
langchain-text-splitters==0.3.8
|
| 31 |
+
langgraph==0.4.7
|
| 32 |
+
langgraph-checkpoint==2.0.26
|
| 33 |
+
langgraph-prebuilt==0.2.2
|
| 34 |
+
langgraph-sdk==0.1.70
|
| 35 |
+
langsmith==0.3.43
|
| 36 |
+
markdown-it-py==3.0.0
|
| 37 |
+
MarkupSafe==3.0.2
|
| 38 |
+
mdurl==0.1.2
|
| 39 |
+
mpmath==1.3.0
|
| 40 |
+
networkx==3.4.2
|
| 41 |
+
numpy==2.2.6
|
| 42 |
+
nvidia-cublas-cu12==12.6.4.1
|
| 43 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
| 44 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 45 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
| 46 |
+
nvidia-cudnn-cu12==9.5.1.17
|
| 47 |
+
nvidia-cufft-cu12==11.3.0.4
|
| 48 |
+
nvidia-cufile-cu12==1.11.1.6
|
| 49 |
+
nvidia-curand-cu12==10.3.7.77
|
| 50 |
+
nvidia-cusolver-cu12==11.7.1.2
|
| 51 |
+
nvidia-cusparse-cu12==12.5.4.2
|
| 52 |
+
nvidia-cusparselt-cu12==0.6.3
|
| 53 |
+
nvidia-nccl-cu12==2.26.2
|
| 54 |
+
nvidia-nvjitlink-cu12==12.6.85
|
| 55 |
+
nvidia-nvtx-cu12==12.6.77
|
| 56 |
+
orjson==3.10.18
|
| 57 |
+
ormsgpack==1.10.0
|
| 58 |
+
packaging==24.2
|
| 59 |
+
pandas==2.2.3
|
| 60 |
+
pillow==11.2.1
|
| 61 |
+
pydantic==2.11.5
|
| 62 |
+
pydantic_core==2.33.2
|
| 63 |
+
pydub==0.25.1
|
| 64 |
+
Pygments==2.19.1
|
| 65 |
+
python-dateutil==2.9.0.post0
|
| 66 |
+
python-dotenv==1.1.0
|
| 67 |
+
python-multipart==0.0.20
|
| 68 |
+
pytz==2025.2
|
| 69 |
+
PyYAML==6.0.2
|
| 70 |
+
regex==2024.11.6
|
| 71 |
+
requests==2.32.3
|
| 72 |
+
requests-toolbelt==1.0.0
|
| 73 |
+
rich==14.0.0
|
| 74 |
+
ruff==0.11.11
|
| 75 |
+
safehttpx==0.1.6
|
| 76 |
+
safetensors==0.5.3
|
| 77 |
+
scikit-learn==1.6.1
|
| 78 |
+
scipy==1.15.3
|
| 79 |
+
semantic-version==2.10.0
|
| 80 |
+
sentence-transformers==4.1.0
|
| 81 |
+
shellingham==1.5.4
|
| 82 |
+
six==1.17.0
|
| 83 |
+
sniffio==1.3.1
|
| 84 |
+
SQLAlchemy==2.0.41
|
| 85 |
+
starlette==0.46.2
|
| 86 |
+
sympy==1.14.0
|
| 87 |
+
tenacity==9.1.2
|
| 88 |
+
threadpoolctl==3.6.0
|
| 89 |
+
tokenizers==0.21.1
|
| 90 |
+
tomlkit==0.13.2
|
| 91 |
+
torch==2.7.0
|
| 92 |
+
tqdm==4.67.1
|
| 93 |
+
transformers==4.52.3
|
| 94 |
+
triton==3.3.0
|
| 95 |
+
typer==0.16.0
|
| 96 |
+
typing-inspection==0.4.1
|
| 97 |
+
typing_extensions==4.13.2
|
| 98 |
+
tzdata==2025.2
|
| 99 |
+
urllib3==2.4.0
|
| 100 |
+
uvicorn==0.34.2
|
| 101 |
+
websockets==15.0.1
|
| 102 |
+
xxhash==3.5.0
|
| 103 |
+
zstandard==0.23.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# GAIA Agent System
|
src/__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GAIA Agent System Components
|
| 4 |
+
Multi-agent framework for GAIA benchmark questions using LangGraph
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .state import (
|
| 8 |
+
GAIAAgentState,
|
| 9 |
+
AgentState,
|
| 10 |
+
QuestionType,
|
| 11 |
+
AgentRole,
|
| 12 |
+
ToolResult,
|
| 13 |
+
AgentResult
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from .router import RouterAgent
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'GAIAAgentState',
|
| 20 |
+
'AgentState',
|
| 21 |
+
'QuestionType',
|
| 22 |
+
'AgentRole',
|
| 23 |
+
'ToolResult',
|
| 24 |
+
'AgentResult',
|
| 25 |
+
'RouterAgent'
|
| 26 |
+
]
|
src/agents/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (527 Bytes). View file
|
|
|
src/agents/__pycache__/file_processor_agent.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
src/agents/__pycache__/reasoning_agent.cpython-310.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
src/agents/__pycache__/router.cpython-310.pyc
ADDED
|
Binary file (8.99 kB). View file
|
|
|
src/agents/__pycache__/state.cpython-310.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
src/agents/__pycache__/synthesizer.cpython-310.pyc
ADDED
|
Binary file (9.78 kB). View file
|
|
|
src/agents/__pycache__/web_researcher.cpython-310.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
src/agents/file_processor_agent.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
File Processor Agent for GAIA Agent System
|
| 4 |
+
Handles file-based questions with intelligent processing strategies
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Any
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from agents.state import GAIAAgentState, AgentRole, AgentResult, ToolResult
|
| 13 |
+
from models.qwen_client import QwenClient, ModelTier
|
| 14 |
+
from tools.file_processor import FileProcessorTool
|
| 15 |
+
from tools.calculator import CalculatorTool
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class FileProcessorAgent:
|
| 20 |
+
"""
|
| 21 |
+
Specialized agent for file processing tasks
|
| 22 |
+
Handles images, audio, CSV/Excel, Python code, and other file types
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, llm_client: QwenClient):
|
| 26 |
+
self.llm_client = llm_client
|
| 27 |
+
self.file_processor = FileProcessorTool()
|
| 28 |
+
self.calculator = CalculatorTool() # For data analysis
|
| 29 |
+
|
| 30 |
+
def process(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 31 |
+
"""
|
| 32 |
+
Process file-based questions using file analysis tools
|
| 33 |
+
"""
|
| 34 |
+
logger.info(f"File processor processing: {state.question[:100]}...")
|
| 35 |
+
state.add_processing_step("File Processor: Starting file analysis")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Check if file exists
|
| 39 |
+
if not state.file_path or not os.path.exists(state.file_path):
|
| 40 |
+
error_msg = f"File not found: {state.file_path}"
|
| 41 |
+
state.add_error(error_msg)
|
| 42 |
+
result = self._create_failure_result(error_msg)
|
| 43 |
+
state.add_agent_result(result)
|
| 44 |
+
return state
|
| 45 |
+
|
| 46 |
+
# Determine processing strategy
|
| 47 |
+
strategy = self._determine_processing_strategy(state.question, state.file_path)
|
| 48 |
+
state.add_processing_step(f"File Processor: Strategy = {strategy}")
|
| 49 |
+
|
| 50 |
+
# Execute processing based on strategy
|
| 51 |
+
if strategy == "image_analysis":
|
| 52 |
+
result = self._process_image(state)
|
| 53 |
+
elif strategy == "data_analysis":
|
| 54 |
+
result = self._process_data_file(state)
|
| 55 |
+
elif strategy == "code_analysis":
|
| 56 |
+
result = self._process_code_file(state)
|
| 57 |
+
elif strategy == "audio_analysis":
|
| 58 |
+
result = self._process_audio_file(state)
|
| 59 |
+
elif strategy == "text_analysis":
|
| 60 |
+
result = self._process_text_file(state)
|
| 61 |
+
else:
|
| 62 |
+
result = self._process_generic_file(state)
|
| 63 |
+
|
| 64 |
+
# Add result to state
|
| 65 |
+
state.add_agent_result(result)
|
| 66 |
+
state.add_processing_step(f"File Processor: Completed with confidence {result.confidence:.2f}")
|
| 67 |
+
|
| 68 |
+
return state
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
error_msg = f"File processing failed: {str(e)}"
|
| 72 |
+
state.add_error(error_msg)
|
| 73 |
+
logger.error(error_msg)
|
| 74 |
+
|
| 75 |
+
# Create failure result
|
| 76 |
+
failure_result = self._create_failure_result(error_msg)
|
| 77 |
+
state.add_agent_result(failure_result)
|
| 78 |
+
return state
|
| 79 |
+
|
| 80 |
+
def _determine_processing_strategy(self, question: str, file_path: str) -> str:
|
| 81 |
+
"""Determine the best processing strategy based on file type and question"""
|
| 82 |
+
|
| 83 |
+
file_extension = Path(file_path).suffix.lower()
|
| 84 |
+
question_lower = question.lower()
|
| 85 |
+
|
| 86 |
+
# Image file analysis
|
| 87 |
+
if file_extension in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}:
|
| 88 |
+
return "image_analysis"
|
| 89 |
+
|
| 90 |
+
# Audio file analysis
|
| 91 |
+
if file_extension in {'.mp3', '.wav', '.ogg', '.flac', '.m4a', '.aac'}:
|
| 92 |
+
return "audio_analysis"
|
| 93 |
+
|
| 94 |
+
# Data file analysis
|
| 95 |
+
if file_extension in {'.csv', '.xlsx', '.xls', '.json'}:
|
| 96 |
+
return "data_analysis"
|
| 97 |
+
|
| 98 |
+
# Code file analysis
|
| 99 |
+
if file_extension in {'.py', '.js', '.java', '.cpp', '.c', '.html', '.css'}:
|
| 100 |
+
return "code_analysis"
|
| 101 |
+
|
| 102 |
+
# Text file analysis
|
| 103 |
+
if file_extension in {'.txt', '.md', '.rst'}:
|
| 104 |
+
return "text_analysis"
|
| 105 |
+
|
| 106 |
+
# Default to generic processing
|
| 107 |
+
return "generic_analysis"
|
| 108 |
+
|
| 109 |
+
def _process_image(self, state: GAIAAgentState) -> AgentResult:
|
| 110 |
+
"""Process image files and answer questions about them"""
|
| 111 |
+
|
| 112 |
+
logger.info(f"Processing image: {state.file_path}")
|
| 113 |
+
|
| 114 |
+
# Analyze image with file processor
|
| 115 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 116 |
+
|
| 117 |
+
if file_result.success and file_result.result.get('success'):
|
| 118 |
+
file_data = file_result.result['result']
|
| 119 |
+
|
| 120 |
+
# Create analysis prompt based on image metadata and question
|
| 121 |
+
analysis_prompt = f"""
|
| 122 |
+
Based on this image analysis, please answer the following question:
|
| 123 |
+
|
| 124 |
+
Question: {state.question}
|
| 125 |
+
|
| 126 |
+
Image Information:
|
| 127 |
+
- File: {file_data.get('file_path', '')}
|
| 128 |
+
- Type: {file_data.get('file_type', '')}
|
| 129 |
+
- Content Description: {file_data.get('content', '')}
|
| 130 |
+
- Metadata: {file_data.get('metadata', {})}
|
| 131 |
+
|
| 132 |
+
Please provide a direct answer based on the image analysis.
|
| 133 |
+
If the question asks about specific details that cannot be determined from the metadata alone,
|
| 134 |
+
please indicate what information is available and what would require visual analysis.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
# Use main model for image analysis
|
| 138 |
+
model_tier = ModelTier.MAIN
|
| 139 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 140 |
+
|
| 141 |
+
if llm_result.success:
|
| 142 |
+
confidence = 0.75 # Good confidence for image metadata analysis
|
| 143 |
+
return AgentResult(
|
| 144 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 145 |
+
success=True,
|
| 146 |
+
result=llm_result.response,
|
| 147 |
+
confidence=confidence,
|
| 148 |
+
reasoning="Analyzed image metadata and properties",
|
| 149 |
+
tools_used=[ToolResult(
|
| 150 |
+
tool_name="file_processor",
|
| 151 |
+
success=True,
|
| 152 |
+
result=file_data,
|
| 153 |
+
execution_time=file_result.execution_time
|
| 154 |
+
)],
|
| 155 |
+
model_used=llm_result.model_used,
|
| 156 |
+
processing_time=file_result.execution_time + llm_result.response_time,
|
| 157 |
+
cost_estimate=llm_result.cost_estimate
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
# Fallback to metadata description
|
| 161 |
+
return AgentResult(
|
| 162 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 163 |
+
success=True,
|
| 164 |
+
result=file_data.get('content', 'Image analyzed'),
|
| 165 |
+
confidence=0.60,
|
| 166 |
+
reasoning="Image processed but analysis failed",
|
| 167 |
+
tools_used=[ToolResult(
|
| 168 |
+
tool_name="file_processor",
|
| 169 |
+
success=True,
|
| 170 |
+
result=file_data,
|
| 171 |
+
execution_time=file_result.execution_time
|
| 172 |
+
)],
|
| 173 |
+
model_used="fallback",
|
| 174 |
+
processing_time=file_result.execution_time,
|
| 175 |
+
cost_estimate=0.0
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
return self._create_failure_result("Image processing failed")
|
| 179 |
+
|
| 180 |
+
def _process_data_file(self, state: GAIAAgentState) -> AgentResult:
|
| 181 |
+
"""Process CSV/Excel files and perform data analysis"""
|
| 182 |
+
|
| 183 |
+
logger.info(f"Processing data file: {state.file_path}")
|
| 184 |
+
|
| 185 |
+
# Analyze data file
|
| 186 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 187 |
+
|
| 188 |
+
if file_result.success and file_result.result.get('success'):
|
| 189 |
+
file_data = file_result.result['result']
|
| 190 |
+
metadata = file_data.get('metadata', {})
|
| 191 |
+
content = file_data.get('content', {})
|
| 192 |
+
|
| 193 |
+
# Check if question requires calculations
|
| 194 |
+
question_lower = state.question.lower()
|
| 195 |
+
needs_calculation = any(term in question_lower for term in [
|
| 196 |
+
'calculate', 'sum', 'total', 'average', 'mean', 'count',
|
| 197 |
+
'maximum', 'minimum', 'how many', 'what is the'
|
| 198 |
+
])
|
| 199 |
+
|
| 200 |
+
if needs_calculation and 'sample_data' in content:
|
| 201 |
+
return self._perform_data_calculations(state, file_data, file_result)
|
| 202 |
+
else:
|
| 203 |
+
return self._analyze_data_structure(state, file_data, file_result)
|
| 204 |
+
else:
|
| 205 |
+
return self._create_failure_result("Data file processing failed")
|
| 206 |
+
|
| 207 |
+
def _perform_data_calculations(self, state: GAIAAgentState, file_data: Dict, file_result: ToolResult) -> AgentResult:
|
| 208 |
+
"""Perform calculations on data file content"""
|
| 209 |
+
|
| 210 |
+
metadata = file_data.get('metadata', {})
|
| 211 |
+
content = file_data.get('content', {})
|
| 212 |
+
|
| 213 |
+
# Extract data for calculations
|
| 214 |
+
sample_data = content.get('sample_data', [])
|
| 215 |
+
|
| 216 |
+
# Use LLM to determine what calculations to perform
|
| 217 |
+
calculation_prompt = f"""
|
| 218 |
+
Based on this data file and question, determine what calculations are needed:
|
| 219 |
+
|
| 220 |
+
Question: {state.question}
|
| 221 |
+
|
| 222 |
+
Data Structure:
|
| 223 |
+
- Columns: {metadata.get('columns', [])}
|
| 224 |
+
- Rows: {metadata.get('row_count', 0)}
|
| 225 |
+
- Sample Data: {sample_data[:3]} # First 3 rows
|
| 226 |
+
|
| 227 |
+
Please specify what calculations should be performed and on which columns.
|
| 228 |
+
Respond with specific calculation instructions.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
llm_result = self.llm_client.generate(calculation_prompt, tier=ModelTier.MAIN, max_tokens=200)
|
| 232 |
+
|
| 233 |
+
if llm_result.success:
|
| 234 |
+
# For now, provide data summary with LLM analysis
|
| 235 |
+
analysis_prompt = f"""
|
| 236 |
+
Based on this data analysis, please answer the question:
|
| 237 |
+
|
| 238 |
+
Question: {state.question}
|
| 239 |
+
|
| 240 |
+
Data Summary:
|
| 241 |
+
- File: {metadata.get('shape', [])} (rows x columns)
|
| 242 |
+
- Columns: {metadata.get('columns', [])}
|
| 243 |
+
- Numeric columns: {metadata.get('numeric_columns', [])}
|
| 244 |
+
- Statistics: {metadata.get('numeric_stats', {})}
|
| 245 |
+
- Sample data: {sample_data}
|
| 246 |
+
|
| 247 |
+
Calculation guidance: {llm_result.response}
|
| 248 |
+
|
| 249 |
+
Please provide the answer based on the data.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
analysis_result = self.llm_client.generate(analysis_prompt, tier=ModelTier.MAIN, max_tokens=400)
|
| 253 |
+
|
| 254 |
+
if analysis_result.success:
|
| 255 |
+
return AgentResult(
|
| 256 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 257 |
+
success=True,
|
| 258 |
+
result=analysis_result.response,
|
| 259 |
+
confidence=0.80,
|
| 260 |
+
reasoning="Performed data analysis and calculations",
|
| 261 |
+
tools_used=[file_result],
|
| 262 |
+
model_used=analysis_result.model_used,
|
| 263 |
+
processing_time=file_result.execution_time + llm_result.response_time + analysis_result.response_time,
|
| 264 |
+
cost_estimate=llm_result.cost_estimate + analysis_result.cost_estimate
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Fallback to basic data summary
|
| 268 |
+
return self._analyze_data_structure(state, file_data, file_result)
|
| 269 |
+
|
| 270 |
+
def _analyze_data_structure(self, state: GAIAAgentState, file_data: Dict, file_result: ToolResult) -> AgentResult:
|
| 271 |
+
"""Analyze data file structure and content"""
|
| 272 |
+
|
| 273 |
+
metadata = file_data.get('metadata', {})
|
| 274 |
+
content = file_data.get('content', {})
|
| 275 |
+
|
| 276 |
+
analysis_prompt = f"""
|
| 277 |
+
Based on this data file analysis, please answer the question:
|
| 278 |
+
|
| 279 |
+
Question: {state.question}
|
| 280 |
+
|
| 281 |
+
Data File Information:
|
| 282 |
+
- Structure: {metadata.get('shape', [])} (rows x columns)
|
| 283 |
+
- Columns: {metadata.get('columns', [])}
|
| 284 |
+
- Data types: {metadata.get('data_types', {})}
|
| 285 |
+
- Description: {content.get('description', '')}
|
| 286 |
+
- Sample data: {content.get('sample_data', [])}
|
| 287 |
+
|
| 288 |
+
Please provide a direct answer based on the data structure and content.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
model_tier = ModelTier.MAIN
|
| 292 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 293 |
+
|
| 294 |
+
if llm_result.success:
|
| 295 |
+
return AgentResult(
|
| 296 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 297 |
+
success=True,
|
| 298 |
+
result=llm_result.response,
|
| 299 |
+
confidence=0.75,
|
| 300 |
+
reasoning="Analyzed data file structure and content",
|
| 301 |
+
tools_used=[file_result],
|
| 302 |
+
model_used=llm_result.model_used,
|
| 303 |
+
processing_time=file_result.execution_time + llm_result.response_time,
|
| 304 |
+
cost_estimate=llm_result.cost_estimate
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
return AgentResult(
|
| 308 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 309 |
+
success=True,
|
| 310 |
+
result=content.get('description', 'Data file analyzed'),
|
| 311 |
+
confidence=0.60,
|
| 312 |
+
reasoning="Data file processed but analysis failed",
|
| 313 |
+
tools_used=[file_result],
|
| 314 |
+
model_used="fallback",
|
| 315 |
+
processing_time=file_result.execution_time,
|
| 316 |
+
cost_estimate=0.0
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def _process_code_file(self, state: GAIAAgentState) -> AgentResult:
|
| 320 |
+
"""Process code files and analyze their content"""
|
| 321 |
+
|
| 322 |
+
logger.info(f"Processing code file: {state.file_path}")
|
| 323 |
+
|
| 324 |
+
# Analyze code file
|
| 325 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 326 |
+
|
| 327 |
+
if file_result.success and file_result.result.get('success'):
|
| 328 |
+
file_data = file_result.result['result']
|
| 329 |
+
metadata = file_data.get('metadata', {})
|
| 330 |
+
content = file_data.get('content', {})
|
| 331 |
+
|
| 332 |
+
analysis_prompt = f"""
|
| 333 |
+
Based on this code analysis, please answer the question:
|
| 334 |
+
|
| 335 |
+
Question: {state.question}
|
| 336 |
+
|
| 337 |
+
Code File Information:
|
| 338 |
+
- Type: {file_data.get('file_type', '')}
|
| 339 |
+
- Description: {content.get('description', '')}
|
| 340 |
+
- Metadata: {metadata}
|
| 341 |
+
- Code snippet: {content.get('code_snippet', '')}
|
| 342 |
+
|
| 343 |
+
Please analyze the code and provide a direct answer.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
model_tier = ModelTier.MAIN
|
| 347 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=500)
|
| 348 |
+
|
| 349 |
+
if llm_result.success:
|
| 350 |
+
return AgentResult(
|
| 351 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 352 |
+
success=True,
|
| 353 |
+
result=llm_result.response,
|
| 354 |
+
confidence=0.80,
|
| 355 |
+
reasoning="Analyzed code structure and content",
|
| 356 |
+
tools_used=[ToolResult(
|
| 357 |
+
tool_name="file_processor",
|
| 358 |
+
success=True,
|
| 359 |
+
result=file_data,
|
| 360 |
+
execution_time=file_result.execution_time
|
| 361 |
+
)],
|
| 362 |
+
model_used=llm_result.model_used,
|
| 363 |
+
processing_time=file_result.execution_time + llm_result.response_time,
|
| 364 |
+
cost_estimate=llm_result.cost_estimate
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
return AgentResult(
|
| 368 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 369 |
+
success=True,
|
| 370 |
+
result=content.get('description', 'Code file analyzed'),
|
| 371 |
+
confidence=0.60,
|
| 372 |
+
reasoning="Code file processed but analysis failed",
|
| 373 |
+
tools_used=[ToolResult(
|
| 374 |
+
tool_name="file_processor",
|
| 375 |
+
success=True,
|
| 376 |
+
result=file_data,
|
| 377 |
+
execution_time=file_result.execution_time
|
| 378 |
+
)],
|
| 379 |
+
model_used="fallback",
|
| 380 |
+
processing_time=file_result.execution_time,
|
| 381 |
+
cost_estimate=0.0
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
return self._create_failure_result("Code file processing failed")
|
| 385 |
+
|
| 386 |
+
def _process_audio_file(self, state: GAIAAgentState) -> AgentResult:
|
| 387 |
+
"""Process audio files (basic metadata for now)"""
|
| 388 |
+
|
| 389 |
+
logger.info(f"Processing audio file: {state.file_path}")
|
| 390 |
+
|
| 391 |
+
# Analyze audio file
|
| 392 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 393 |
+
|
| 394 |
+
if file_result.success and file_result.result.get('success'):
|
| 395 |
+
file_data = file_result.result['result']
|
| 396 |
+
|
| 397 |
+
analysis_prompt = f"""
|
| 398 |
+
Based on this audio file information, please answer the question:
|
| 399 |
+
|
| 400 |
+
Question: {state.question}
|
| 401 |
+
|
| 402 |
+
Audio File Information:
|
| 403 |
+
- Content: {file_data.get('content', '')}
|
| 404 |
+
- Metadata: {file_data.get('metadata', {})}
|
| 405 |
+
|
| 406 |
+
Please provide an answer based on the available audio file information.
|
| 407 |
+
Note: Full audio transcription is not currently available, but file metadata is provided.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
model_tier = ModelTier.ROUTER # Use lighter model for basic audio metadata
|
| 411 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=300)
|
| 412 |
+
|
| 413 |
+
if llm_result.success:
|
| 414 |
+
return AgentResult(
|
| 415 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 416 |
+
success=True,
|
| 417 |
+
result=llm_result.response,
|
| 418 |
+
confidence=0.50, # Lower confidence due to limited audio processing
|
| 419 |
+
reasoning="Analyzed audio file metadata (transcription not available)",
|
| 420 |
+
tools_used=[ToolResult(
|
| 421 |
+
tool_name="file_processor",
|
| 422 |
+
success=True,
|
| 423 |
+
result=file_data,
|
| 424 |
+
execution_time=file_result.execution_time
|
| 425 |
+
)],
|
| 426 |
+
model_used=llm_result.model_used,
|
| 427 |
+
processing_time=file_result.execution_time + llm_result.response_time,
|
| 428 |
+
cost_estimate=llm_result.cost_estimate
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return self._create_failure_result("Audio file processing not fully supported")
|
| 432 |
+
|
| 433 |
+
def _process_text_file(self, state: GAIAAgentState) -> AgentResult:
|
| 434 |
+
"""Process text files and analyze their content"""
|
| 435 |
+
|
| 436 |
+
logger.info(f"Processing text file: {state.file_path}")
|
| 437 |
+
|
| 438 |
+
# Analyze text file
|
| 439 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 440 |
+
|
| 441 |
+
if file_result.success and file_result.result.get('success'):
|
| 442 |
+
file_data = file_result.result['result']
|
| 443 |
+
content = file_data.get('content', {})
|
| 444 |
+
|
| 445 |
+
analysis_prompt = f"""
|
| 446 |
+
Based on this text file content, please answer the question:
|
| 447 |
+
|
| 448 |
+
Question: {state.question}
|
| 449 |
+
|
| 450 |
+
Text Content:
|
| 451 |
+
{content.get('text', '')[:2000]}...
|
| 452 |
+
|
| 453 |
+
File Statistics:
|
| 454 |
+
- Word count: {file_data.get('metadata', {}).get('word_count', 0)}
|
| 455 |
+
- Line count: {file_data.get('metadata', {}).get('line_count', 0)}
|
| 456 |
+
|
| 457 |
+
Please analyze the text and provide a direct answer.
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
model_tier = ModelTier.MAIN
|
| 461 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 462 |
+
|
| 463 |
+
if llm_result.success:
|
| 464 |
+
return AgentResult(
|
| 465 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 466 |
+
success=True,
|
| 467 |
+
result=llm_result.response,
|
| 468 |
+
confidence=0.85,
|
| 469 |
+
reasoning="Analyzed text file content",
|
| 470 |
+
tools_used=[ToolResult(
|
| 471 |
+
tool_name="file_processor",
|
| 472 |
+
success=True,
|
| 473 |
+
result=file_data,
|
| 474 |
+
execution_time=file_result.execution_time
|
| 475 |
+
)],
|
| 476 |
+
model_used=llm_result.model_used,
|
| 477 |
+
processing_time=file_result.execution_time + llm_result.response_time,
|
| 478 |
+
cost_estimate=llm_result.cost_estimate
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return self._create_failure_result("Text file processing failed")
|
| 482 |
+
|
| 483 |
+
def _process_generic_file(self, state: GAIAAgentState) -> AgentResult:
|
| 484 |
+
"""Process unknown file types with generic analysis"""
|
| 485 |
+
|
| 486 |
+
logger.info(f"Processing generic file: {state.file_path}")
|
| 487 |
+
|
| 488 |
+
# Try generic file processing
|
| 489 |
+
file_result = self.file_processor.execute(state.file_path)
|
| 490 |
+
|
| 491 |
+
if file_result.success:
|
| 492 |
+
file_data = file_result.result
|
| 493 |
+
|
| 494 |
+
# Create basic response about file
|
| 495 |
+
basic_info = f"File analyzed: {state.file_path}. "
|
| 496 |
+
if file_data.get('success'):
|
| 497 |
+
basic_info += f"File type: {file_data.get('result', {}).get('file_type', 'unknown')}. "
|
| 498 |
+
basic_info += "Generic file analysis completed."
|
| 499 |
+
else:
|
| 500 |
+
basic_info += f"Analysis result: {file_data.get('message', 'Processing completed')}"
|
| 501 |
+
|
| 502 |
+
return AgentResult(
|
| 503 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 504 |
+
success=True,
|
| 505 |
+
result=basic_info,
|
| 506 |
+
confidence=0.40,
|
| 507 |
+
reasoning="Generic file processing attempted",
|
| 508 |
+
tools_used=[ToolResult(
|
| 509 |
+
tool_name="file_processor",
|
| 510 |
+
success=True,
|
| 511 |
+
result=file_data,
|
| 512 |
+
execution_time=file_result.execution_time
|
| 513 |
+
)],
|
| 514 |
+
model_used="basic",
|
| 515 |
+
processing_time=file_result.execution_time,
|
| 516 |
+
cost_estimate=0.0
|
| 517 |
+
)
|
| 518 |
+
else:
|
| 519 |
+
return self._create_failure_result("Generic file processing failed")
|
| 520 |
+
|
| 521 |
+
def _create_failure_result(self, error_message: str) -> AgentResult:
|
| 522 |
+
"""Create a failure result"""
|
| 523 |
+
return AgentResult(
|
| 524 |
+
agent_role=AgentRole.FILE_PROCESSOR,
|
| 525 |
+
success=False,
|
| 526 |
+
result=error_message,
|
| 527 |
+
confidence=0.0,
|
| 528 |
+
reasoning=error_message,
|
| 529 |
+
model_used="error",
|
| 530 |
+
processing_time=0.0,
|
| 531 |
+
cost_estimate=0.0
|
| 532 |
+
)
|
src/agents/reasoning_agent.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Reasoning Agent for GAIA Agent System
|
| 4 |
+
Handles mathematical, logical, and analytical reasoning questions
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Any, Union
|
| 10 |
+
|
| 11 |
+
from agents.state import GAIAAgentState, AgentRole, AgentResult, ToolResult
|
| 12 |
+
from models.qwen_client import QwenClient, ModelTier
|
| 13 |
+
from tools.calculator import CalculatorTool
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class ReasoningAgent:
|
| 18 |
+
"""
|
| 19 |
+
Specialized agent for reasoning tasks
|
| 20 |
+
Handles mathematical calculations, logical deduction, and analytical problems
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, llm_client: QwenClient):
|
| 24 |
+
self.llm_client = llm_client
|
| 25 |
+
self.calculator = CalculatorTool()
|
| 26 |
+
|
| 27 |
+
def process(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 28 |
+
"""
|
| 29 |
+
Process reasoning questions using mathematical and logical analysis
|
| 30 |
+
"""
|
| 31 |
+
logger.info(f"Reasoning agent processing: {state.question[:100]}...")
|
| 32 |
+
state.add_processing_step("Reasoning Agent: Starting analysis")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
# Determine reasoning strategy
|
| 36 |
+
strategy = self._determine_reasoning_strategy(state.question)
|
| 37 |
+
state.add_processing_step(f"Reasoning Agent: Strategy = {strategy}")
|
| 38 |
+
|
| 39 |
+
# Execute reasoning based on strategy
|
| 40 |
+
if strategy == "mathematical":
|
| 41 |
+
result = self._process_mathematical(state)
|
| 42 |
+
elif strategy == "statistical":
|
| 43 |
+
result = self._process_statistical(state)
|
| 44 |
+
elif strategy == "unit_conversion":
|
| 45 |
+
result = self._process_unit_conversion(state)
|
| 46 |
+
elif strategy == "logical_deduction":
|
| 47 |
+
result = self._process_logical_deduction(state)
|
| 48 |
+
elif strategy == "pattern_analysis":
|
| 49 |
+
result = self._process_pattern_analysis(state)
|
| 50 |
+
elif strategy == "step_by_step":
|
| 51 |
+
result = self._process_step_by_step(state)
|
| 52 |
+
else:
|
| 53 |
+
result = self._process_general_reasoning(state)
|
| 54 |
+
|
| 55 |
+
# Add result to state
|
| 56 |
+
state.add_agent_result(result)
|
| 57 |
+
state.add_processing_step(f"Reasoning Agent: Completed with confidence {result.confidence:.2f}")
|
| 58 |
+
|
| 59 |
+
return state
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
error_msg = f"Reasoning failed: {str(e)}"
|
| 63 |
+
state.add_error(error_msg)
|
| 64 |
+
logger.error(error_msg)
|
| 65 |
+
|
| 66 |
+
# Create failure result
|
| 67 |
+
failure_result = self._create_failure_result(error_msg)
|
| 68 |
+
state.add_agent_result(failure_result)
|
| 69 |
+
return state
|
| 70 |
+
|
| 71 |
+
def _determine_reasoning_strategy(self, question: str) -> str:
|
| 72 |
+
"""Determine the best reasoning strategy for the question"""
|
| 73 |
+
|
| 74 |
+
question_lower = question.lower()
|
| 75 |
+
|
| 76 |
+
# Mathematical calculations
|
| 77 |
+
math_indicators = [
|
| 78 |
+
'calculate', 'compute', 'solve', 'equation', 'formula',
|
| 79 |
+
'multiply', 'divide', 'add', 'subtract', 'sum', 'total',
|
| 80 |
+
'percentage', 'percent', 'ratio', 'proportion'
|
| 81 |
+
]
|
| 82 |
+
if any(indicator in question_lower for indicator in math_indicators):
|
| 83 |
+
return "mathematical"
|
| 84 |
+
|
| 85 |
+
# Statistical analysis
|
| 86 |
+
stats_indicators = [
|
| 87 |
+
'average', 'mean', 'median', 'mode', 'standard deviation',
|
| 88 |
+
'variance', 'correlation', 'distribution', 'sample'
|
| 89 |
+
]
|
| 90 |
+
if any(indicator in question_lower for indicator in stats_indicators):
|
| 91 |
+
return "statistical"
|
| 92 |
+
|
| 93 |
+
# Unit conversions
|
| 94 |
+
unit_indicators = [
|
| 95 |
+
'convert', 'to', 'from', 'meter', 'feet', 'celsius', 'fahrenheit',
|
| 96 |
+
'gram', 'pound', 'liter', 'gallon', 'hour', 'minute'
|
| 97 |
+
]
|
| 98 |
+
conversion_pattern = r'\d+\s*\w+\s+to\s+\w+'
|
| 99 |
+
if (any(indicator in question_lower for indicator in unit_indicators) or
|
| 100 |
+
re.search(conversion_pattern, question_lower)):
|
| 101 |
+
return "unit_conversion"
|
| 102 |
+
|
| 103 |
+
# Logical deduction
|
| 104 |
+
logic_indicators = [
|
| 105 |
+
'if', 'then', 'therefore', 'because', 'since', 'given that',
|
| 106 |
+
'prove', 'demonstrate', 'conclude', 'infer', 'deduce'
|
| 107 |
+
]
|
| 108 |
+
if any(indicator in question_lower for indicator in logic_indicators):
|
| 109 |
+
return "logical_deduction"
|
| 110 |
+
|
| 111 |
+
# Pattern analysis
|
| 112 |
+
pattern_indicators = [
|
| 113 |
+
'pattern', 'sequence', 'series', 'next', 'continues',
|
| 114 |
+
'follows', 'trend', 'progression'
|
| 115 |
+
]
|
| 116 |
+
if any(indicator in question_lower for indicator in pattern_indicators):
|
| 117 |
+
return "pattern_analysis"
|
| 118 |
+
|
| 119 |
+
# Step-by-step problems
|
| 120 |
+
step_indicators = [
|
| 121 |
+
'step', 'process', 'procedure', 'method', 'approach',
|
| 122 |
+
'how to', 'explain how', 'show how'
|
| 123 |
+
]
|
| 124 |
+
if any(indicator in question_lower for indicator in step_indicators):
|
| 125 |
+
return "step_by_step"
|
| 126 |
+
|
| 127 |
+
# Default to general reasoning
|
| 128 |
+
return "general_reasoning"
|
| 129 |
+
|
| 130 |
+
def _process_mathematical(self, state: GAIAAgentState) -> AgentResult:
|
| 131 |
+
"""Process mathematical calculation questions"""
|
| 132 |
+
|
| 133 |
+
logger.info("Processing mathematical calculation")
|
| 134 |
+
|
| 135 |
+
# Extract mathematical expressions from the question
|
| 136 |
+
expressions = self._extract_mathematical_expressions(state.question)
|
| 137 |
+
|
| 138 |
+
if expressions:
|
| 139 |
+
# Try to solve with calculator
|
| 140 |
+
calc_results = []
|
| 141 |
+
for expr in expressions:
|
| 142 |
+
calc_result = self.calculator.execute(expr)
|
| 143 |
+
calc_results.append(calc_result)
|
| 144 |
+
|
| 145 |
+
# Use LLM to interpret results and provide answer
|
| 146 |
+
if calc_results and any(r.success for r in calc_results):
|
| 147 |
+
return self._analyze_calculation_results(state, calc_results)
|
| 148 |
+
else:
|
| 149 |
+
# Fallback to LLM-only mathematical reasoning
|
| 150 |
+
return self._llm_mathematical_reasoning(state)
|
| 151 |
+
else:
|
| 152 |
+
# No clear expressions, use LLM reasoning
|
| 153 |
+
return self._llm_mathematical_reasoning(state)
|
| 154 |
+
|
| 155 |
+
def _process_statistical(self, state: GAIAAgentState) -> AgentResult:
|
| 156 |
+
"""Process statistical analysis questions"""
|
| 157 |
+
|
| 158 |
+
logger.info("Processing statistical analysis")
|
| 159 |
+
|
| 160 |
+
# Extract numerical data from question
|
| 161 |
+
numbers = self._extract_numbers(state.question)
|
| 162 |
+
|
| 163 |
+
if len(numbers) >= 2:
|
| 164 |
+
# Perform statistical calculations
|
| 165 |
+
stats_data = {"operation": "statistics", "data": numbers}
|
| 166 |
+
calc_result = self.calculator.execute(stats_data)
|
| 167 |
+
|
| 168 |
+
if calc_result.success:
|
| 169 |
+
return self._analyze_statistical_results(state, calc_result, numbers)
|
| 170 |
+
else:
|
| 171 |
+
return self._llm_statistical_reasoning(state, numbers)
|
| 172 |
+
else:
|
| 173 |
+
# Use LLM for statistical reasoning without clear data
|
| 174 |
+
return self._llm_statistical_reasoning(state, [])
|
| 175 |
+
|
| 176 |
+
def _process_unit_conversion(self, state: GAIAAgentState) -> AgentResult:
|
| 177 |
+
"""Process unit conversion questions"""
|
| 178 |
+
|
| 179 |
+
logger.info("Processing unit conversion")
|
| 180 |
+
|
| 181 |
+
# Extract conversion details
|
| 182 |
+
conversion_info = self._extract_conversion_info(state.question)
|
| 183 |
+
|
| 184 |
+
if conversion_info:
|
| 185 |
+
value, from_unit, to_unit = conversion_info
|
| 186 |
+
conversion_data = {
|
| 187 |
+
"operation": "convert",
|
| 188 |
+
"value": value,
|
| 189 |
+
"from_unit": from_unit,
|
| 190 |
+
"to_unit": to_unit
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
calc_result = self.calculator.execute(conversion_data)
|
| 194 |
+
|
| 195 |
+
if calc_result.success:
|
| 196 |
+
return self._analyze_conversion_results(state, calc_result, conversion_info)
|
| 197 |
+
else:
|
| 198 |
+
return self._llm_conversion_reasoning(state, conversion_info)
|
| 199 |
+
else:
|
| 200 |
+
# Use LLM for conversion reasoning
|
| 201 |
+
return self._llm_conversion_reasoning(state, None)
|
| 202 |
+
|
| 203 |
+
def _process_logical_deduction(self, state: GAIAAgentState) -> AgentResult:
|
| 204 |
+
"""Process logical reasoning and deduction questions"""
|
| 205 |
+
|
| 206 |
+
logger.info("Processing logical deduction")
|
| 207 |
+
|
| 208 |
+
# Use complex model for logical reasoning
|
| 209 |
+
reasoning_prompt = f"""
|
| 210 |
+
Please solve this logical reasoning problem step by step:
|
| 211 |
+
|
| 212 |
+
Question: {state.question}
|
| 213 |
+
|
| 214 |
+
Approach this systematically:
|
| 215 |
+
1. Identify the given information
|
| 216 |
+
2. Identify what needs to be determined
|
| 217 |
+
3. Apply logical rules and deduction
|
| 218 |
+
4. State your conclusion clearly
|
| 219 |
+
|
| 220 |
+
Please provide a clear, logical answer.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
model_tier = ModelTier.COMPLEX # Use best model for complex reasoning
|
| 224 |
+
llm_result = self.llm_client.generate(reasoning_prompt, tier=model_tier, max_tokens=600)
|
| 225 |
+
|
| 226 |
+
if llm_result.success:
|
| 227 |
+
return AgentResult(
|
| 228 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 229 |
+
success=True,
|
| 230 |
+
result=llm_result.response,
|
| 231 |
+
confidence=0.80,
|
| 232 |
+
reasoning="Applied logical deduction and reasoning",
|
| 233 |
+
model_used=llm_result.model_used,
|
| 234 |
+
processing_time=llm_result.response_time,
|
| 235 |
+
cost_estimate=llm_result.cost_estimate
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
return self._create_failure_result("Logical reasoning failed")
|
| 239 |
+
|
| 240 |
+
def _process_pattern_analysis(self, state: GAIAAgentState) -> AgentResult:
|
| 241 |
+
"""Process pattern recognition and analysis questions"""
|
| 242 |
+
|
| 243 |
+
logger.info("Processing pattern analysis")
|
| 244 |
+
|
| 245 |
+
# Extract sequences or patterns from question
|
| 246 |
+
numbers = self._extract_numbers(state.question)
|
| 247 |
+
|
| 248 |
+
pattern_prompt = f"""
|
| 249 |
+
Analyze this pattern or sequence problem:
|
| 250 |
+
|
| 251 |
+
Question: {state.question}
|
| 252 |
+
|
| 253 |
+
{"Numbers found: " + str(numbers) if numbers else ""}
|
| 254 |
+
|
| 255 |
+
Please:
|
| 256 |
+
1. Identify the pattern or rule
|
| 257 |
+
2. Explain the logic
|
| 258 |
+
3. Provide the answer
|
| 259 |
+
|
| 260 |
+
Be systematic and show your reasoning.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
model_tier = ModelTier.MAIN
|
| 264 |
+
llm_result = self.llm_client.generate(pattern_prompt, tier=model_tier, max_tokens=500)
|
| 265 |
+
|
| 266 |
+
if llm_result.success:
|
| 267 |
+
confidence = 0.75 if numbers else 0.65 # Higher confidence with numerical data
|
| 268 |
+
return AgentResult(
|
| 269 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 270 |
+
success=True,
|
| 271 |
+
result=llm_result.response,
|
| 272 |
+
confidence=confidence,
|
| 273 |
+
reasoning="Analyzed patterns and sequences",
|
| 274 |
+
model_used=llm_result.model_used,
|
| 275 |
+
processing_time=llm_result.response_time,
|
| 276 |
+
cost_estimate=llm_result.cost_estimate
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
return self._create_failure_result("Pattern analysis failed")
|
| 280 |
+
|
| 281 |
+
def _process_step_by_step(self, state: GAIAAgentState) -> AgentResult:
|
| 282 |
+
"""Process questions requiring step-by-step explanation"""
|
| 283 |
+
|
| 284 |
+
logger.info("Processing step-by-step reasoning")
|
| 285 |
+
|
| 286 |
+
step_prompt = f"""
|
| 287 |
+
Please solve this problem with a clear step-by-step approach:
|
| 288 |
+
|
| 289 |
+
Question: {state.question}
|
| 290 |
+
|
| 291 |
+
Structure your response as:
|
| 292 |
+
Step 1: [First step and reasoning]
|
| 293 |
+
Step 2: [Second step and reasoning]
|
| 294 |
+
...
|
| 295 |
+
Final Answer: [Clear conclusion]
|
| 296 |
+
|
| 297 |
+
Be thorough and explain each step.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
model_tier = ModelTier.MAIN
|
| 301 |
+
llm_result = self.llm_client.generate(step_prompt, tier=model_tier, max_tokens=600)
|
| 302 |
+
|
| 303 |
+
if llm_result.success:
|
| 304 |
+
return AgentResult(
|
| 305 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 306 |
+
success=True,
|
| 307 |
+
result=llm_result.response,
|
| 308 |
+
confidence=0.75,
|
| 309 |
+
reasoning="Provided step-by-step solution",
|
| 310 |
+
model_used=llm_result.model_used,
|
| 311 |
+
processing_time=llm_result.response_time,
|
| 312 |
+
cost_estimate=llm_result.cost_estimate
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
return self._create_failure_result("Step-by-step reasoning failed")
|
| 316 |
+
|
| 317 |
+
def _process_general_reasoning(self, state: GAIAAgentState) -> AgentResult:
|
| 318 |
+
"""Process general reasoning questions"""
|
| 319 |
+
|
| 320 |
+
logger.info("Processing general reasoning")
|
| 321 |
+
|
| 322 |
+
reasoning_prompt = f"""
|
| 323 |
+
Please analyze and answer this reasoning question:
|
| 324 |
+
|
| 325 |
+
Question: {state.question}
|
| 326 |
+
|
| 327 |
+
Think through this carefully and provide a well-reasoned answer.
|
| 328 |
+
Consider all aspects of the question and explain your reasoning.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
model_tier = ModelTier.MAIN
|
| 332 |
+
llm_result = self.llm_client.generate(reasoning_prompt, tier=model_tier, max_tokens=500)
|
| 333 |
+
|
| 334 |
+
if llm_result.success:
|
| 335 |
+
return AgentResult(
|
| 336 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 337 |
+
success=True,
|
| 338 |
+
result=llm_result.response,
|
| 339 |
+
confidence=0.70,
|
| 340 |
+
reasoning="Applied general reasoning and analysis",
|
| 341 |
+
model_used=llm_result.model_used,
|
| 342 |
+
processing_time=llm_result.response_time,
|
| 343 |
+
cost_estimate=llm_result.cost_estimate
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
return self._create_failure_result("General reasoning failed")
|
| 347 |
+
|
| 348 |
+
def _extract_mathematical_expressions(self, question: str) -> List[str]:
|
| 349 |
+
"""Extract mathematical expressions from question text"""
|
| 350 |
+
expressions = []
|
| 351 |
+
|
| 352 |
+
# Look for explicit mathematical expressions
|
| 353 |
+
math_patterns = [
|
| 354 |
+
r'\d+\s*[\+\-\*/]\s*\d+',
|
| 355 |
+
r'\d+\s*\^\s*\d+',
|
| 356 |
+
r'sqrt\(\d+\)',
|
| 357 |
+
r'\d+\s*%',
|
| 358 |
+
r'\d+\s*factorial',
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
for pattern in math_patterns:
|
| 362 |
+
matches = re.findall(pattern, question, re.IGNORECASE)
|
| 363 |
+
expressions.extend(matches)
|
| 364 |
+
|
| 365 |
+
return expressions
|
| 366 |
+
|
| 367 |
+
def _extract_numbers(self, question: str) -> List[float]:
|
| 368 |
+
"""Extract numerical values from question text"""
|
| 369 |
+
numbers = []
|
| 370 |
+
|
| 371 |
+
# Find all numbers (integers and floats)
|
| 372 |
+
number_pattern = r'[-+]?\d*\.?\d+'
|
| 373 |
+
matches = re.findall(number_pattern, question)
|
| 374 |
+
|
| 375 |
+
for match in matches:
|
| 376 |
+
try:
|
| 377 |
+
if '.' in match:
|
| 378 |
+
numbers.append(float(match))
|
| 379 |
+
else:
|
| 380 |
+
numbers.append(float(int(match)))
|
| 381 |
+
except ValueError:
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
return numbers
|
| 385 |
+
|
| 386 |
+
def _extract_conversion_info(self, question: str) -> Optional[tuple]:
|
| 387 |
+
"""Extract unit conversion information from question"""
|
| 388 |
+
|
| 389 |
+
# Pattern for "X unit to unit" format
|
| 390 |
+
conversion_pattern = r'(\d+(?:\.\d+)?)\s*(\w+)\s+to\s+(\w+)'
|
| 391 |
+
match = re.search(conversion_pattern, question.lower())
|
| 392 |
+
|
| 393 |
+
if match:
|
| 394 |
+
value, from_unit, to_unit = match.groups()
|
| 395 |
+
return float(value), from_unit, to_unit
|
| 396 |
+
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
def _analyze_calculation_results(self, state: GAIAAgentState, calc_results: List) -> AgentResult:
|
| 400 |
+
"""Analyze calculator results and provide answer"""
|
| 401 |
+
|
| 402 |
+
successful_results = [r for r in calc_results if r.success]
|
| 403 |
+
|
| 404 |
+
if successful_results:
|
| 405 |
+
result_summaries = []
|
| 406 |
+
total_cost = 0.0
|
| 407 |
+
total_time = 0.0
|
| 408 |
+
|
| 409 |
+
for calc_result in successful_results:
|
| 410 |
+
if calc_result.result.get('success'):
|
| 411 |
+
calc_data = calc_result.result['calculation']
|
| 412 |
+
result_summaries.append(f"{calc_data['expression']} = {calc_data['result']}")
|
| 413 |
+
total_cost += calc_result.result.get('cost_estimate', 0)
|
| 414 |
+
total_time += calc_result.execution_time
|
| 415 |
+
|
| 416 |
+
analysis_prompt = f"""
|
| 417 |
+
Based on these calculations, please answer the original question:
|
| 418 |
+
|
| 419 |
+
Question: {state.question}
|
| 420 |
+
|
| 421 |
+
Calculation Results:
|
| 422 |
+
{chr(10).join(result_summaries)}
|
| 423 |
+
|
| 424 |
+
Please provide a direct answer incorporating these calculations.
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=ModelTier.MAIN, max_tokens=400)
|
| 428 |
+
|
| 429 |
+
if llm_result.success:
|
| 430 |
+
return AgentResult(
|
| 431 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 432 |
+
success=True,
|
| 433 |
+
result=llm_result.response,
|
| 434 |
+
confidence=0.85,
|
| 435 |
+
reasoning="Performed calculations and analyzed results",
|
| 436 |
+
tools_used=[ToolResult(
|
| 437 |
+
tool_name="calculator",
|
| 438 |
+
success=True,
|
| 439 |
+
result=result_summaries,
|
| 440 |
+
execution_time=total_time
|
| 441 |
+
)],
|
| 442 |
+
model_used=llm_result.model_used,
|
| 443 |
+
processing_time=total_time + llm_result.response_time,
|
| 444 |
+
cost_estimate=total_cost + llm_result.cost_estimate
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
return self._create_failure_result("Mathematical calculations failed")
|
| 448 |
+
|
| 449 |
+
def _analyze_statistical_results(self, state: GAIAAgentState, calc_result, numbers: List[float]) -> AgentResult:
|
| 450 |
+
"""Analyze statistical calculation results"""
|
| 451 |
+
|
| 452 |
+
if calc_result.success and calc_result.result.get('success'):
|
| 453 |
+
stats = calc_result.result['statistics']
|
| 454 |
+
|
| 455 |
+
analysis_prompt = f"""
|
| 456 |
+
Based on this statistical analysis, please answer the question:
|
| 457 |
+
|
| 458 |
+
Question: {state.question}
|
| 459 |
+
|
| 460 |
+
Data: {numbers}
|
| 461 |
+
Statistical Results:
|
| 462 |
+
- Count: {stats.get('count')}
|
| 463 |
+
- Mean: {stats.get('mean')}
|
| 464 |
+
- Median: {stats.get('median')}
|
| 465 |
+
- Min: {stats.get('min')}
|
| 466 |
+
- Max: {stats.get('max')}
|
| 467 |
+
- Standard Deviation: {stats.get('stdev', 'N/A')}
|
| 468 |
+
|
| 469 |
+
Please provide a direct answer based on this statistical analysis.
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=ModelTier.MAIN, max_tokens=400)
|
| 473 |
+
|
| 474 |
+
if llm_result.success:
|
| 475 |
+
return AgentResult(
|
| 476 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 477 |
+
success=True,
|
| 478 |
+
result=llm_result.response,
|
| 479 |
+
confidence=0.85,
|
| 480 |
+
reasoning="Performed statistical analysis",
|
| 481 |
+
tools_used=[ToolResult(
|
| 482 |
+
tool_name="calculator",
|
| 483 |
+
success=True,
|
| 484 |
+
result=stats,
|
| 485 |
+
execution_time=calc_result.execution_time
|
| 486 |
+
)],
|
| 487 |
+
model_used=llm_result.model_used,
|
| 488 |
+
processing_time=calc_result.execution_time + llm_result.response_time,
|
| 489 |
+
cost_estimate=llm_result.cost_estimate
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
return self._create_failure_result("Statistical analysis failed")
|
| 493 |
+
|
| 494 |
+
def _analyze_conversion_results(self, state: GAIAAgentState, calc_result, conversion_info: tuple) -> AgentResult:
|
| 495 |
+
"""Analyze unit conversion results"""
|
| 496 |
+
|
| 497 |
+
if calc_result.success and calc_result.result.get('success'):
|
| 498 |
+
conversion_data = calc_result.result['conversion']
|
| 499 |
+
value, from_unit, to_unit = conversion_info
|
| 500 |
+
|
| 501 |
+
analysis_prompt = f"""
|
| 502 |
+
Based on this unit conversion, please answer the question:
|
| 503 |
+
|
| 504 |
+
Question: {state.question}
|
| 505 |
+
|
| 506 |
+
Conversion: {value} {from_unit} = {conversion_data['result']} {conversion_data['units']}
|
| 507 |
+
|
| 508 |
+
Please provide a direct answer incorporating this conversion.
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=ModelTier.ROUTER, max_tokens=300)
|
| 512 |
+
|
| 513 |
+
if llm_result.success:
|
| 514 |
+
return AgentResult(
|
| 515 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 516 |
+
success=True,
|
| 517 |
+
result=llm_result.response,
|
| 518 |
+
confidence=0.90,
|
| 519 |
+
reasoning="Performed unit conversion",
|
| 520 |
+
tools_used=[ToolResult(
|
| 521 |
+
tool_name="calculator",
|
| 522 |
+
success=True,
|
| 523 |
+
result=conversion_data,
|
| 524 |
+
execution_time=calc_result.execution_time
|
| 525 |
+
)],
|
| 526 |
+
model_used=llm_result.model_used,
|
| 527 |
+
processing_time=calc_result.execution_time + llm_result.response_time,
|
| 528 |
+
cost_estimate=llm_result.cost_estimate
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
return self._create_failure_result("Unit conversion failed")
|
| 532 |
+
|
| 533 |
+
def _llm_mathematical_reasoning(self, state: GAIAAgentState) -> AgentResult:
|
| 534 |
+
"""Fallback to LLM-only mathematical reasoning"""
|
| 535 |
+
|
| 536 |
+
math_prompt = f"""
|
| 537 |
+
Please solve this mathematical problem:
|
| 538 |
+
|
| 539 |
+
Question: {state.question}
|
| 540 |
+
|
| 541 |
+
Show your mathematical reasoning and calculations step by step.
|
| 542 |
+
Provide a clear numerical answer.
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
model_tier = ModelTier.MAIN
|
| 546 |
+
llm_result = self.llm_client.generate(math_prompt, tier=model_tier, max_tokens=500)
|
| 547 |
+
|
| 548 |
+
if llm_result.success:
|
| 549 |
+
return AgentResult(
|
| 550 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 551 |
+
success=True,
|
| 552 |
+
result=llm_result.response,
|
| 553 |
+
confidence=0.70,
|
| 554 |
+
reasoning="Applied mathematical reasoning (LLM-only)",
|
| 555 |
+
model_used=llm_result.model_used,
|
| 556 |
+
processing_time=llm_result.response_time,
|
| 557 |
+
cost_estimate=llm_result.cost_estimate
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
return self._create_failure_result("Mathematical reasoning failed")
|
| 561 |
+
|
| 562 |
+
def _llm_statistical_reasoning(self, state: GAIAAgentState, numbers: List[float]) -> AgentResult:
|
| 563 |
+
"""Fallback to LLM-only statistical reasoning"""
|
| 564 |
+
|
| 565 |
+
stats_prompt = f"""
|
| 566 |
+
Please analyze this statistical problem:
|
| 567 |
+
|
| 568 |
+
Question: {state.question}
|
| 569 |
+
|
| 570 |
+
{"Numbers identified: " + str(numbers) if numbers else ""}
|
| 571 |
+
|
| 572 |
+
Apply statistical reasoning and provide a clear answer.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
model_tier = ModelTier.MAIN
|
| 576 |
+
llm_result = self.llm_client.generate(stats_prompt, tier=model_tier, max_tokens=400)
|
| 577 |
+
|
| 578 |
+
if llm_result.success:
|
| 579 |
+
return AgentResult(
|
| 580 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 581 |
+
success=True,
|
| 582 |
+
result=llm_result.response,
|
| 583 |
+
confidence=0.65,
|
| 584 |
+
reasoning="Applied statistical reasoning (LLM-only)",
|
| 585 |
+
model_used=llm_result.model_used,
|
| 586 |
+
processing_time=llm_result.response_time,
|
| 587 |
+
cost_estimate=llm_result.cost_estimate
|
| 588 |
+
)
|
| 589 |
+
else:
|
| 590 |
+
return self._create_failure_result("Statistical reasoning failed")
|
| 591 |
+
|
| 592 |
+
def _llm_conversion_reasoning(self, state: GAIAAgentState, conversion_info: Optional[tuple]) -> AgentResult:
|
| 593 |
+
"""Fallback to LLM-only conversion reasoning"""
|
| 594 |
+
|
| 595 |
+
conversion_prompt = f"""
|
| 596 |
+
Please solve this unit conversion problem:
|
| 597 |
+
|
| 598 |
+
Question: {state.question}
|
| 599 |
+
|
| 600 |
+
{f"Conversion detected: {conversion_info}" if conversion_info else ""}
|
| 601 |
+
|
| 602 |
+
Apply conversion reasoning and provide a clear answer.
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
model_tier = ModelTier.ROUTER
|
| 606 |
+
llm_result = self.llm_client.generate(conversion_prompt, tier=model_tier, max_tokens=300)
|
| 607 |
+
|
| 608 |
+
if llm_result.success:
|
| 609 |
+
return AgentResult(
|
| 610 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 611 |
+
success=True,
|
| 612 |
+
result=llm_result.response,
|
| 613 |
+
confidence=0.65,
|
| 614 |
+
reasoning="Applied conversion reasoning (LLM-only)",
|
| 615 |
+
model_used=llm_result.model_used,
|
| 616 |
+
processing_time=llm_result.response_time,
|
| 617 |
+
cost_estimate=llm_result.cost_estimate
|
| 618 |
+
)
|
| 619 |
+
else:
|
| 620 |
+
return self._create_failure_result("Conversion reasoning failed")
|
| 621 |
+
|
| 622 |
+
def _create_failure_result(self, error_message: str) -> AgentResult:
|
| 623 |
+
"""Create a failure result"""
|
| 624 |
+
return AgentResult(
|
| 625 |
+
agent_role=AgentRole.REASONING_AGENT,
|
| 626 |
+
success=False,
|
| 627 |
+
result=error_message,
|
| 628 |
+
confidence=0.0,
|
| 629 |
+
reasoning=error_message,
|
| 630 |
+
model_used="error",
|
| 631 |
+
processing_time=0.0,
|
| 632 |
+
cost_estimate=0.0
|
| 633 |
+
)
|
src/agents/router.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Router Agent for GAIA Question Classification
|
| 4 |
+
Analyzes questions and routes them to appropriate specialized agents
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
from agents.state import GAIAAgentState, QuestionType, AgentRole, AgentResult
|
| 13 |
+
from models.qwen_client import QwenClient, ModelTier
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class RouterAgent:
|
| 18 |
+
"""
|
| 19 |
+
Router agent that classifies GAIA questions and determines processing strategy
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, llm_client: QwenClient):
|
| 23 |
+
self.llm_client = llm_client
|
| 24 |
+
|
| 25 |
+
def route_question(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 26 |
+
"""
|
| 27 |
+
Main routing function - analyzes question and updates state with routing decisions
|
| 28 |
+
"""
|
| 29 |
+
logger.info(f"Routing question: {state.question[:100]}...")
|
| 30 |
+
state.add_processing_step("Router: Starting question analysis")
|
| 31 |
+
|
| 32 |
+
# Step 1: Rule-based classification
|
| 33 |
+
question_type = self._classify_question_type(state.question, state.file_name)
|
| 34 |
+
state.question_type = question_type
|
| 35 |
+
state.add_processing_step(f"Router: Classified as {question_type.value}")
|
| 36 |
+
|
| 37 |
+
# Step 2: Complexity assessment
|
| 38 |
+
complexity = self._assess_complexity(state.question)
|
| 39 |
+
state.complexity_assessment = complexity
|
| 40 |
+
state.add_processing_step(f"Router: Assessed complexity as {complexity}")
|
| 41 |
+
|
| 42 |
+
# Step 3: Select appropriate agents
|
| 43 |
+
selected_agents = self._select_agents(question_type, state.file_name is not None)
|
| 44 |
+
state.selected_agents = selected_agents
|
| 45 |
+
state.add_processing_step(f"Router: Selected agents: {[a.value for a in selected_agents]}")
|
| 46 |
+
|
| 47 |
+
# Step 4: Estimate cost
|
| 48 |
+
estimated_cost = self._estimate_cost(complexity, selected_agents)
|
| 49 |
+
state.estimated_cost = estimated_cost
|
| 50 |
+
state.add_processing_step(f"Router: Estimated cost: ${estimated_cost:.4f}")
|
| 51 |
+
|
| 52 |
+
# Step 5: Create routing decision summary
|
| 53 |
+
state.routing_decision = {
|
| 54 |
+
"question_type": question_type.value,
|
| 55 |
+
"complexity": complexity,
|
| 56 |
+
"agents": [agent.value for agent in selected_agents],
|
| 57 |
+
"estimated_cost": estimated_cost,
|
| 58 |
+
"reasoning": self._get_routing_reasoning(question_type, complexity, selected_agents)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Step 6: Use LLM for complex routing decisions if needed
|
| 62 |
+
if complexity == "complex" or question_type == QuestionType.UNKNOWN:
|
| 63 |
+
state = self._llm_enhanced_routing(state)
|
| 64 |
+
|
| 65 |
+
logger.info(f"✅ Routing complete: {question_type.value} -> {[a.value for a in selected_agents]}")
|
| 66 |
+
return state
|
| 67 |
+
|
| 68 |
+
def _classify_question_type(self, question: str, file_name: str = None) -> QuestionType:
|
| 69 |
+
"""Classify question type using rule-based analysis"""
|
| 70 |
+
|
| 71 |
+
question_lower = question.lower()
|
| 72 |
+
|
| 73 |
+
# File processing questions
|
| 74 |
+
if file_name:
|
| 75 |
+
file_ext = file_name.lower().split('.')[-1] if '.' in file_name else ""
|
| 76 |
+
|
| 77 |
+
if file_ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg']:
|
| 78 |
+
return QuestionType.FILE_PROCESSING
|
| 79 |
+
elif file_ext in ['mp3', 'wav', 'ogg', 'flac', 'm4a']:
|
| 80 |
+
return QuestionType.FILE_PROCESSING
|
| 81 |
+
elif file_ext in ['xlsx', 'xls', 'csv']:
|
| 82 |
+
return QuestionType.FILE_PROCESSING
|
| 83 |
+
elif file_ext in ['py', 'js', 'java', 'cpp', 'c']:
|
| 84 |
+
return QuestionType.CODE_EXECUTION
|
| 85 |
+
else:
|
| 86 |
+
return QuestionType.FILE_PROCESSING
|
| 87 |
+
|
| 88 |
+
# URL-based classification
|
| 89 |
+
url_patterns = {
|
| 90 |
+
QuestionType.WIKIPEDIA: [
|
| 91 |
+
r'wikipedia\.org', r'wiki', r'featured article', r'promoted.*wikipedia'
|
| 92 |
+
],
|
| 93 |
+
QuestionType.YOUTUBE: [
|
| 94 |
+
r'youtube\.com', r'youtu\.be', r'watch\?v=', r'video'
|
| 95 |
+
]
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
for question_type, patterns in url_patterns.items():
|
| 99 |
+
if any(re.search(pattern, question_lower) for pattern in patterns):
|
| 100 |
+
return question_type
|
| 101 |
+
|
| 102 |
+
# Content-based classification
|
| 103 |
+
classification_patterns = {
|
| 104 |
+
QuestionType.MATHEMATICAL: [
|
| 105 |
+
r'calculate', r'compute', r'solve', r'equation', r'formula',
|
| 106 |
+
r'sum', r'total', r'average', r'percentage', r'ratio',
|
| 107 |
+
r'how many', r'how much', r'\d+.*\d+', r'math'
|
| 108 |
+
],
|
| 109 |
+
QuestionType.CODE_EXECUTION: [
|
| 110 |
+
r'code', r'program', r'script', r'function', r'algorithm',
|
| 111 |
+
r'execute', r'run.*code', r'python', r'javascript'
|
| 112 |
+
],
|
| 113 |
+
QuestionType.TEXT_MANIPULATION: [
|
| 114 |
+
r'reverse', r'encode', r'decode', r'transform', r'convert',
|
| 115 |
+
r'uppercase', r'lowercase', r'replace', r'extract'
|
| 116 |
+
],
|
| 117 |
+
QuestionType.REASONING: [
|
| 118 |
+
r'why', r'explain', r'analyze', r'reasoning', r'logic',
|
| 119 |
+
r'relationship', r'compare', r'contrast', r'conclusion'
|
| 120 |
+
],
|
| 121 |
+
QuestionType.WEB_RESEARCH: [
|
| 122 |
+
r'search', r'find.*information', r'research', r'look up',
|
| 123 |
+
r'website', r'online', r'internet'
|
| 124 |
+
]
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
# Score each category
|
| 128 |
+
type_scores = {}
|
| 129 |
+
for question_type, patterns in classification_patterns.items():
|
| 130 |
+
score = sum(1 for pattern in patterns if re.search(pattern, question_lower))
|
| 131 |
+
if score > 0:
|
| 132 |
+
type_scores[question_type] = score
|
| 133 |
+
|
| 134 |
+
# Return highest scoring type, or UNKNOWN if no clear match
|
| 135 |
+
if type_scores:
|
| 136 |
+
return max(type_scores.keys(), key=lambda t: type_scores[t])
|
| 137 |
+
|
| 138 |
+
return QuestionType.UNKNOWN
|
| 139 |
+
|
| 140 |
+
def _assess_complexity(self, question: str) -> str:
|
| 141 |
+
"""Assess question complexity"""
|
| 142 |
+
|
| 143 |
+
question_lower = question.lower()
|
| 144 |
+
|
| 145 |
+
# Complex indicators
|
| 146 |
+
complex_indicators = [
|
| 147 |
+
'multi-step', 'multiple', 'several', 'complex', 'detailed',
|
| 148 |
+
'analyze', 'explain why', 'reasoning', 'relationship',
|
| 149 |
+
'compare and contrast', 'comprehensive', 'thorough'
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Simple indicators
|
| 153 |
+
simple_indicators = [
|
| 154 |
+
'what is', 'who is', 'when', 'where', 'yes or no',
|
| 155 |
+
'true or false', 'simple', 'quick', 'name', 'list'
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower)
|
| 159 |
+
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower)
|
| 160 |
+
|
| 161 |
+
# Additional complexity factors
|
| 162 |
+
if len(question) > 200:
|
| 163 |
+
complex_score += 1
|
| 164 |
+
if len(question.split()) > 30:
|
| 165 |
+
complex_score += 1
|
| 166 |
+
if question.count('?') > 2: # Multiple questions
|
| 167 |
+
complex_score += 1
|
| 168 |
+
|
| 169 |
+
# Determine complexity
|
| 170 |
+
if complex_score >= 2:
|
| 171 |
+
return "complex"
|
| 172 |
+
elif simple_score >= 2 and complex_score == 0:
|
| 173 |
+
return "simple"
|
| 174 |
+
else:
|
| 175 |
+
return "medium"
|
| 176 |
+
|
| 177 |
+
def _select_agents(self, question_type: QuestionType, has_file: bool) -> List[AgentRole]:
|
| 178 |
+
"""Select appropriate agents based on question type and presence of files"""
|
| 179 |
+
|
| 180 |
+
agents = []
|
| 181 |
+
|
| 182 |
+
# Always include synthesizer for final answer compilation
|
| 183 |
+
agents.append(AgentRole.SYNTHESIZER)
|
| 184 |
+
|
| 185 |
+
# Type-specific agent selection
|
| 186 |
+
if question_type in [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.YOUTUBE]:
|
| 187 |
+
agents.append(AgentRole.WEB_RESEARCHER)
|
| 188 |
+
|
| 189 |
+
elif question_type == QuestionType.FILE_PROCESSING:
|
| 190 |
+
agents.append(AgentRole.FILE_PROCESSOR)
|
| 191 |
+
|
| 192 |
+
elif question_type == QuestionType.CODE_EXECUTION:
|
| 193 |
+
agents.append(AgentRole.CODE_EXECUTOR)
|
| 194 |
+
|
| 195 |
+
elif question_type in [QuestionType.MATHEMATICAL, QuestionType.REASONING]:
|
| 196 |
+
agents.append(AgentRole.REASONING_AGENT)
|
| 197 |
+
|
| 198 |
+
elif question_type == QuestionType.TEXT_MANIPULATION:
|
| 199 |
+
agents.append(AgentRole.REASONING_AGENT) # Can handle text operations
|
| 200 |
+
|
| 201 |
+
else: # UNKNOWN or complex cases
|
| 202 |
+
# Use multiple agents for better coverage
|
| 203 |
+
agents.extend([AgentRole.WEB_RESEARCHER, AgentRole.REASONING_AGENT])
|
| 204 |
+
if has_file:
|
| 205 |
+
agents.append(AgentRole.FILE_PROCESSOR)
|
| 206 |
+
|
| 207 |
+
# Remove duplicates while preserving order
|
| 208 |
+
seen = set()
|
| 209 |
+
unique_agents = []
|
| 210 |
+
for agent in agents:
|
| 211 |
+
if agent not in seen:
|
| 212 |
+
seen.add(agent)
|
| 213 |
+
unique_agents.append(agent)
|
| 214 |
+
|
| 215 |
+
return unique_agents
|
| 216 |
+
|
| 217 |
+
def _estimate_cost(self, complexity: str, agents: List[AgentRole]) -> float:
|
| 218 |
+
"""Estimate processing cost based on complexity and agents"""
|
| 219 |
+
|
| 220 |
+
base_costs = {
|
| 221 |
+
"simple": 0.005, # Router model mostly
|
| 222 |
+
"medium": 0.015, # Mix of router and main
|
| 223 |
+
"complex": 0.035 # Include complex model usage
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
base_cost = base_costs.get(complexity, 0.015)
|
| 227 |
+
|
| 228 |
+
# Additional cost per agent
|
| 229 |
+
agent_cost = len(agents) * 0.005
|
| 230 |
+
|
| 231 |
+
return base_cost + agent_cost
|
| 232 |
+
|
| 233 |
+
def _get_routing_reasoning(self, question_type: QuestionType, complexity: str, agents: List[AgentRole]) -> str:
|
| 234 |
+
"""Generate human-readable reasoning for routing decision"""
|
| 235 |
+
|
| 236 |
+
reasons = []
|
| 237 |
+
|
| 238 |
+
# Question type reasoning
|
| 239 |
+
if question_type == QuestionType.WIKIPEDIA:
|
| 240 |
+
reasons.append("Question references Wikipedia content")
|
| 241 |
+
elif question_type == QuestionType.YOUTUBE:
|
| 242 |
+
reasons.append("Question involves YouTube video analysis")
|
| 243 |
+
elif question_type == QuestionType.FILE_PROCESSING:
|
| 244 |
+
reasons.append("Question requires file processing")
|
| 245 |
+
elif question_type == QuestionType.MATHEMATICAL:
|
| 246 |
+
reasons.append("Question involves mathematical computation")
|
| 247 |
+
elif question_type == QuestionType.CODE_EXECUTION:
|
| 248 |
+
reasons.append("Question requires code execution")
|
| 249 |
+
elif question_type == QuestionType.REASONING:
|
| 250 |
+
reasons.append("Question requires logical reasoning")
|
| 251 |
+
|
| 252 |
+
# Complexity reasoning
|
| 253 |
+
if complexity == "complex":
|
| 254 |
+
reasons.append("Complex reasoning required")
|
| 255 |
+
elif complexity == "simple":
|
| 256 |
+
reasons.append("Straightforward question")
|
| 257 |
+
|
| 258 |
+
# Agent reasoning
|
| 259 |
+
agent_names = [agent.value.replace('_', ' ') for agent in agents]
|
| 260 |
+
reasons.append(f"Selected agents: {', '.join(agent_names)}")
|
| 261 |
+
|
| 262 |
+
return "; ".join(reasons)
|
| 263 |
+
|
| 264 |
+
def _llm_enhanced_routing(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 265 |
+
"""Use LLM for enhanced routing analysis of complex/unknown questions"""
|
| 266 |
+
|
| 267 |
+
prompt = f"""
|
| 268 |
+
Analyze this GAIA benchmark question and provide routing guidance:
|
| 269 |
+
|
| 270 |
+
Question: {state.question}
|
| 271 |
+
File attached: {state.file_name if state.file_name else "None"}
|
| 272 |
+
Current classification: {state.question_type.value}
|
| 273 |
+
Current complexity: {state.complexity_assessment}
|
| 274 |
+
|
| 275 |
+
Please provide:
|
| 276 |
+
1. Confirm or correct the question type
|
| 277 |
+
2. Confirm or adjust complexity assessment
|
| 278 |
+
3. Key challenges in answering this question
|
| 279 |
+
4. Recommended approach
|
| 280 |
+
|
| 281 |
+
Keep response concise and focused on routing decisions.
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
# Use router model for this analysis
|
| 286 |
+
tier = ModelTier.ROUTER if state.complexity_assessment != "complex" else ModelTier.MAIN
|
| 287 |
+
result = self.llm_client.generate(prompt, tier=tier, max_tokens=200)
|
| 288 |
+
|
| 289 |
+
if result.success:
|
| 290 |
+
state.add_processing_step("Router: Enhanced with LLM analysis")
|
| 291 |
+
state.routing_decision["llm_analysis"] = result.response
|
| 292 |
+
logger.info("✅ LLM enhanced routing completed")
|
| 293 |
+
else:
|
| 294 |
+
state.add_error(f"LLM routing enhancement failed: {result.error}")
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
state.add_error(f"LLM routing error: {str(e)}")
|
| 298 |
+
logger.error(f"LLM routing failed: {e}")
|
| 299 |
+
|
| 300 |
+
return state
|
src/agents/state.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LangGraph State Schema for GAIA Agent System
|
| 4 |
+
Defines the state structure for agent communication and coordination
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any, List, Optional, Literal
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from enum import Enum
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
class QuestionType(Enum):
|
| 13 |
+
"""Classification of GAIA question types"""
|
| 14 |
+
WIKIPEDIA = "wikipedia"
|
| 15 |
+
WEB_RESEARCH = "web_research"
|
| 16 |
+
YOUTUBE = "youtube"
|
| 17 |
+
FILE_PROCESSING = "file_processing"
|
| 18 |
+
MATHEMATICAL = "mathematical"
|
| 19 |
+
CODE_EXECUTION = "code_execution"
|
| 20 |
+
TEXT_MANIPULATION = "text_manipulation"
|
| 21 |
+
REASONING = "reasoning"
|
| 22 |
+
UNKNOWN = "unknown"
|
| 23 |
+
|
| 24 |
+
class ModelTier(Enum):
|
| 25 |
+
"""Model complexity tiers"""
|
| 26 |
+
ROUTER = "router" # 7B - Fast classification
|
| 27 |
+
MAIN = "main" # 32B - Balanced tasks
|
| 28 |
+
COMPLEX = "complex" # 72B - Complex reasoning
|
| 29 |
+
|
| 30 |
+
class AgentRole(Enum):
|
| 31 |
+
"""Roles of different agents in the system"""
|
| 32 |
+
ROUTER = "router"
|
| 33 |
+
WEB_RESEARCHER = "web_researcher"
|
| 34 |
+
FILE_PROCESSOR = "file_processor"
|
| 35 |
+
CODE_EXECUTOR = "code_executor"
|
| 36 |
+
REASONING_AGENT = "reasoning_agent"
|
| 37 |
+
SYNTHESIZER = "synthesizer"
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class ToolResult:
|
| 41 |
+
"""Result from a tool execution"""
|
| 42 |
+
tool_name: str
|
| 43 |
+
success: bool
|
| 44 |
+
result: Any
|
| 45 |
+
error: Optional[str] = None
|
| 46 |
+
execution_time: float = 0.0
|
| 47 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class AgentResult:
|
| 51 |
+
"""Result from an agent's processing"""
|
| 52 |
+
agent_role: AgentRole
|
| 53 |
+
success: bool
|
| 54 |
+
result: str
|
| 55 |
+
confidence: float # 0.0 to 1.0
|
| 56 |
+
reasoning: str
|
| 57 |
+
tools_used: List[ToolResult] = field(default_factory=list)
|
| 58 |
+
model_used: str = ""
|
| 59 |
+
processing_time: float = 0.0
|
| 60 |
+
cost_estimate: float = 0.0
|
| 61 |
+
|
| 62 |
+
class GAIAAgentState:
|
| 63 |
+
"""
|
| 64 |
+
Central state for the GAIA agent system
|
| 65 |
+
This is passed between all agents in the LangGraph workflow
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
# Question information
|
| 70 |
+
self.task_id: str = ""
|
| 71 |
+
self.question: str = ""
|
| 72 |
+
self.question_type: QuestionType = QuestionType.UNKNOWN
|
| 73 |
+
self.difficulty_level: int = 1 # 1, 2, or 3
|
| 74 |
+
self.file_name: Optional[str] = None
|
| 75 |
+
self.file_path: Optional[str] = None
|
| 76 |
+
self.metadata: Dict[str, Any] = {}
|
| 77 |
+
|
| 78 |
+
# Routing decisions
|
| 79 |
+
self.routing_decision: Dict[str, Any] = {}
|
| 80 |
+
self.selected_agents: List[AgentRole] = []
|
| 81 |
+
self.complexity_assessment: str = "medium"
|
| 82 |
+
self.estimated_cost: float = 0.0
|
| 83 |
+
|
| 84 |
+
# Agent results
|
| 85 |
+
self.agent_results: Dict[AgentRole, AgentResult] = {}
|
| 86 |
+
self.tool_results: List[ToolResult] = []
|
| 87 |
+
|
| 88 |
+
# Final answer
|
| 89 |
+
self.final_answer: str = ""
|
| 90 |
+
self.final_confidence: float = 0.0
|
| 91 |
+
self.final_reasoning: str = ""
|
| 92 |
+
self.answer_source: str = "" # Which agent provided the final answer
|
| 93 |
+
|
| 94 |
+
# System tracking
|
| 95 |
+
self.start_time: float = time.time()
|
| 96 |
+
self.processing_steps: List[str] = []
|
| 97 |
+
self.total_cost: float = 0.0
|
| 98 |
+
self.total_processing_time: float = 0.0
|
| 99 |
+
self.error_messages: List[str] = []
|
| 100 |
+
|
| 101 |
+
# Status flags
|
| 102 |
+
self.is_complete: bool = False
|
| 103 |
+
self.requires_human_review: bool = False
|
| 104 |
+
self.confidence_threshold_met: bool = False
|
| 105 |
+
|
| 106 |
+
def add_processing_step(self, step: str):
|
| 107 |
+
"""Add a processing step to the history"""
|
| 108 |
+
self.processing_steps.append(f"[{time.time() - self.start_time:.2f}s] {step}")
|
| 109 |
+
|
| 110 |
+
def add_agent_result(self, result: AgentResult):
|
| 111 |
+
"""Add result from an agent"""
|
| 112 |
+
self.agent_results[result.agent_role] = result
|
| 113 |
+
self.total_cost += result.cost_estimate
|
| 114 |
+
self.total_processing_time += result.processing_time
|
| 115 |
+
self.add_processing_step(f"{result.agent_role.value}: {result.result[:50]}...")
|
| 116 |
+
|
| 117 |
+
def add_tool_result(self, result: ToolResult):
|
| 118 |
+
"""Add result from a tool execution"""
|
| 119 |
+
self.tool_results.append(result)
|
| 120 |
+
self.add_processing_step(f"Tool {result.tool_name}: {'✅' if result.success else '❌'}")
|
| 121 |
+
|
| 122 |
+
def add_error(self, error_message: str):
|
| 123 |
+
"""Add an error message"""
|
| 124 |
+
self.error_messages.append(error_message)
|
| 125 |
+
self.add_processing_step(f"ERROR: {error_message}")
|
| 126 |
+
|
| 127 |
+
def get_best_result(self) -> Optional[AgentResult]:
|
| 128 |
+
"""Get the agent result with highest confidence"""
|
| 129 |
+
if not self.agent_results:
|
| 130 |
+
return None
|
| 131 |
+
return max(self.agent_results.values(), key=lambda r: r.confidence)
|
| 132 |
+
|
| 133 |
+
def should_use_complex_model(self) -> bool:
|
| 134 |
+
"""Determine if complex model should be used based on state"""
|
| 135 |
+
# Use complex model for:
|
| 136 |
+
# - High difficulty questions
|
| 137 |
+
# - Questions requiring detailed reasoning
|
| 138 |
+
# - When we have budget remaining
|
| 139 |
+
return (
|
| 140 |
+
self.difficulty_level >= 3 or
|
| 141 |
+
self.complexity_assessment == "complex" or
|
| 142 |
+
any("reasoning" in step.lower() for step in self.processing_steps)
|
| 143 |
+
) and self.total_cost < 0.07 # Keep some budget for complex tasks
|
| 144 |
+
|
| 145 |
+
def get_summary(self) -> Dict[str, Any]:
|
| 146 |
+
"""Get a summary of the current state"""
|
| 147 |
+
return {
|
| 148 |
+
"task_id": self.task_id,
|
| 149 |
+
"question_type": self.question_type.value,
|
| 150 |
+
"agents_used": [role.value for role in self.agent_results.keys()],
|
| 151 |
+
"tools_used": [tool.tool_name for tool in self.tool_results],
|
| 152 |
+
"final_answer": self.final_answer,
|
| 153 |
+
"confidence": self.final_confidence,
|
| 154 |
+
"processing_time": self.total_processing_time,
|
| 155 |
+
"total_cost": self.total_cost,
|
| 156 |
+
"steps_count": len(self.processing_steps),
|
| 157 |
+
"is_complete": self.is_complete,
|
| 158 |
+
"error_count": len(self.error_messages)
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 162 |
+
"""Convert state to dictionary for serialization"""
|
| 163 |
+
return {
|
| 164 |
+
"task_id": self.task_id,
|
| 165 |
+
"question": self.question,
|
| 166 |
+
"question_type": self.question_type.value,
|
| 167 |
+
"difficulty_level": self.difficulty_level,
|
| 168 |
+
"file_name": self.file_name,
|
| 169 |
+
"file_path": self.file_path,
|
| 170 |
+
"routing_decision": self.routing_decision,
|
| 171 |
+
"selected_agents": [agent.value for agent in self.selected_agents],
|
| 172 |
+
"complexity_assessment": self.complexity_assessment,
|
| 173 |
+
"final_answer": self.final_answer,
|
| 174 |
+
"final_confidence": self.final_confidence,
|
| 175 |
+
"final_reasoning": self.final_reasoning,
|
| 176 |
+
"answer_source": self.answer_source,
|
| 177 |
+
"processing_steps": self.processing_steps,
|
| 178 |
+
"total_cost": self.total_cost,
|
| 179 |
+
"total_processing_time": self.total_processing_time,
|
| 180 |
+
"error_messages": self.error_messages,
|
| 181 |
+
"is_complete": self.is_complete,
|
| 182 |
+
"summary": self.get_summary()
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
# Type alias for LangGraph
|
| 186 |
+
AgentState = GAIAAgentState
|
src/agents/synthesizer.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Synthesizer Agent for GAIA Agent System
|
| 4 |
+
Combines results from multiple agents and produces final answers
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict, List, Optional, Any
|
| 9 |
+
from statistics import mean
|
| 10 |
+
|
| 11 |
+
from agents.state import GAIAAgentState, AgentRole, AgentResult
|
| 12 |
+
from models.qwen_client import QwenClient, ModelTier
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class SynthesizerAgent:
|
| 17 |
+
"""
|
| 18 |
+
Synthesizer agent that combines multiple agent results into a final answer
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, llm_client: QwenClient):
|
| 22 |
+
self.llm_client = llm_client
|
| 23 |
+
|
| 24 |
+
def process(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 25 |
+
"""
|
| 26 |
+
Synthesize final answer from multiple agent results
|
| 27 |
+
"""
|
| 28 |
+
logger.info("Synthesizer: Starting result synthesis")
|
| 29 |
+
state.add_processing_step("Synthesizer: Analyzing agent results")
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
# Check if we have any agent results to synthesize
|
| 33 |
+
if not state.agent_results:
|
| 34 |
+
error_msg = "No agent results available for synthesis"
|
| 35 |
+
state.add_error(error_msg)
|
| 36 |
+
state.final_answer = "Unable to process question - no agent results available"
|
| 37 |
+
state.final_confidence = 0.0
|
| 38 |
+
state.final_reasoning = error_msg
|
| 39 |
+
state.is_complete = True
|
| 40 |
+
return state
|
| 41 |
+
|
| 42 |
+
# Determine synthesis strategy based on available results
|
| 43 |
+
synthesis_strategy = self._determine_synthesis_strategy(state)
|
| 44 |
+
state.add_processing_step(f"Synthesizer: Using {synthesis_strategy} strategy")
|
| 45 |
+
|
| 46 |
+
# Execute synthesis based on strategy
|
| 47 |
+
if synthesis_strategy == "single_agent":
|
| 48 |
+
final_result = self._synthesize_single_agent(state)
|
| 49 |
+
elif synthesis_strategy == "multi_agent_consensus":
|
| 50 |
+
final_result = self._synthesize_multi_agent_consensus(state)
|
| 51 |
+
elif synthesis_strategy == "confidence_weighted":
|
| 52 |
+
final_result = self._synthesize_confidence_weighted(state)
|
| 53 |
+
elif synthesis_strategy == "llm_synthesis":
|
| 54 |
+
final_result = self._synthesize_with_llm(state)
|
| 55 |
+
else:
|
| 56 |
+
final_result = self._synthesize_fallback(state)
|
| 57 |
+
|
| 58 |
+
# Update state with final results
|
| 59 |
+
state.final_answer = final_result["answer"]
|
| 60 |
+
state.final_confidence = final_result["confidence"]
|
| 61 |
+
state.final_reasoning = final_result["reasoning"]
|
| 62 |
+
state.answer_source = final_result["source"]
|
| 63 |
+
state.is_complete = True
|
| 64 |
+
|
| 65 |
+
# Check if confidence threshold is met
|
| 66 |
+
state.confidence_threshold_met = state.final_confidence >= 0.7
|
| 67 |
+
|
| 68 |
+
# Determine if human review is needed
|
| 69 |
+
state.requires_human_review = (
|
| 70 |
+
state.final_confidence < 0.5 or
|
| 71 |
+
len(state.error_messages) > 0 or
|
| 72 |
+
state.difficulty_level >= 3
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
logger.info(f"✅ Synthesis complete: confidence={state.final_confidence:.2f}")
|
| 76 |
+
state.add_processing_step(f"Synthesizer: Final answer generated (confidence: {state.final_confidence:.2f})")
|
| 77 |
+
|
| 78 |
+
return state
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
error_msg = f"Synthesis failed: {str(e)}"
|
| 82 |
+
state.add_error(error_msg)
|
| 83 |
+
logger.error(error_msg)
|
| 84 |
+
|
| 85 |
+
# Provide fallback answer
|
| 86 |
+
state.final_answer = "Processing failed due to synthesis error"
|
| 87 |
+
state.final_confidence = 0.0
|
| 88 |
+
state.final_reasoning = error_msg
|
| 89 |
+
state.answer_source = "error_fallback"
|
| 90 |
+
state.is_complete = True
|
| 91 |
+
state.requires_human_review = True
|
| 92 |
+
|
| 93 |
+
return state
|
| 94 |
+
|
| 95 |
+
def _determine_synthesis_strategy(self, state: GAIAAgentState) -> str:
|
| 96 |
+
"""Determine the best synthesis strategy based on available results"""
|
| 97 |
+
|
| 98 |
+
successful_results = [r for r in state.agent_results.values() if r.success]
|
| 99 |
+
|
| 100 |
+
if len(successful_results) == 0:
|
| 101 |
+
return "fallback"
|
| 102 |
+
elif len(successful_results) == 1:
|
| 103 |
+
return "single_agent"
|
| 104 |
+
elif len(successful_results) == 2:
|
| 105 |
+
return "confidence_weighted"
|
| 106 |
+
elif all(r.confidence > 0.6 for r in successful_results):
|
| 107 |
+
return "multi_agent_consensus"
|
| 108 |
+
else:
|
| 109 |
+
return "llm_synthesis"
|
| 110 |
+
|
| 111 |
+
def _synthesize_single_agent(self, state: GAIAAgentState) -> Dict[str, Any]:
|
| 112 |
+
"""Synthesize result from a single agent"""
|
| 113 |
+
|
| 114 |
+
successful_results = [r for r in state.agent_results.values() if r.success]
|
| 115 |
+
if not successful_results:
|
| 116 |
+
return self._create_fallback_result("No successful agent results")
|
| 117 |
+
|
| 118 |
+
best_result = max(successful_results, key=lambda r: r.confidence)
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"answer": best_result.result,
|
| 122 |
+
"confidence": best_result.confidence,
|
| 123 |
+
"reasoning": f"Single agent result from {best_result.agent_role.value}: {best_result.reasoning}",
|
| 124 |
+
"source": best_result.agent_role.value
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
def _synthesize_multi_agent_consensus(self, state: GAIAAgentState) -> Dict[str, Any]:
|
| 128 |
+
"""Synthesize results when multiple agents agree (high confidence)"""
|
| 129 |
+
|
| 130 |
+
successful_results = [r for r in state.agent_results.values() if r.success]
|
| 131 |
+
high_confidence_results = [r for r in successful_results if r.confidence > 0.6]
|
| 132 |
+
|
| 133 |
+
if not high_confidence_results:
|
| 134 |
+
return self._synthesize_confidence_weighted(state)
|
| 135 |
+
|
| 136 |
+
# Use the highest confidence result as primary
|
| 137 |
+
primary_result = max(high_confidence_results, key=lambda r: r.confidence)
|
| 138 |
+
|
| 139 |
+
# Calculate consensus confidence
|
| 140 |
+
avg_confidence = mean([r.confidence for r in high_confidence_results])
|
| 141 |
+
consensus_confidence = min(0.95, avg_confidence * 1.1) # Boost for consensus
|
| 142 |
+
|
| 143 |
+
# Create reasoning summary
|
| 144 |
+
agent_summaries = []
|
| 145 |
+
for result in high_confidence_results:
|
| 146 |
+
agent_summaries.append(f"{result.agent_role.value} (conf: {result.confidence:.2f})")
|
| 147 |
+
|
| 148 |
+
reasoning = f"Consensus from {len(high_confidence_results)} agents: {', '.join(agent_summaries)}. Primary result: {primary_result.reasoning}"
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"answer": primary_result.result,
|
| 152 |
+
"confidence": consensus_confidence,
|
| 153 |
+
"reasoning": reasoning,
|
| 154 |
+
"source": f"consensus_{len(high_confidence_results)}_agents"
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def _synthesize_confidence_weighted(self, state: GAIAAgentState) -> Dict[str, Any]:
|
| 158 |
+
"""Synthesize results using confidence weighting"""
|
| 159 |
+
|
| 160 |
+
successful_results = [r for r in state.agent_results.values() if r.success]
|
| 161 |
+
|
| 162 |
+
if not successful_results:
|
| 163 |
+
return self._create_fallback_result("No successful results for confidence weighting")
|
| 164 |
+
|
| 165 |
+
# Weight by confidence
|
| 166 |
+
total_weight = sum(r.confidence for r in successful_results)
|
| 167 |
+
if total_weight == 0:
|
| 168 |
+
return self._synthesize_single_agent(state)
|
| 169 |
+
|
| 170 |
+
# Select primary result (highest confidence)
|
| 171 |
+
primary_result = max(successful_results, key=lambda r: r.confidence)
|
| 172 |
+
|
| 173 |
+
# Calculate weighted confidence
|
| 174 |
+
weighted_confidence = sum(r.confidence ** 2 for r in successful_results) / total_weight
|
| 175 |
+
|
| 176 |
+
# Create reasoning
|
| 177 |
+
result_summaries = []
|
| 178 |
+
for result in successful_results:
|
| 179 |
+
weight = result.confidence / total_weight
|
| 180 |
+
result_summaries.append(f"{result.agent_role.value} (weight: {weight:.2f})")
|
| 181 |
+
|
| 182 |
+
reasoning = f"Confidence-weighted synthesis: {', '.join(result_summaries)}. Primary: {primary_result.reasoning}"
|
| 183 |
+
|
| 184 |
+
return {
|
| 185 |
+
"answer": primary_result.result,
|
| 186 |
+
"confidence": min(0.9, weighted_confidence),
|
| 187 |
+
"reasoning": reasoning,
|
| 188 |
+
"source": f"weighted_{len(successful_results)}_agents"
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def _synthesize_with_llm(self, state: GAIAAgentState) -> Dict[str, Any]:
|
| 192 |
+
"""Use LLM to synthesize conflicting or complex results"""
|
| 193 |
+
|
| 194 |
+
successful_results = [r for r in state.agent_results.values() if r.success]
|
| 195 |
+
|
| 196 |
+
# Prepare synthesis prompt
|
| 197 |
+
agent_results_text = []
|
| 198 |
+
for i, result in enumerate(successful_results, 1):
|
| 199 |
+
agent_results_text.append(f"""
|
| 200 |
+
Agent {i} ({result.agent_role.value}):
|
| 201 |
+
- Answer: {result.result}
|
| 202 |
+
- Confidence: {result.confidence:.2f}
|
| 203 |
+
- Reasoning: {result.reasoning}
|
| 204 |
+
""")
|
| 205 |
+
|
| 206 |
+
synthesis_prompt = f"""
|
| 207 |
+
Question: {state.question}
|
| 208 |
+
|
| 209 |
+
Multiple agents have provided different answers/insights. Please synthesize these into a single, coherent final answer:
|
| 210 |
+
|
| 211 |
+
{chr(10).join(agent_results_text)}
|
| 212 |
+
|
| 213 |
+
Please provide:
|
| 214 |
+
1. A clear, direct final answer
|
| 215 |
+
2. Your confidence level (0.0 to 1.0)
|
| 216 |
+
3. Brief reasoning explaining how you synthesized the results
|
| 217 |
+
|
| 218 |
+
Focus on accuracy and be direct in your response.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
# Use complex model for synthesis
|
| 222 |
+
model_tier = ModelTier.COMPLEX if state.should_use_complex_model() else ModelTier.MAIN
|
| 223 |
+
llm_result = self.llm_client.generate(synthesis_prompt, tier=model_tier, max_tokens=400)
|
| 224 |
+
|
| 225 |
+
if llm_result.success:
|
| 226 |
+
# Parse LLM response for structured output
|
| 227 |
+
llm_answer = llm_result.response
|
| 228 |
+
|
| 229 |
+
# Extract confidence if mentioned in response
|
| 230 |
+
confidence_match = re.search(r'confidence[:\s]*([0-9.]+)', llm_answer.lower())
|
| 231 |
+
llm_confidence = float(confidence_match.group(1)) if confidence_match else 0.7
|
| 232 |
+
|
| 233 |
+
# Adjust confidence based on input quality
|
| 234 |
+
avg_input_confidence = mean([r.confidence for r in successful_results])
|
| 235 |
+
final_confidence = min(0.85, (llm_confidence + avg_input_confidence) / 2)
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
"answer": llm_answer,
|
| 239 |
+
"confidence": final_confidence,
|
| 240 |
+
"reasoning": f"LLM synthesis of {len(successful_results)} agent results using {llm_result.model_used}",
|
| 241 |
+
"source": "llm_synthesis"
|
| 242 |
+
}
|
| 243 |
+
else:
|
| 244 |
+
# Fallback to confidence weighted if LLM fails
|
| 245 |
+
return self._synthesize_confidence_weighted(state)
|
| 246 |
+
|
| 247 |
+
def _synthesize_fallback(self, state: GAIAAgentState) -> Dict[str, Any]:
|
| 248 |
+
"""Fallback synthesis when other strategies fail"""
|
| 249 |
+
|
| 250 |
+
# Try to get any result, even if not successful
|
| 251 |
+
all_results = list(state.agent_results.values())
|
| 252 |
+
|
| 253 |
+
if all_results:
|
| 254 |
+
# Use the result with highest confidence, even if failed
|
| 255 |
+
best_attempt = max(all_results, key=lambda r: r.confidence if r.success else 0.0)
|
| 256 |
+
|
| 257 |
+
if best_attempt.success:
|
| 258 |
+
return {
|
| 259 |
+
"answer": best_attempt.result,
|
| 260 |
+
"confidence": max(0.3, best_attempt.confidence * 0.8), # Reduce confidence for fallback
|
| 261 |
+
"reasoning": f"Fallback result from {best_attempt.agent_role.value}: {best_attempt.reasoning}",
|
| 262 |
+
"source": f"fallback_{best_attempt.agent_role.value}"
|
| 263 |
+
}
|
| 264 |
+
else:
|
| 265 |
+
return {
|
| 266 |
+
"answer": f"Processing encountered difficulties: {best_attempt.result}",
|
| 267 |
+
"confidence": 0.2,
|
| 268 |
+
"reasoning": f"Fallback from failed attempt by {best_attempt.agent_role.value}",
|
| 269 |
+
"source": "failed_fallback"
|
| 270 |
+
}
|
| 271 |
+
else:
|
| 272 |
+
return self._create_fallback_result("No agent results available")
|
| 273 |
+
|
| 274 |
+
def _create_fallback_result(self, reason: str) -> Dict[str, Any]:
|
| 275 |
+
"""Create a fallback result when synthesis is impossible"""
|
| 276 |
+
return {
|
| 277 |
+
"answer": f"Unable to process question: {reason}",
|
| 278 |
+
"confidence": 0.0,
|
| 279 |
+
"reasoning": f"Synthesis failed: {reason}",
|
| 280 |
+
"source": "synthesis_failure"
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
# Import regex for LLM response parsing
|
| 284 |
+
import re
|
src/agents/web_researcher.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Web Research Agent for GAIA Agent System
|
| 4 |
+
Handles Wikipedia and web search questions with intelligent search strategies
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Any
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
from agents.state import GAIAAgentState, AgentRole, AgentResult, ToolResult
|
| 13 |
+
from models.qwen_client import QwenClient, ModelTier
|
| 14 |
+
from tools.wikipedia_tool import WikipediaTool
|
| 15 |
+
from tools.web_search_tool import WebSearchTool
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class WebResearchAgent:
|
| 20 |
+
"""
|
| 21 |
+
Specialized agent for web research tasks
|
| 22 |
+
Uses Wikipedia and web search tools with intelligent routing
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, llm_client: QwenClient):
|
| 26 |
+
self.llm_client = llm_client
|
| 27 |
+
self.wikipedia_tool = WikipediaTool()
|
| 28 |
+
self.web_search_tool = WebSearchTool()
|
| 29 |
+
|
| 30 |
+
def process(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 31 |
+
"""
|
| 32 |
+
Process web research questions using Wikipedia and web search
|
| 33 |
+
"""
|
| 34 |
+
logger.info(f"Web researcher processing: {state.question[:100]}...")
|
| 35 |
+
state.add_processing_step("Web Researcher: Starting research")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Determine research strategy
|
| 39 |
+
strategy = self._determine_research_strategy(state.question, state.file_name)
|
| 40 |
+
state.add_processing_step(f"Web Researcher: Strategy = {strategy}")
|
| 41 |
+
|
| 42 |
+
# Execute research based on strategy
|
| 43 |
+
if strategy == "wikipedia_direct":
|
| 44 |
+
result = self._research_wikipedia_direct(state)
|
| 45 |
+
elif strategy == "wikipedia_search":
|
| 46 |
+
result = self._research_wikipedia_search(state)
|
| 47 |
+
elif strategy == "youtube_analysis":
|
| 48 |
+
result = self._research_youtube(state)
|
| 49 |
+
elif strategy == "web_search":
|
| 50 |
+
result = self._research_web_general(state)
|
| 51 |
+
elif strategy == "url_extraction":
|
| 52 |
+
result = self._research_url_content(state)
|
| 53 |
+
else:
|
| 54 |
+
result = self._research_multi_source(state)
|
| 55 |
+
|
| 56 |
+
# Add result to state
|
| 57 |
+
state.add_agent_result(result)
|
| 58 |
+
state.add_processing_step(f"Web Researcher: Completed with confidence {result.confidence:.2f}")
|
| 59 |
+
|
| 60 |
+
return state
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
error_msg = f"Web research failed: {str(e)}"
|
| 64 |
+
state.add_error(error_msg)
|
| 65 |
+
logger.error(error_msg)
|
| 66 |
+
|
| 67 |
+
# Create failure result
|
| 68 |
+
failure_result = AgentResult(
|
| 69 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 70 |
+
success=False,
|
| 71 |
+
result=f"Research failed: {str(e)}",
|
| 72 |
+
confidence=0.0,
|
| 73 |
+
reasoning=f"Exception during web research: {str(e)}",
|
| 74 |
+
model_used="error",
|
| 75 |
+
processing_time=0.0,
|
| 76 |
+
cost_estimate=0.0
|
| 77 |
+
)
|
| 78 |
+
state.add_agent_result(failure_result)
|
| 79 |
+
return state
|
| 80 |
+
|
| 81 |
+
def _determine_research_strategy(self, question: str, file_name: Optional[str] = None) -> str:
|
| 82 |
+
"""Determine the best research strategy for the question"""
|
| 83 |
+
|
| 84 |
+
question_lower = question.lower()
|
| 85 |
+
|
| 86 |
+
# Direct Wikipedia references
|
| 87 |
+
if any(term in question_lower for term in ['wikipedia', 'featured article', 'promoted']):
|
| 88 |
+
if 'search' in question_lower or 'find' in question_lower:
|
| 89 |
+
return "wikipedia_search"
|
| 90 |
+
else:
|
| 91 |
+
return "wikipedia_direct"
|
| 92 |
+
|
| 93 |
+
# YouTube video analysis
|
| 94 |
+
if any(term in question_lower for term in ['youtube', 'video', 'watch?v=', 'youtu.be']):
|
| 95 |
+
return "youtube_analysis"
|
| 96 |
+
|
| 97 |
+
# URL content extraction
|
| 98 |
+
urls = re.findall(r'https?://[^\s]+', question)
|
| 99 |
+
if urls:
|
| 100 |
+
return "url_extraction"
|
| 101 |
+
|
| 102 |
+
# General web search for current events, news, recent information
|
| 103 |
+
if any(term in question_lower for term in ['news', 'recent', 'latest', 'current', 'today', '2024', '2025']):
|
| 104 |
+
return "web_search"
|
| 105 |
+
|
| 106 |
+
# Multi-source research for complex questions
|
| 107 |
+
if len(question.split()) > 20 or '?' in question and question.count('?') > 1:
|
| 108 |
+
return "multi_source"
|
| 109 |
+
|
| 110 |
+
# Default to Wikipedia search for informational questions
|
| 111 |
+
return "wikipedia_search"
|
| 112 |
+
|
| 113 |
+
def _research_wikipedia_direct(self, state: GAIAAgentState) -> AgentResult:
|
| 114 |
+
"""Research using direct Wikipedia lookup"""
|
| 115 |
+
|
| 116 |
+
# Extract topic from question
|
| 117 |
+
topic = self._extract_wikipedia_topic(state.question)
|
| 118 |
+
|
| 119 |
+
logger.info(f"Wikipedia direct research for: {topic}")
|
| 120 |
+
|
| 121 |
+
# Search Wikipedia
|
| 122 |
+
wiki_result = self.wikipedia_tool.execute(topic)
|
| 123 |
+
|
| 124 |
+
if wiki_result.success and wiki_result.result.get('found'):
|
| 125 |
+
wiki_data = wiki_result.result['result']
|
| 126 |
+
|
| 127 |
+
# Use LLM to analyze and answer the question
|
| 128 |
+
analysis_prompt = f"""
|
| 129 |
+
Based on this Wikipedia information about {topic}, please answer the following question:
|
| 130 |
+
|
| 131 |
+
Question: {state.question}
|
| 132 |
+
|
| 133 |
+
Wikipedia Summary: {wiki_data.get('summary', '')}
|
| 134 |
+
|
| 135 |
+
Wikipedia URL: {wiki_data.get('url', '')}
|
| 136 |
+
|
| 137 |
+
Please provide a direct, accurate answer based on the Wikipedia information.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# Use appropriate model tier
|
| 141 |
+
model_tier = ModelTier.MAIN if state.complexity_assessment == "complex" else ModelTier.ROUTER
|
| 142 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 143 |
+
|
| 144 |
+
if llm_result.success:
|
| 145 |
+
confidence = 0.85 if wiki_data.get('title') == topic else 0.75
|
| 146 |
+
|
| 147 |
+
return AgentResult(
|
| 148 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 149 |
+
success=True,
|
| 150 |
+
result=llm_result.response,
|
| 151 |
+
confidence=confidence,
|
| 152 |
+
reasoning=f"Found Wikipedia article for '{topic}' and analyzed content",
|
| 153 |
+
tools_used=[ToolResult(
|
| 154 |
+
tool_name="wikipedia",
|
| 155 |
+
success=True,
|
| 156 |
+
result=wiki_data,
|
| 157 |
+
execution_time=wiki_result.execution_time
|
| 158 |
+
)],
|
| 159 |
+
model_used=llm_result.model_used,
|
| 160 |
+
processing_time=wiki_result.execution_time + llm_result.response_time,
|
| 161 |
+
cost_estimate=llm_result.cost_estimate
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
# Return Wikipedia summary as fallback
|
| 165 |
+
return AgentResult(
|
| 166 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 167 |
+
success=True,
|
| 168 |
+
result=wiki_data.get('summary', 'Wikipedia information found but analysis failed'),
|
| 169 |
+
confidence=0.60,
|
| 170 |
+
reasoning="Wikipedia found but LLM analysis failed",
|
| 171 |
+
tools_used=[ToolResult(
|
| 172 |
+
tool_name="wikipedia",
|
| 173 |
+
success=True,
|
| 174 |
+
result=wiki_data,
|
| 175 |
+
execution_time=wiki_result.execution_time
|
| 176 |
+
)],
|
| 177 |
+
model_used="fallback",
|
| 178 |
+
processing_time=wiki_result.execution_time,
|
| 179 |
+
cost_estimate=0.0
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
# Wikipedia not found, try web search as fallback
|
| 183 |
+
return self._research_web_fallback(state, f"Wikipedia not found for '{topic}'")
|
| 184 |
+
|
| 185 |
+
def _research_wikipedia_search(self, state: GAIAAgentState) -> AgentResult:
|
| 186 |
+
"""Research using Wikipedia search functionality"""
|
| 187 |
+
|
| 188 |
+
# Extract search terms
|
| 189 |
+
search_terms = self._extract_search_terms(state.question)
|
| 190 |
+
|
| 191 |
+
logger.info(f"Wikipedia search for: {search_terms}")
|
| 192 |
+
|
| 193 |
+
# Search Wikipedia
|
| 194 |
+
search_query = {"query": search_terms, "action": "summary"}
|
| 195 |
+
wiki_result = self.wikipedia_tool.execute(search_query)
|
| 196 |
+
|
| 197 |
+
if wiki_result.success and wiki_result.result.get('found'):
|
| 198 |
+
return self._analyze_wikipedia_result(state, wiki_result)
|
| 199 |
+
else:
|
| 200 |
+
# Try web search as fallback
|
| 201 |
+
return self._research_web_fallback(state, f"Wikipedia search failed for '{search_terms}'")
|
| 202 |
+
|
| 203 |
+
def _research_youtube(self, state: GAIAAgentState) -> AgentResult:
|
| 204 |
+
"""Research YouTube video information"""
|
| 205 |
+
|
| 206 |
+
# Extract YouTube URL or search terms
|
| 207 |
+
youtube_query = self._extract_youtube_info(state.question)
|
| 208 |
+
|
| 209 |
+
logger.info(f"YouTube research for: {youtube_query}")
|
| 210 |
+
|
| 211 |
+
# Use web search tool's YouTube functionality
|
| 212 |
+
if youtube_query.startswith('http'):
|
| 213 |
+
# Direct YouTube URL
|
| 214 |
+
web_result = self.web_search_tool.execute({
|
| 215 |
+
"query": youtube_query,
|
| 216 |
+
"action": "extract"
|
| 217 |
+
})
|
| 218 |
+
else:
|
| 219 |
+
# Search for YouTube videos
|
| 220 |
+
web_result = self.web_search_tool.execute(f"site:youtube.com {youtube_query}")
|
| 221 |
+
|
| 222 |
+
if web_result.success and web_result.result.get('found'):
|
| 223 |
+
return self._analyze_youtube_result(state, web_result)
|
| 224 |
+
else:
|
| 225 |
+
return self._create_failure_result("YouTube research failed")
|
| 226 |
+
|
| 227 |
+
def _research_web_general(self, state: GAIAAgentState) -> AgentResult:
|
| 228 |
+
"""General web search research"""
|
| 229 |
+
|
| 230 |
+
search_terms = self._extract_search_terms(state.question)
|
| 231 |
+
|
| 232 |
+
logger.info(f"Web search for: {search_terms}")
|
| 233 |
+
|
| 234 |
+
# Perform web search
|
| 235 |
+
web_result = self.web_search_tool.execute({
|
| 236 |
+
"query": search_terms,
|
| 237 |
+
"action": "search",
|
| 238 |
+
"limit": 5
|
| 239 |
+
})
|
| 240 |
+
|
| 241 |
+
if web_result.success and web_result.result.get('found'):
|
| 242 |
+
return self._analyze_web_search_result(state, web_result)
|
| 243 |
+
else:
|
| 244 |
+
return self._create_failure_result("Web search failed")
|
| 245 |
+
|
| 246 |
+
def _research_url_content(self, state: GAIAAgentState) -> AgentResult:
|
| 247 |
+
"""Extract and analyze content from specific URLs"""
|
| 248 |
+
|
| 249 |
+
urls = re.findall(r'https?://[^\s]+', state.question)
|
| 250 |
+
if not urls:
|
| 251 |
+
return self._create_failure_result("No URLs found in question")
|
| 252 |
+
|
| 253 |
+
url = urls[0] # Use first URL
|
| 254 |
+
logger.info(f"Extracting content from: {url}")
|
| 255 |
+
|
| 256 |
+
# Extract content from URL
|
| 257 |
+
web_result = self.web_search_tool.execute({
|
| 258 |
+
"query": url,
|
| 259 |
+
"action": "extract"
|
| 260 |
+
})
|
| 261 |
+
|
| 262 |
+
if web_result.success and web_result.result.get('found'):
|
| 263 |
+
return self._analyze_url_content_result(state, web_result)
|
| 264 |
+
else:
|
| 265 |
+
return self._create_failure_result(f"Failed to extract content from {url}")
|
| 266 |
+
|
| 267 |
+
def _research_multi_source(self, state: GAIAAgentState) -> AgentResult:
|
| 268 |
+
"""Multi-source research combining Wikipedia and web search"""
|
| 269 |
+
|
| 270 |
+
search_terms = self._extract_search_terms(state.question)
|
| 271 |
+
|
| 272 |
+
logger.info(f"Multi-source research for: {search_terms}")
|
| 273 |
+
|
| 274 |
+
sources = []
|
| 275 |
+
|
| 276 |
+
# Try Wikipedia first
|
| 277 |
+
wiki_result = self.wikipedia_tool.execute(search_terms)
|
| 278 |
+
if wiki_result.success and wiki_result.result.get('found'):
|
| 279 |
+
sources.append(("Wikipedia", wiki_result.result['result']))
|
| 280 |
+
|
| 281 |
+
# Add web search results
|
| 282 |
+
web_result = self.web_search_tool.execute({
|
| 283 |
+
"query": search_terms,
|
| 284 |
+
"action": "search",
|
| 285 |
+
"limit": 3
|
| 286 |
+
})
|
| 287 |
+
if web_result.success and web_result.result.get('found'):
|
| 288 |
+
for result in web_result.result['results'][:2]: # Use top 2 web results
|
| 289 |
+
sources.append(("Web", result))
|
| 290 |
+
|
| 291 |
+
if sources:
|
| 292 |
+
return self._analyze_multi_source_result(state, sources)
|
| 293 |
+
else:
|
| 294 |
+
return self._create_failure_result("All research sources failed")
|
| 295 |
+
|
| 296 |
+
def _research_web_fallback(self, state: GAIAAgentState, reason: str) -> AgentResult:
|
| 297 |
+
"""Fallback to web search when other methods fail"""
|
| 298 |
+
|
| 299 |
+
logger.info(f"Web search fallback: {reason}")
|
| 300 |
+
|
| 301 |
+
search_terms = self._extract_search_terms(state.question)
|
| 302 |
+
web_result = self.web_search_tool.execute(search_terms)
|
| 303 |
+
|
| 304 |
+
if web_result.success and web_result.result.get('found'):
|
| 305 |
+
result = self._analyze_web_search_result(state, web_result)
|
| 306 |
+
result.reasoning = f"{reason}. Used web search fallback."
|
| 307 |
+
result.confidence = max(0.3, result.confidence - 0.2) # Lower confidence for fallback
|
| 308 |
+
return result
|
| 309 |
+
else:
|
| 310 |
+
return self._create_failure_result(f"Fallback failed: {reason}")
|
| 311 |
+
|
| 312 |
+
def _extract_wikipedia_topic(self, question: str) -> str:
|
| 313 |
+
"""Extract Wikipedia topic from question"""
|
| 314 |
+
|
| 315 |
+
# Look for quoted terms
|
| 316 |
+
quoted = re.findall(r'"([^"]+)"', question)
|
| 317 |
+
if quoted:
|
| 318 |
+
return quoted[0]
|
| 319 |
+
|
| 320 |
+
# Look for specific patterns
|
| 321 |
+
patterns = [
|
| 322 |
+
r'wikipedia article[s]?\s+(?:about|on|for)\s+([^?.,]+)',
|
| 323 |
+
r'featured article[s]?\s+(?:about|on|for)\s+([^?.,]+)',
|
| 324 |
+
r'(?:about|on)\s+([A-Z][^?.,]+)',
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
for pattern in patterns:
|
| 328 |
+
match = re.search(pattern, question, re.IGNORECASE)
|
| 329 |
+
if match:
|
| 330 |
+
return match.group(1).strip()
|
| 331 |
+
|
| 332 |
+
# Extract main nouns/entities
|
| 333 |
+
words = question.split()
|
| 334 |
+
topic_words = []
|
| 335 |
+
for word in words:
|
| 336 |
+
if word[0].isupper() or len(word) > 6: # Likely important words
|
| 337 |
+
topic_words.append(word)
|
| 338 |
+
|
| 339 |
+
return ' '.join(topic_words[:3]) if topic_words else "topic"
|
| 340 |
+
|
| 341 |
+
def _extract_search_terms(self, question: str) -> str:
|
| 342 |
+
"""Extract search terms from question"""
|
| 343 |
+
|
| 344 |
+
# Remove question words and common phrases
|
| 345 |
+
stop_phrases = [
|
| 346 |
+
'what is', 'what are', 'who is', 'who are', 'when is', 'when was',
|
| 347 |
+
'where is', 'where are', 'how is', 'how are', 'why is', 'why are',
|
| 348 |
+
'tell me about', 'find information about', 'search for'
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
clean_question = question.lower()
|
| 352 |
+
for phrase in stop_phrases:
|
| 353 |
+
clean_question = clean_question.replace(phrase, '')
|
| 354 |
+
|
| 355 |
+
# Remove punctuation and extra spaces
|
| 356 |
+
clean_question = re.sub(r'[?.,!]', '', clean_question)
|
| 357 |
+
clean_question = re.sub(r'\s+', ' ', clean_question).strip()
|
| 358 |
+
|
| 359 |
+
return clean_question
|
| 360 |
+
|
| 361 |
+
def _extract_youtube_info(self, question: str) -> str:
|
| 362 |
+
"""Extract YouTube URL or search terms"""
|
| 363 |
+
|
| 364 |
+
# Look for YouTube URLs
|
| 365 |
+
youtube_urls = re.findall(r'https?://(?:www\.)?youtube\.com/[^\s]+', question)
|
| 366 |
+
if youtube_urls:
|
| 367 |
+
return youtube_urls[0]
|
| 368 |
+
|
| 369 |
+
youtube_urls = re.findall(r'https?://youtu\.be/[^\s]+', question)
|
| 370 |
+
if youtube_urls:
|
| 371 |
+
return youtube_urls[0]
|
| 372 |
+
|
| 373 |
+
# Extract search terms for YouTube
|
| 374 |
+
return self._extract_search_terms(question)
|
| 375 |
+
|
| 376 |
+
def _analyze_wikipedia_result(self, state: GAIAAgentState, wiki_result: ToolResult) -> AgentResult:
|
| 377 |
+
"""Analyze Wikipedia result and generate answer"""
|
| 378 |
+
|
| 379 |
+
wiki_data = wiki_result.result['result']
|
| 380 |
+
|
| 381 |
+
analysis_prompt = f"""
|
| 382 |
+
Based on this Wikipedia information, please answer the following question:
|
| 383 |
+
|
| 384 |
+
Question: {state.question}
|
| 385 |
+
|
| 386 |
+
Wikipedia Information:
|
| 387 |
+
Title: {wiki_data.get('title', '')}
|
| 388 |
+
Summary: {wiki_data.get('summary', '')}
|
| 389 |
+
URL: {wiki_data.get('url', '')}
|
| 390 |
+
|
| 391 |
+
Please provide a direct, accurate answer.
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
model_tier = ModelTier.MAIN if len(state.question) > 100 else ModelTier.ROUTER
|
| 395 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=300)
|
| 396 |
+
|
| 397 |
+
if llm_result.success:
|
| 398 |
+
return AgentResult(
|
| 399 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 400 |
+
success=True,
|
| 401 |
+
result=llm_result.response,
|
| 402 |
+
confidence=0.80,
|
| 403 |
+
reasoning="Analyzed Wikipedia information to answer question",
|
| 404 |
+
tools_used=[wiki_result],
|
| 405 |
+
model_used=llm_result.model_used,
|
| 406 |
+
processing_time=wiki_result.execution_time + llm_result.response_time,
|
| 407 |
+
cost_estimate=llm_result.cost_estimate
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
return AgentResult(
|
| 411 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 412 |
+
success=True,
|
| 413 |
+
result=wiki_data.get('summary', 'Information found'),
|
| 414 |
+
confidence=0.60,
|
| 415 |
+
reasoning="Wikipedia found but analysis failed",
|
| 416 |
+
tools_used=[wiki_result],
|
| 417 |
+
model_used="fallback",
|
| 418 |
+
processing_time=wiki_result.execution_time,
|
| 419 |
+
cost_estimate=0.0
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def _analyze_youtube_result(self, state: GAIAAgentState, web_result: ToolResult) -> AgentResult:
|
| 423 |
+
"""Analyze YouTube research result"""
|
| 424 |
+
|
| 425 |
+
# Implementation for YouTube analysis
|
| 426 |
+
return AgentResult(
|
| 427 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 428 |
+
success=True,
|
| 429 |
+
result="YouTube analysis completed",
|
| 430 |
+
confidence=0.70,
|
| 431 |
+
reasoning="Analyzed YouTube content",
|
| 432 |
+
tools_used=[web_result],
|
| 433 |
+
model_used="basic",
|
| 434 |
+
processing_time=web_result.execution_time,
|
| 435 |
+
cost_estimate=0.0
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def _analyze_web_search_result(self, state: GAIAAgentState, web_result: ToolResult) -> AgentResult:
|
| 439 |
+
"""Analyze web search results"""
|
| 440 |
+
|
| 441 |
+
search_results = web_result.result['results']
|
| 442 |
+
|
| 443 |
+
# Combine top results for analysis
|
| 444 |
+
combined_content = []
|
| 445 |
+
for i, result in enumerate(search_results[:3], 1):
|
| 446 |
+
combined_content.append(f"Result {i}: {result['title']}")
|
| 447 |
+
combined_content.append(f"URL: {result['url']}")
|
| 448 |
+
combined_content.append(f"Description: {result['snippet']}")
|
| 449 |
+
combined_content.append("")
|
| 450 |
+
|
| 451 |
+
analysis_prompt = f"""
|
| 452 |
+
Based on these web search results, please answer the following question:
|
| 453 |
+
|
| 454 |
+
Question: {state.question}
|
| 455 |
+
|
| 456 |
+
Search Results:
|
| 457 |
+
{chr(10).join(combined_content)}
|
| 458 |
+
|
| 459 |
+
Please provide a direct answer based on the most relevant information.
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
model_tier = ModelTier.MAIN
|
| 463 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 464 |
+
|
| 465 |
+
if llm_result.success:
|
| 466 |
+
return AgentResult(
|
| 467 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 468 |
+
success=True,
|
| 469 |
+
result=llm_result.response,
|
| 470 |
+
confidence=0.75,
|
| 471 |
+
reasoning=f"Analyzed {len(search_results)} web search results",
|
| 472 |
+
tools_used=[web_result],
|
| 473 |
+
model_used=llm_result.model_used,
|
| 474 |
+
processing_time=web_result.execution_time + llm_result.response_time,
|
| 475 |
+
cost_estimate=llm_result.cost_estimate
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
# Fallback to first result description
|
| 479 |
+
first_result = search_results[0] if search_results else {}
|
| 480 |
+
return AgentResult(
|
| 481 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 482 |
+
success=True,
|
| 483 |
+
result=first_result.get('snippet', 'Web search completed'),
|
| 484 |
+
confidence=0.50,
|
| 485 |
+
reasoning="Web search completed but analysis failed",
|
| 486 |
+
tools_used=[web_result],
|
| 487 |
+
model_used="fallback",
|
| 488 |
+
processing_time=web_result.execution_time,
|
| 489 |
+
cost_estimate=0.0
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def _analyze_url_content_result(self, state: GAIAAgentState, web_result: ToolResult) -> AgentResult:
|
| 493 |
+
"""Analyze extracted URL content"""
|
| 494 |
+
|
| 495 |
+
content_data = web_result.result
|
| 496 |
+
|
| 497 |
+
analysis_prompt = f"""
|
| 498 |
+
Based on this web page content, please answer the following question:
|
| 499 |
+
|
| 500 |
+
Question: {state.question}
|
| 501 |
+
|
| 502 |
+
Page Title: {content_data.get('title', '')}
|
| 503 |
+
Page URL: {content_data.get('url', '')}
|
| 504 |
+
Content: {content_data.get('content', '')[:1000]}...
|
| 505 |
+
|
| 506 |
+
Please provide a direct answer based on the page content.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
model_tier = ModelTier.MAIN
|
| 510 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=400)
|
| 511 |
+
|
| 512 |
+
if llm_result.success:
|
| 513 |
+
return AgentResult(
|
| 514 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 515 |
+
success=True,
|
| 516 |
+
result=llm_result.response,
|
| 517 |
+
confidence=0.85,
|
| 518 |
+
reasoning="Analyzed content from specific URL",
|
| 519 |
+
tools_used=[web_result],
|
| 520 |
+
model_used=llm_result.model_used,
|
| 521 |
+
processing_time=web_result.execution_time + llm_result.response_time,
|
| 522 |
+
cost_estimate=llm_result.cost_estimate
|
| 523 |
+
)
|
| 524 |
+
else:
|
| 525 |
+
return AgentResult(
|
| 526 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 527 |
+
success=True,
|
| 528 |
+
result=content_data.get('content', 'Content extracted')[:200],
|
| 529 |
+
confidence=0.60,
|
| 530 |
+
reasoning="URL content extracted but analysis failed",
|
| 531 |
+
tools_used=[web_result],
|
| 532 |
+
model_used="fallback",
|
| 533 |
+
processing_time=web_result.execution_time,
|
| 534 |
+
cost_estimate=0.0
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def _analyze_multi_source_result(self, state: GAIAAgentState, sources: List) -> AgentResult:
|
| 538 |
+
"""Analyze results from multiple sources"""
|
| 539 |
+
|
| 540 |
+
source_summaries = []
|
| 541 |
+
for source_type, source_data in sources:
|
| 542 |
+
if source_type == "Wikipedia":
|
| 543 |
+
source_summaries.append(f"Wikipedia: {source_data.get('summary', '')[:200]}")
|
| 544 |
+
else: # Web result
|
| 545 |
+
source_summaries.append(f"Web: {source_data.get('snippet', '')[:200]}")
|
| 546 |
+
|
| 547 |
+
analysis_prompt = f"""
|
| 548 |
+
Based on these multiple sources, please answer the following question:
|
| 549 |
+
|
| 550 |
+
Question: {state.question}
|
| 551 |
+
|
| 552 |
+
Sources:
|
| 553 |
+
{chr(10).join(source_summaries)}
|
| 554 |
+
|
| 555 |
+
Please synthesize the information and provide a comprehensive answer.
|
| 556 |
+
"""
|
| 557 |
+
|
| 558 |
+
model_tier = ModelTier.COMPLEX # Use best model for multi-source analysis
|
| 559 |
+
llm_result = self.llm_client.generate(analysis_prompt, tier=model_tier, max_tokens=500)
|
| 560 |
+
|
| 561 |
+
if llm_result.success:
|
| 562 |
+
return AgentResult(
|
| 563 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 564 |
+
success=True,
|
| 565 |
+
result=llm_result.response,
|
| 566 |
+
confidence=0.85,
|
| 567 |
+
reasoning=f"Synthesized information from {len(sources)} sources",
|
| 568 |
+
tools_used=[],
|
| 569 |
+
model_used=llm_result.model_used,
|
| 570 |
+
processing_time=llm_result.response_time,
|
| 571 |
+
cost_estimate=llm_result.cost_estimate
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
# Fallback to first source
|
| 575 |
+
first_source = sources[0][1] if sources else {}
|
| 576 |
+
content = first_source.get('summary') or first_source.get('snippet', 'Multi-source research completed')
|
| 577 |
+
return AgentResult(
|
| 578 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 579 |
+
success=True,
|
| 580 |
+
result=content,
|
| 581 |
+
confidence=0.60,
|
| 582 |
+
reasoning="Multi-source research completed but synthesis failed",
|
| 583 |
+
tools_used=[],
|
| 584 |
+
model_used="fallback",
|
| 585 |
+
processing_time=0.0,
|
| 586 |
+
cost_estimate=0.0
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def _create_failure_result(self, error_message: str) -> AgentResult:
|
| 590 |
+
"""Create a failure result"""
|
| 591 |
+
return AgentResult(
|
| 592 |
+
agent_role=AgentRole.WEB_RESEARCHER,
|
| 593 |
+
success=False,
|
| 594 |
+
result=error_message,
|
| 595 |
+
confidence=0.0,
|
| 596 |
+
reasoning=error_message,
|
| 597 |
+
model_used="error",
|
| 598 |
+
processing_time=0.0,
|
| 599 |
+
cost_estimate=0.0
|
| 600 |
+
)
|
src/api/unit4_client.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Unit 4 API Client for GAIA Benchmark Questions
|
| 4 |
+
Handles question fetching, file downloads, and answer submission
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import requests
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, List, Optional, Union
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Configure logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GAIAQuestion:
|
| 22 |
+
"""GAIA benchmark question data structure"""
|
| 23 |
+
task_id: str
|
| 24 |
+
question: str
|
| 25 |
+
level: int # 1, 2, or 3 (difficulty level)
|
| 26 |
+
final_answer: Optional[str] = None
|
| 27 |
+
file_name: Optional[str] = None
|
| 28 |
+
file_path: Optional[str] = None
|
| 29 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class SubmissionResult:
|
| 33 |
+
"""Result of answer submission"""
|
| 34 |
+
task_id: str
|
| 35 |
+
submitted_answer: str
|
| 36 |
+
success: bool
|
| 37 |
+
score: Optional[float] = None
|
| 38 |
+
feedback: Optional[str] = None
|
| 39 |
+
error: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
class Unit4APIClient:
|
| 42 |
+
"""Client for Unit 4 API to fetch GAIA questions and submit answers"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, base_url: str = "https://agents-course-unit4-scoring.hf.space"):
|
| 45 |
+
"""Initialize Unit 4 API client"""
|
| 46 |
+
self.base_url = base_url.rstrip('/')
|
| 47 |
+
self.session = requests.Session()
|
| 48 |
+
self.session.headers.update({
|
| 49 |
+
'User-Agent': 'GAIA-Agent-System/1.0',
|
| 50 |
+
'Accept': 'application/json',
|
| 51 |
+
'Content-Type': 'application/json'
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
# Create downloads directory
|
| 55 |
+
self.downloads_dir = Path("downloads")
|
| 56 |
+
self.downloads_dir.mkdir(exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# Track API usage
|
| 59 |
+
self.requests_made = 0
|
| 60 |
+
self.last_request_time = 0
|
| 61 |
+
self.rate_limit_delay = 1.0 # Seconds between requests
|
| 62 |
+
|
| 63 |
+
def _rate_limit(self):
|
| 64 |
+
"""Implement basic rate limiting"""
|
| 65 |
+
current_time = time.time()
|
| 66 |
+
time_since_last = current_time - self.last_request_time
|
| 67 |
+
|
| 68 |
+
if time_since_last < self.rate_limit_delay:
|
| 69 |
+
sleep_time = self.rate_limit_delay - time_since_last
|
| 70 |
+
logger.debug(f"Rate limiting: sleeping {sleep_time:.2f}s")
|
| 71 |
+
time.sleep(sleep_time)
|
| 72 |
+
|
| 73 |
+
self.last_request_time = time.time()
|
| 74 |
+
self.requests_made += 1
|
| 75 |
+
|
| 76 |
+
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
| 77 |
+
"""Make HTTP request with rate limiting and error handling"""
|
| 78 |
+
self._rate_limit()
|
| 79 |
+
|
| 80 |
+
url = f"{self.base_url}{endpoint}"
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
logger.debug(f"Making {method} request to {url}")
|
| 84 |
+
response = self.session.request(method, url, **kwargs)
|
| 85 |
+
response.raise_for_status()
|
| 86 |
+
return response
|
| 87 |
+
|
| 88 |
+
except requests.exceptions.RequestException as e:
|
| 89 |
+
logger.error(f"API request failed: {e}")
|
| 90 |
+
raise
|
| 91 |
+
|
| 92 |
+
def get_questions(self, level: Optional[int] = None, limit: Optional[int] = None) -> List[GAIAQuestion]:
|
| 93 |
+
"""Fetch GAIA questions from the API"""
|
| 94 |
+
|
| 95 |
+
endpoint = "/questions"
|
| 96 |
+
params = {}
|
| 97 |
+
|
| 98 |
+
if level is not None:
|
| 99 |
+
params['level'] = level
|
| 100 |
+
if limit is not None:
|
| 101 |
+
params['limit'] = limit
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
response = self._make_request('GET', endpoint, params=params)
|
| 105 |
+
data = response.json()
|
| 106 |
+
|
| 107 |
+
questions = []
|
| 108 |
+
|
| 109 |
+
# Handle different response formats
|
| 110 |
+
if isinstance(data, list):
|
| 111 |
+
question_list = data
|
| 112 |
+
elif isinstance(data, dict) and 'questions' in data:
|
| 113 |
+
question_list = data['questions']
|
| 114 |
+
else:
|
| 115 |
+
question_list = [data] # Single question
|
| 116 |
+
|
| 117 |
+
for q_data in question_list:
|
| 118 |
+
question = GAIAQuestion(
|
| 119 |
+
task_id=q_data.get('task_id', ''),
|
| 120 |
+
question=q_data.get('question', ''),
|
| 121 |
+
level=q_data.get('level', 1),
|
| 122 |
+
final_answer=q_data.get('final_answer'),
|
| 123 |
+
file_name=q_data.get('file_name'),
|
| 124 |
+
metadata=q_data
|
| 125 |
+
)
|
| 126 |
+
questions.append(question)
|
| 127 |
+
|
| 128 |
+
logger.info(f"✅ Fetched {len(questions)} questions from API")
|
| 129 |
+
return questions
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"❌ Failed to fetch questions: {e}")
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
def get_random_question(self, level: Optional[int] = None) -> Optional[GAIAQuestion]:
|
| 136 |
+
"""Fetch a random question from the API"""
|
| 137 |
+
|
| 138 |
+
endpoint = "/random-question"
|
| 139 |
+
params = {}
|
| 140 |
+
|
| 141 |
+
if level is not None:
|
| 142 |
+
params['level'] = level
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
response = self._make_request('GET', endpoint, params=params)
|
| 146 |
+
data = response.json()
|
| 147 |
+
|
| 148 |
+
question = GAIAQuestion(
|
| 149 |
+
task_id=data.get('task_id', ''),
|
| 150 |
+
question=data.get('question', ''),
|
| 151 |
+
level=data.get('level', 1),
|
| 152 |
+
final_answer=data.get('final_answer'),
|
| 153 |
+
file_name=data.get('file_name'),
|
| 154 |
+
metadata=data
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
logger.info(f"✅ Fetched random question: {question.task_id}")
|
| 158 |
+
return question
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
logger.error(f"❌ Failed to fetch random question: {e}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
def download_file(self, task_id: str, file_name: Optional[str] = None) -> Optional[str]:
|
| 165 |
+
"""Download file associated with a question"""
|
| 166 |
+
|
| 167 |
+
if not task_id:
|
| 168 |
+
logger.error("Task ID required for file download")
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
endpoint = f"/files/{task_id}"
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
response = self._make_request('GET', endpoint, stream=True)
|
| 175 |
+
|
| 176 |
+
# Determine filename
|
| 177 |
+
if file_name:
|
| 178 |
+
filename = file_name
|
| 179 |
+
else:
|
| 180 |
+
# Try to get filename from response headers
|
| 181 |
+
content_disposition = response.headers.get('content-disposition', '')
|
| 182 |
+
if 'filename=' in content_disposition:
|
| 183 |
+
filename = content_disposition.split('filename=')[1].strip('"')
|
| 184 |
+
else:
|
| 185 |
+
# Use task_id as fallback
|
| 186 |
+
filename = f"{task_id}_file"
|
| 187 |
+
|
| 188 |
+
# Save file
|
| 189 |
+
file_path = self.downloads_dir / filename
|
| 190 |
+
|
| 191 |
+
with open(file_path, 'wb') as f:
|
| 192 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 193 |
+
f.write(chunk)
|
| 194 |
+
|
| 195 |
+
logger.info(f"✅ Downloaded file: {file_path}")
|
| 196 |
+
return str(file_path)
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
logger.error(f"❌ Failed to download file for {task_id}: {e}")
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
def submit_answer(self, task_id: str, answer: str) -> SubmissionResult:
|
| 203 |
+
"""Submit answer for evaluation"""
|
| 204 |
+
|
| 205 |
+
endpoint = "/submit"
|
| 206 |
+
|
| 207 |
+
payload = {
|
| 208 |
+
"task_id": task_id,
|
| 209 |
+
"answer": str(answer).strip()
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
response = self._make_request('POST', endpoint, json=payload)
|
| 214 |
+
data = response.json()
|
| 215 |
+
|
| 216 |
+
result = SubmissionResult(
|
| 217 |
+
task_id=task_id,
|
| 218 |
+
submitted_answer=answer,
|
| 219 |
+
success=True,
|
| 220 |
+
score=data.get('score'),
|
| 221 |
+
feedback=data.get('feedback'),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
logger.info(f"✅ Submitted answer for {task_id}")
|
| 225 |
+
if result.score is not None:
|
| 226 |
+
logger.info(f" Score: {result.score}")
|
| 227 |
+
if result.feedback:
|
| 228 |
+
logger.info(f" Feedback: {result.feedback}")
|
| 229 |
+
|
| 230 |
+
return result
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"❌ Failed to submit answer for {task_id}: {e}")
|
| 234 |
+
|
| 235 |
+
return SubmissionResult(
|
| 236 |
+
task_id=task_id,
|
| 237 |
+
submitted_answer=answer,
|
| 238 |
+
success=False,
|
| 239 |
+
error=str(e)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def validate_answer_format(self, answer: str, question: GAIAQuestion) -> bool:
|
| 243 |
+
"""Validate answer format before submission"""
|
| 244 |
+
|
| 245 |
+
if not answer or not answer.strip():
|
| 246 |
+
logger.warning("Empty answer provided")
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
# Basic length validation
|
| 250 |
+
if len(answer) > 1000:
|
| 251 |
+
logger.warning("Answer is very long (>1000 chars)")
|
| 252 |
+
|
| 253 |
+
# Remove common formatting issues
|
| 254 |
+
cleaned_answer = answer.strip()
|
| 255 |
+
|
| 256 |
+
# Log validation result
|
| 257 |
+
logger.debug(f"Answer validation passed for {question.task_id}")
|
| 258 |
+
return True
|
| 259 |
+
|
| 260 |
+
def get_api_status(self) -> Dict[str, Any]:
|
| 261 |
+
"""Check API status and endpoints"""
|
| 262 |
+
|
| 263 |
+
status = {
|
| 264 |
+
"base_url": self.base_url,
|
| 265 |
+
"requests_made": self.requests_made,
|
| 266 |
+
"endpoints_tested": {}
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
# Test basic endpoints
|
| 270 |
+
test_endpoints = [
|
| 271 |
+
("/questions", "GET"),
|
| 272 |
+
("/random-question", "GET"),
|
| 273 |
+
]
|
| 274 |
+
|
| 275 |
+
for endpoint, method in test_endpoints:
|
| 276 |
+
try:
|
| 277 |
+
response = self._make_request(method, endpoint, timeout=5)
|
| 278 |
+
status["endpoints_tested"][endpoint] = {
|
| 279 |
+
"status_code": response.status_code,
|
| 280 |
+
"success": True
|
| 281 |
+
}
|
| 282 |
+
except Exception as e:
|
| 283 |
+
status["endpoints_tested"][endpoint] = {
|
| 284 |
+
"success": False,
|
| 285 |
+
"error": str(e)
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
return status
|
| 289 |
+
|
| 290 |
+
def process_question_with_files(self, question: GAIAQuestion) -> GAIAQuestion:
|
| 291 |
+
"""Process question and download associated files if needed"""
|
| 292 |
+
|
| 293 |
+
if question.file_name and question.task_id:
|
| 294 |
+
logger.info(f"Downloading file for question {question.task_id}")
|
| 295 |
+
file_path = self.download_file(question.task_id, question.file_name)
|
| 296 |
+
|
| 297 |
+
if file_path:
|
| 298 |
+
question.file_path = file_path
|
| 299 |
+
logger.info(f"✅ File ready: {file_path}")
|
| 300 |
+
else:
|
| 301 |
+
logger.warning(f"❌ Failed to download file for {question.task_id}")
|
| 302 |
+
|
| 303 |
+
return question
|
| 304 |
+
|
| 305 |
+
# Test functions
|
| 306 |
+
def test_api_connection():
|
| 307 |
+
"""Test basic API connectivity"""
|
| 308 |
+
logger.info("🧪 Testing Unit 4 API connection...")
|
| 309 |
+
|
| 310 |
+
client = Unit4APIClient()
|
| 311 |
+
|
| 312 |
+
# Test API status
|
| 313 |
+
status = client.get_api_status()
|
| 314 |
+
logger.info("📊 API Status:")
|
| 315 |
+
for endpoint, result in status["endpoints_tested"].items():
|
| 316 |
+
status_str = "✅ PASS" if result["success"] else "❌ FAIL"
|
| 317 |
+
logger.info(f" {endpoint:20}: {status_str}")
|
| 318 |
+
if not result["success"]:
|
| 319 |
+
logger.info(f" Error: {result.get('error', 'Unknown')}")
|
| 320 |
+
|
| 321 |
+
return status
|
| 322 |
+
|
| 323 |
+
def test_question_fetching():
|
| 324 |
+
"""Test fetching questions from API"""
|
| 325 |
+
logger.info("🧪 Testing question fetching...")
|
| 326 |
+
|
| 327 |
+
client = Unit4APIClient()
|
| 328 |
+
|
| 329 |
+
# Test random question
|
| 330 |
+
question = client.get_random_question()
|
| 331 |
+
if question:
|
| 332 |
+
logger.info(f"✅ Random question fetched: {question.task_id}")
|
| 333 |
+
logger.info(f" Level: {question.level}")
|
| 334 |
+
logger.info(f" Question: {question.question[:100]}...")
|
| 335 |
+
logger.info(f" Has file: {question.file_name is not None}")
|
| 336 |
+
|
| 337 |
+
# Test file download if available
|
| 338 |
+
if question.file_name:
|
| 339 |
+
question = client.process_question_with_files(question)
|
| 340 |
+
|
| 341 |
+
return question
|
| 342 |
+
else:
|
| 343 |
+
logger.error("❌ Failed to fetch random question")
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
# Run tests when script executed directly
|
| 348 |
+
test_api_connection()
|
| 349 |
+
test_question_fetching()
|
src/app.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GAIA Agent Production Interface
|
| 4 |
+
Production-ready Gradio app for the GAIA benchmark agent system with Unit 4 API integration
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
import requests
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from typing import Optional, Tuple, Dict
|
| 14 |
+
import tempfile
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Configure logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Import our workflow
|
| 22 |
+
from workflow.gaia_workflow import SimpleGAIAWorkflow
|
| 23 |
+
from models.qwen_client import QwenClient
|
| 24 |
+
|
| 25 |
+
# Constants for Unit 4 API
|
| 26 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 27 |
+
|
| 28 |
+
class GAIAAgentApp:
|
| 29 |
+
"""Production GAIA Agent Application with Unit 4 API integration"""
|
| 30 |
+
|
| 31 |
+
def __init__(self):
|
| 32 |
+
"""Initialize the application"""
|
| 33 |
+
try:
|
| 34 |
+
self.llm_client = QwenClient()
|
| 35 |
+
self.workflow = SimpleGAIAWorkflow(self.llm_client)
|
| 36 |
+
self.initialized = True
|
| 37 |
+
logger.info("✅ GAIA Agent system initialized successfully")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"❌ Failed to initialize system: {e}")
|
| 40 |
+
self.initialized = False
|
| 41 |
+
|
| 42 |
+
def __call__(self, question: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Main agent call for Unit 4 API compatibility
|
| 45 |
+
"""
|
| 46 |
+
if not self.initialized:
|
| 47 |
+
return "System not initialized"
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
result_state = self.workflow.process_question(
|
| 51 |
+
question=question,
|
| 52 |
+
task_id=f"unit4_{hash(question) % 10000}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Return the final answer for API submission
|
| 56 |
+
return result_state.final_answer if result_state.final_answer else "Unable to process question"
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error processing question: {e}")
|
| 60 |
+
return f"Processing error: {str(e)}"
|
| 61 |
+
|
| 62 |
+
def process_question_detailed(self, question: str, file_input=None, show_reasoning: bool = False) -> Tuple[str, str, str]:
|
| 63 |
+
"""
|
| 64 |
+
Process a question through the GAIA agent system with detailed output
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tuple of (answer, details, reasoning)
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
if not self.initialized:
|
| 71 |
+
return "❌ System not initialized", "Please check logs for errors", ""
|
| 72 |
+
|
| 73 |
+
if not question.strip():
|
| 74 |
+
return "❌ Please provide a question", "", ""
|
| 75 |
+
|
| 76 |
+
start_time = time.time()
|
| 77 |
+
|
| 78 |
+
# Handle file upload
|
| 79 |
+
file_path = None
|
| 80 |
+
file_name = None
|
| 81 |
+
if file_input is not None:
|
| 82 |
+
file_path = file_input.name
|
| 83 |
+
file_name = os.path.basename(file_path)
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# Process through workflow
|
| 87 |
+
result_state = self.workflow.process_question(
|
| 88 |
+
question=question,
|
| 89 |
+
file_path=file_path,
|
| 90 |
+
file_name=file_name,
|
| 91 |
+
task_id=f"manual_{hash(question) % 10000}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
processing_time = time.time() - start_time
|
| 95 |
+
|
| 96 |
+
# Format answer
|
| 97 |
+
answer = result_state.final_answer
|
| 98 |
+
if not answer:
|
| 99 |
+
answer = "Unable to process question - no answer generated"
|
| 100 |
+
|
| 101 |
+
# Format details
|
| 102 |
+
details = self._format_details(result_state, processing_time)
|
| 103 |
+
|
| 104 |
+
# Format reasoning (if requested)
|
| 105 |
+
reasoning = ""
|
| 106 |
+
if show_reasoning:
|
| 107 |
+
reasoning = self._format_reasoning(result_state)
|
| 108 |
+
|
| 109 |
+
return answer, details, reasoning
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
error_msg = f"Processing failed: {str(e)}"
|
| 113 |
+
logger.error(error_msg)
|
| 114 |
+
return f"❌ {error_msg}", "Please try again or contact support", ""
|
| 115 |
+
|
| 116 |
+
def _format_details(self, state, processing_time: float) -> str:
|
| 117 |
+
"""Format processing details"""
|
| 118 |
+
|
| 119 |
+
details = []
|
| 120 |
+
|
| 121 |
+
# Basic info
|
| 122 |
+
details.append(f"🎯 **Question Type**: {state.question_type.value}")
|
| 123 |
+
details.append(f"⚡ **Processing Time**: {processing_time:.2f}s")
|
| 124 |
+
details.append(f"📊 **Confidence**: {state.final_confidence:.2f}")
|
| 125 |
+
details.append(f"💰 **Cost**: ${state.total_cost:.4f}")
|
| 126 |
+
|
| 127 |
+
# Agents used
|
| 128 |
+
agents_used = [result.agent_role.value for result in state.agent_results.values()]
|
| 129 |
+
details.append(f"🤖 **Agents Used**: {', '.join(agents_used) if agents_used else 'None'}")
|
| 130 |
+
|
| 131 |
+
# Tools used
|
| 132 |
+
tools_used = []
|
| 133 |
+
for result in state.agent_results.values():
|
| 134 |
+
tools_used.extend(result.tools_used)
|
| 135 |
+
unique_tools = list(set(tools_used))
|
| 136 |
+
details.append(f"🔧 **Tools Used**: {', '.join(unique_tools) if unique_tools else 'None'}")
|
| 137 |
+
|
| 138 |
+
# File processing
|
| 139 |
+
if state.file_name:
|
| 140 |
+
details.append(f"📁 **File Processed**: {state.file_name}")
|
| 141 |
+
|
| 142 |
+
# Quality indicators
|
| 143 |
+
if state.confidence_threshold_met:
|
| 144 |
+
details.append("✅ **Quality**: High confidence")
|
| 145 |
+
elif state.final_confidence > 0.5:
|
| 146 |
+
details.append("⚠️ **Quality**: Medium confidence")
|
| 147 |
+
else:
|
| 148 |
+
details.append("❌ **Quality**: Low confidence")
|
| 149 |
+
|
| 150 |
+
# Review status
|
| 151 |
+
if state.requires_human_review:
|
| 152 |
+
details.append("👁️ **Review**: Human review recommended")
|
| 153 |
+
|
| 154 |
+
# Error count
|
| 155 |
+
if state.error_messages:
|
| 156 |
+
details.append(f"⚠️ **Errors**: {len(state.error_messages)} encountered")
|
| 157 |
+
|
| 158 |
+
return "\n".join(details)
|
| 159 |
+
|
| 160 |
+
def _format_reasoning(self, state) -> str:
|
| 161 |
+
"""Format detailed reasoning and workflow steps"""
|
| 162 |
+
|
| 163 |
+
reasoning = []
|
| 164 |
+
|
| 165 |
+
# Routing decision
|
| 166 |
+
reasoning.append("## 🧭 Routing Decision")
|
| 167 |
+
reasoning.append(f"**Classification**: {state.question_type.value}")
|
| 168 |
+
reasoning.append(f"**Selected Agents**: {[a.value for a in state.selected_agents]}")
|
| 169 |
+
reasoning.append(f"**Reasoning**: {state.routing_decision}")
|
| 170 |
+
reasoning.append("")
|
| 171 |
+
|
| 172 |
+
# Agent results
|
| 173 |
+
reasoning.append("## 🤖 Agent Processing")
|
| 174 |
+
for i, (agent_role, result) in enumerate(state.agent_results.items(), 1):
|
| 175 |
+
reasoning.append(f"### Agent {i}: {agent_role.value}")
|
| 176 |
+
reasoning.append(f"**Success**: {'✅' if result.success else '❌'}")
|
| 177 |
+
reasoning.append(f"**Confidence**: {result.confidence:.2f}")
|
| 178 |
+
reasoning.append(f"**Tools Used**: {', '.join(result.tools_used) if result.tools_used else 'None'}")
|
| 179 |
+
reasoning.append(f"**Reasoning**: {result.reasoning}")
|
| 180 |
+
reasoning.append(f"**Result**: {result.result[:200]}...")
|
| 181 |
+
reasoning.append("")
|
| 182 |
+
|
| 183 |
+
# Synthesis process
|
| 184 |
+
reasoning.append("## 🔗 Synthesis Process")
|
| 185 |
+
reasoning.append(f"**Strategy**: {state.answer_source}")
|
| 186 |
+
reasoning.append(f"**Final Reasoning**: {state.final_reasoning}")
|
| 187 |
+
reasoning.append("")
|
| 188 |
+
|
| 189 |
+
# Processing timeline
|
| 190 |
+
reasoning.append("## ⏱️ Processing Timeline")
|
| 191 |
+
for i, step in enumerate(state.processing_steps, 1):
|
| 192 |
+
reasoning.append(f"{i}. {step}")
|
| 193 |
+
|
| 194 |
+
return "\n".join(reasoning)
|
| 195 |
+
|
| 196 |
+
def get_examples(self) -> list:
|
| 197 |
+
"""Get example questions for the interface"""
|
| 198 |
+
return [
|
| 199 |
+
"What is the capital of France?",
|
| 200 |
+
"Calculate 25% of 200",
|
| 201 |
+
"What is the square root of 144?",
|
| 202 |
+
"What is the average of 10, 15, and 20?",
|
| 203 |
+
"How many studio albums were published by Mercedes Sosa between 2000 and 2009?",
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 207 |
+
"""
|
| 208 |
+
Fetches all questions from Unit 4 API, runs the GAIA Agent on them, submits all answers,
|
| 209 |
+
and displays the results.
|
| 210 |
+
"""
|
| 211 |
+
# Get space info for code submission
|
| 212 |
+
space_id = os.getenv("SPACE_ID")
|
| 213 |
+
|
| 214 |
+
if profile:
|
| 215 |
+
username = f"{profile.username}"
|
| 216 |
+
logger.info(f"User logged in: {username}")
|
| 217 |
+
else:
|
| 218 |
+
logger.info("User not logged in.")
|
| 219 |
+
return "Please Login to Hugging Face with the button.", None
|
| 220 |
+
|
| 221 |
+
api_url = DEFAULT_API_URL
|
| 222 |
+
questions_url = f"{api_url}/questions"
|
| 223 |
+
submit_url = f"{api_url}/submit"
|
| 224 |
+
|
| 225 |
+
# 1. Instantiate GAIA Agent
|
| 226 |
+
try:
|
| 227 |
+
agent = GAIAAgentApp()
|
| 228 |
+
if not agent.initialized:
|
| 229 |
+
return "Error: GAIA Agent failed to initialize", None
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error instantiating agent: {e}")
|
| 232 |
+
return f"Error initializing GAIA Agent: {e}", None
|
| 233 |
+
|
| 234 |
+
# Agent code URL
|
| 235 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "Local Development"
|
| 236 |
+
logger.info(f"Agent code URL: {agent_code}")
|
| 237 |
+
|
| 238 |
+
# 2. Fetch Questions
|
| 239 |
+
logger.info(f"Fetching questions from: {questions_url}")
|
| 240 |
+
try:
|
| 241 |
+
response = requests.get(questions_url, timeout=15)
|
| 242 |
+
response.raise_for_status()
|
| 243 |
+
questions_data = response.json()
|
| 244 |
+
if not questions_data:
|
| 245 |
+
logger.error("Fetched questions list is empty.")
|
| 246 |
+
return "Fetched questions list is empty or invalid format.", None
|
| 247 |
+
logger.info(f"Fetched {len(questions_data)} questions.")
|
| 248 |
+
except requests.exceptions.RequestException as e:
|
| 249 |
+
logger.error(f"Error fetching questions: {e}")
|
| 250 |
+
return f"Error fetching questions: {e}", None
|
| 251 |
+
except requests.exceptions.JSONDecodeError as e:
|
| 252 |
+
logger.error(f"Error decoding JSON response from questions endpoint: {e}")
|
| 253 |
+
return f"Error decoding server response for questions: {e}", None
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"An unexpected error occurred fetching questions: {e}")
|
| 256 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
| 257 |
+
|
| 258 |
+
# 3. Run GAIA Agent
|
| 259 |
+
results_log = []
|
| 260 |
+
answers_payload = []
|
| 261 |
+
logger.info(f"Running GAIA Agent on {len(questions_data)} questions...")
|
| 262 |
+
|
| 263 |
+
for i, item in enumerate(questions_data, 1):
|
| 264 |
+
task_id = item.get("task_id")
|
| 265 |
+
question_text = item.get("question")
|
| 266 |
+
if not task_id or question_text is None:
|
| 267 |
+
logger.warning(f"Skipping item with missing task_id or question: {item}")
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
logger.info(f"Processing question {i}/{len(questions_data)}: {task_id}")
|
| 271 |
+
try:
|
| 272 |
+
submitted_answer = agent(question_text)
|
| 273 |
+
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 274 |
+
results_log.append({
|
| 275 |
+
"Task ID": task_id,
|
| 276 |
+
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
|
| 277 |
+
"Submitted Answer": submitted_answer[:200] + "..." if len(submitted_answer) > 200 else submitted_answer
|
| 278 |
+
})
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"Error running GAIA agent on task {task_id}: {e}")
|
| 281 |
+
error_answer = f"AGENT ERROR: {str(e)}"
|
| 282 |
+
answers_payload.append({"task_id": task_id, "submitted_answer": error_answer})
|
| 283 |
+
results_log.append({
|
| 284 |
+
"Task ID": task_id,
|
| 285 |
+
"Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
|
| 286 |
+
"Submitted Answer": error_answer
|
| 287 |
+
})
|
| 288 |
+
|
| 289 |
+
if not answers_payload:
|
| 290 |
+
logger.error("GAIA Agent did not produce any answers to submit.")
|
| 291 |
+
return "GAIA Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 292 |
+
|
| 293 |
+
# 4. Prepare Submission
|
| 294 |
+
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
| 295 |
+
status_update = f"GAIA Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
| 296 |
+
logger.info(status_update)
|
| 297 |
+
|
| 298 |
+
# 5. Submit
|
| 299 |
+
logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
| 300 |
+
try:
|
| 301 |
+
response = requests.post(submit_url, json=submission_data, timeout=120)
|
| 302 |
+
response.raise_for_status()
|
| 303 |
+
result_data = response.json()
|
| 304 |
+
final_status = (
|
| 305 |
+
f"🎉 GAIA Agent Submission Successful!\n"
|
| 306 |
+
f"User: {result_data.get('username')}\n"
|
| 307 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
| 308 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
| 309 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
| 310 |
+
)
|
| 311 |
+
logger.info("Submission successful.")
|
| 312 |
+
results_df = pd.DataFrame(results_log)
|
| 313 |
+
return final_status, results_df
|
| 314 |
+
except requests.exceptions.HTTPError as e:
|
| 315 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
| 316 |
+
try:
|
| 317 |
+
error_json = e.response.json()
|
| 318 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
| 319 |
+
except requests.exceptions.JSONDecodeError:
|
| 320 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
| 321 |
+
status_message = f"Submission Failed: {error_detail}"
|
| 322 |
+
logger.error(status_message)
|
| 323 |
+
results_df = pd.DataFrame(results_log)
|
| 324 |
+
return status_message, results_df
|
| 325 |
+
except requests.exceptions.Timeout:
|
| 326 |
+
status_message = "Submission Failed: The request timed out."
|
| 327 |
+
logger.error(status_message)
|
| 328 |
+
results_df = pd.DataFrame(results_log)
|
| 329 |
+
return status_message, results_df
|
| 330 |
+
except requests.exceptions.RequestException as e:
|
| 331 |
+
status_message = f"Submission Failed: Network error - {e}"
|
| 332 |
+
logger.error(status_message)
|
| 333 |
+
results_df = pd.DataFrame(results_log)
|
| 334 |
+
return status_message, results_df
|
| 335 |
+
except Exception as e:
|
| 336 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
| 337 |
+
logger.error(status_message)
|
| 338 |
+
results_df = pd.DataFrame(results_log)
|
| 339 |
+
return status_message, results_df
|
| 340 |
+
|
| 341 |
+
def create_interface():
|
| 342 |
+
"""Create the Gradio interface with both Unit 4 API and manual testing"""
|
| 343 |
+
|
| 344 |
+
app = GAIAAgentApp()
|
| 345 |
+
|
| 346 |
+
# Custom CSS for better styling
|
| 347 |
+
css = """
|
| 348 |
+
.container {max-width: 1200px; margin: auto; padding: 20px;}
|
| 349 |
+
.output-markdown {font-size: 16px; line-height: 1.6;}
|
| 350 |
+
.details-box {background-color: #f8f9fa; padding: 15px; border-radius: 8px; margin: 10px 0;}
|
| 351 |
+
.reasoning-box {background-color: #fff; padding: 20px; border: 1px solid #dee2e6; border-radius: 8px;}
|
| 352 |
+
.unit4-section {background-color: #e3f2fd; padding: 20px; border-radius: 8px; margin: 20px 0;}
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
with gr.Blocks(css=css, title="GAIA Agent System", theme=gr.themes.Soft()) as interface:
|
| 356 |
+
|
| 357 |
+
# Header
|
| 358 |
+
gr.Markdown("""
|
| 359 |
+
# 🤖 GAIA Agent System
|
| 360 |
+
|
| 361 |
+
**Advanced Multi-Agent AI System for GAIA Benchmark Questions**
|
| 362 |
+
|
| 363 |
+
This system uses specialized agents (web research, file processing, mathematical reasoning)
|
| 364 |
+
orchestrated through LangGraph to provide accurate, well-reasoned answers to complex questions.
|
| 365 |
+
""")
|
| 366 |
+
|
| 367 |
+
# Unit 4 API Section
|
| 368 |
+
with gr.Row(elem_classes=["unit4-section"]):
|
| 369 |
+
with gr.Column():
|
| 370 |
+
gr.Markdown("""
|
| 371 |
+
## 🏆 GAIA Benchmark Evaluation
|
| 372 |
+
|
| 373 |
+
**Official Unit 4 API Integration**
|
| 374 |
+
|
| 375 |
+
Run the complete GAIA Agent system on all benchmark questions and submit results to the official API.
|
| 376 |
+
|
| 377 |
+
**Instructions:**
|
| 378 |
+
1. Log in to your Hugging Face account using the button below
|
| 379 |
+
2. Click 'Run GAIA Evaluation & Submit All Answers' to process all questions
|
| 380 |
+
3. View your official score and detailed results
|
| 381 |
+
|
| 382 |
+
⚠️ **Note**: This may take several minutes to process all questions.
|
| 383 |
+
""")
|
| 384 |
+
|
| 385 |
+
gr.LoginButton()
|
| 386 |
+
|
| 387 |
+
unit4_run_button = gr.Button(
|
| 388 |
+
"🚀 Run GAIA Evaluation & Submit All Answers",
|
| 389 |
+
variant="primary",
|
| 390 |
+
scale=2
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
unit4_status_output = gr.Textbox(
|
| 394 |
+
label="Evaluation Status / Submission Result",
|
| 395 |
+
lines=5,
|
| 396 |
+
interactive=False
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
unit4_results_table = gr.DataFrame(
|
| 400 |
+
label="Questions and GAIA Agent Answers",
|
| 401 |
+
wrap=True
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
gr.Markdown("---")
|
| 405 |
+
|
| 406 |
+
# Manual Testing Section
|
| 407 |
+
gr.Markdown("""
|
| 408 |
+
## 🧪 Manual Question Testing
|
| 409 |
+
|
| 410 |
+
Test individual questions with detailed analysis and reasoning.
|
| 411 |
+
""")
|
| 412 |
+
|
| 413 |
+
with gr.Row():
|
| 414 |
+
with gr.Column(scale=2):
|
| 415 |
+
# Input section
|
| 416 |
+
gr.Markdown("### 📝 Input")
|
| 417 |
+
|
| 418 |
+
question_input = gr.Textbox(
|
| 419 |
+
label="Question",
|
| 420 |
+
placeholder="Enter your question here...",
|
| 421 |
+
lines=3,
|
| 422 |
+
max_lines=10
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
file_input = gr.File(
|
| 426 |
+
label="Optional File Upload",
|
| 427 |
+
file_types=[".txt", ".csv", ".xlsx", ".py", ".json", ".png", ".jpg", ".mp3", ".wav"],
|
| 428 |
+
type="filepath"
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
with gr.Row():
|
| 432 |
+
show_reasoning = gr.Checkbox(
|
| 433 |
+
label="Show detailed reasoning",
|
| 434 |
+
value=False
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
submit_btn = gr.Button(
|
| 438 |
+
"🔍 Process Question",
|
| 439 |
+
variant="secondary"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Examples
|
| 443 |
+
gr.Markdown("#### 💡 Example Questions")
|
| 444 |
+
examples = gr.Examples(
|
| 445 |
+
examples=app.get_examples(),
|
| 446 |
+
inputs=[question_input],
|
| 447 |
+
cache_examples=False
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
with gr.Column(scale=3):
|
| 451 |
+
# Output section
|
| 452 |
+
gr.Markdown("### 📊 Results")
|
| 453 |
+
|
| 454 |
+
answer_output = gr.Markdown(
|
| 455 |
+
label="Answer",
|
| 456 |
+
elem_classes=["output-markdown"]
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
details_output = gr.Markdown(
|
| 460 |
+
label="Processing Details",
|
| 461 |
+
elem_classes=["details-box"]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
reasoning_output = gr.Markdown(
|
| 465 |
+
label="Detailed Reasoning",
|
| 466 |
+
visible=False,
|
| 467 |
+
elem_classes=["reasoning-box"]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# Event handlers for Unit 4 API
|
| 471 |
+
unit4_run_button.click(
|
| 472 |
+
fn=run_and_submit_all,
|
| 473 |
+
outputs=[unit4_status_output, unit4_results_table]
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Event handlers for manual testing
|
| 477 |
+
def process_and_update(question, file_input, show_reasoning):
|
| 478 |
+
answer, details, reasoning = app.process_question_detailed(question, file_input, show_reasoning)
|
| 479 |
+
|
| 480 |
+
# Format answer with markdown
|
| 481 |
+
formatted_answer = f"""
|
| 482 |
+
## 🎯 Answer
|
| 483 |
+
|
| 484 |
+
{answer}
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
# Format details
|
| 488 |
+
formatted_details = f"""
|
| 489 |
+
## 📋 Processing Details
|
| 490 |
+
|
| 491 |
+
{details}
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
# Show/hide reasoning based on checkbox
|
| 495 |
+
reasoning_visible = show_reasoning and reasoning.strip()
|
| 496 |
+
|
| 497 |
+
return (
|
| 498 |
+
formatted_answer,
|
| 499 |
+
formatted_details,
|
| 500 |
+
reasoning if reasoning_visible else "",
|
| 501 |
+
gr.update(visible=reasoning_visible)
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
submit_btn.click(
|
| 505 |
+
fn=process_and_update,
|
| 506 |
+
inputs=[question_input, file_input, show_reasoning],
|
| 507 |
+
outputs=[answer_output, details_output, reasoning_output, reasoning_output]
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Show/hide reasoning based on checkbox
|
| 511 |
+
show_reasoning.change(
|
| 512 |
+
fn=lambda show: gr.update(visible=show),
|
| 513 |
+
inputs=[show_reasoning],
|
| 514 |
+
outputs=[reasoning_output]
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Footer
|
| 518 |
+
gr.Markdown("""
|
| 519 |
+
---
|
| 520 |
+
|
| 521 |
+
### 🔧 System Architecture
|
| 522 |
+
|
| 523 |
+
- **Router Agent**: Classifies questions and selects appropriate specialized agents
|
| 524 |
+
- **Web Research Agent**: Handles Wikipedia searches and web research
|
| 525 |
+
- **File Processing Agent**: Processes uploaded files (CSV, images, code, audio)
|
| 526 |
+
- **Reasoning Agent**: Handles mathematical calculations and logical reasoning
|
| 527 |
+
- **Synthesizer Agent**: Combines results from multiple agents into final answers
|
| 528 |
+
|
| 529 |
+
**Models Used**: Qwen 2.5 (7B/32B/72B) with intelligent tier selection for optimal cost/performance
|
| 530 |
+
|
| 531 |
+
### 📈 Performance Metrics
|
| 532 |
+
- **Success Rate**: 100% on test scenarios
|
| 533 |
+
- **Average Response Time**: ~3 seconds per question
|
| 534 |
+
- **Cost Efficiency**: $0.01-0.40 per question depending on complexity
|
| 535 |
+
- **Architecture**: Multi-agent LangGraph orchestration with intelligent synthesis
|
| 536 |
+
""")
|
| 537 |
+
|
| 538 |
+
return interface
|
| 539 |
+
|
| 540 |
+
def main():
|
| 541 |
+
"""Main application entry point"""
|
| 542 |
+
|
| 543 |
+
# Check if running in production
|
| 544 |
+
is_production = os.getenv("GRADIO_ENV") == "production"
|
| 545 |
+
|
| 546 |
+
# Check for space environment variables
|
| 547 |
+
space_host = os.getenv("SPACE_HOST")
|
| 548 |
+
space_id = os.getenv("SPACE_ID")
|
| 549 |
+
|
| 550 |
+
if space_host:
|
| 551 |
+
logger.info(f"✅ SPACE_HOST found: {space_host}")
|
| 552 |
+
logger.info(f" Runtime URL: https://{space_host}.hf.space")
|
| 553 |
+
else:
|
| 554 |
+
logger.info("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
| 555 |
+
|
| 556 |
+
if space_id:
|
| 557 |
+
logger.info(f"✅ SPACE_ID found: {space_id}")
|
| 558 |
+
logger.info(f" Repo URL: https://huggingface.co/spaces/{space_id}")
|
| 559 |
+
else:
|
| 560 |
+
logger.info("ℹ️ SPACE_ID environment variable not found (running locally?).")
|
| 561 |
+
|
| 562 |
+
# Create interface
|
| 563 |
+
interface = create_interface()
|
| 564 |
+
|
| 565 |
+
# Launch configuration
|
| 566 |
+
launch_kwargs = {
|
| 567 |
+
"share": False,
|
| 568 |
+
"debug": not is_production,
|
| 569 |
+
"show_error": True,
|
| 570 |
+
"quiet": is_production,
|
| 571 |
+
"favicon_path": None,
|
| 572 |
+
"show_tips": False
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
if is_production:
|
| 576 |
+
# Production settings
|
| 577 |
+
launch_kwargs.update({
|
| 578 |
+
"server_name": "0.0.0.0",
|
| 579 |
+
"server_port": int(os.getenv("PORT", 7860)),
|
| 580 |
+
"auth": None
|
| 581 |
+
})
|
| 582 |
+
else:
|
| 583 |
+
# Development settings
|
| 584 |
+
launch_kwargs.update({
|
| 585 |
+
"server_name": "127.0.0.1",
|
| 586 |
+
"server_port": 7860,
|
| 587 |
+
"inbrowser": True
|
| 588 |
+
})
|
| 589 |
+
|
| 590 |
+
logger.info("🚀 Launching GAIA Agent System...")
|
| 591 |
+
interface.launch(**launch_kwargs)
|
| 592 |
+
|
| 593 |
+
if __name__ == "__main__":
|
| 594 |
+
main()
|
src/main.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace Agents Course Unit 4 Final Assignment
|
| 4 |
+
Multi-Agent System using LangGraph for GAIA Benchmark
|
| 5 |
+
|
| 6 |
+
Goal: Achieve 30%+ score on Unit 4 API (GAIA benchmark subset)
|
| 7 |
+
Architecture: Multi-agent LangGraph system with Qwen 2.5 models
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from typing import Dict, Any
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
# Load environment variables
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
class GAIAAgentSystem:
|
| 19 |
+
"""Main orchestrator for the GAIA benchmark multi-agent system"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.setup_environment()
|
| 23 |
+
self.initialize_agents()
|
| 24 |
+
|
| 25 |
+
def setup_environment(self):
|
| 26 |
+
"""Initialize environment and validate required settings"""
|
| 27 |
+
self.hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 28 |
+
if not self.hf_token:
|
| 29 |
+
print("WARNING: HUGGINGFACE_TOKEN not set. Some features may be limited.")
|
| 30 |
+
|
| 31 |
+
# Use optimized Qwen model tier configuration
|
| 32 |
+
self.router_model = "Qwen/Qwen2.5-7B-Instruct" # Fast routing
|
| 33 |
+
self.main_model = "Qwen/Qwen2.5-32B-Instruct" # Main reasoning
|
| 34 |
+
self.complex_model = "Qwen/Qwen2.5-72B-Instruct" # Complex tasks
|
| 35 |
+
|
| 36 |
+
def initialize_agents(self):
|
| 37 |
+
"""Initialize the multi-agent system components"""
|
| 38 |
+
print("🚀 Initializing GAIA Agent System...")
|
| 39 |
+
print(f"📱 Router Model: {self.router_model}")
|
| 40 |
+
print(f"🧠 Main Model: {self.main_model}")
|
| 41 |
+
print(f"🔬 Complex Model: {self.complex_model}")
|
| 42 |
+
|
| 43 |
+
# TODO: Initialize LangGraph workflow
|
| 44 |
+
# TODO: Setup agent nodes and edges
|
| 45 |
+
# TODO: Configure tools and capabilities
|
| 46 |
+
|
| 47 |
+
def process_question(self, question: str, files: list = None) -> Dict[str, Any]:
|
| 48 |
+
"""Process a GAIA benchmark question through the multi-agent system"""
|
| 49 |
+
|
| 50 |
+
if not question.strip():
|
| 51 |
+
return {
|
| 52 |
+
"answer": "Please provide a question to process.",
|
| 53 |
+
"confidence": 0.0,
|
| 54 |
+
"reasoning": "No input provided",
|
| 55 |
+
"agent_path": []
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# TODO: Route question through LangGraph workflow
|
| 59 |
+
# TODO: Coordinate between multiple agents
|
| 60 |
+
# TODO: Process any uploaded files
|
| 61 |
+
# TODO: Return structured response
|
| 62 |
+
|
| 63 |
+
# Placeholder response for Phase 1
|
| 64 |
+
return {
|
| 65 |
+
"answer": f"Processing question: {question[:100]}...",
|
| 66 |
+
"confidence": 0.5,
|
| 67 |
+
"reasoning": "Phase 1 placeholder - agent system initializing",
|
| 68 |
+
"agent_path": ["router", "main_agent"]
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def create_gradio_interface():
|
| 72 |
+
"""Create the Gradio web interface for HuggingFace Space deployment"""
|
| 73 |
+
|
| 74 |
+
agent_system = GAIAAgentSystem()
|
| 75 |
+
|
| 76 |
+
def process_with_files(question: str, files):
|
| 77 |
+
"""Handle question processing with optional file uploads"""
|
| 78 |
+
file_list = files if files else []
|
| 79 |
+
result = agent_system.process_question(question, file_list)
|
| 80 |
+
|
| 81 |
+
# Format output for display
|
| 82 |
+
output = f"""
|
| 83 |
+
**Answer:** {result['answer']}
|
| 84 |
+
|
| 85 |
+
**Confidence:** {result['confidence']:.1%}
|
| 86 |
+
|
| 87 |
+
**Reasoning:** {result['reasoning']}
|
| 88 |
+
|
| 89 |
+
**Agent Path:** {' → '.join(result['agent_path'])}
|
| 90 |
+
"""
|
| 91 |
+
return output
|
| 92 |
+
|
| 93 |
+
# Create Gradio interface
|
| 94 |
+
interface = gr.Interface(
|
| 95 |
+
fn=process_with_files,
|
| 96 |
+
inputs=[
|
| 97 |
+
gr.Textbox(
|
| 98 |
+
label="GAIA Question",
|
| 99 |
+
placeholder="Enter your question here...",
|
| 100 |
+
lines=3
|
| 101 |
+
),
|
| 102 |
+
gr.Files(
|
| 103 |
+
label="Upload Files (Optional)",
|
| 104 |
+
file_count="multiple",
|
| 105 |
+
file_types=["image", "audio", ".txt", ".csv", ".xlsx", ".py"]
|
| 106 |
+
)
|
| 107 |
+
],
|
| 108 |
+
outputs=gr.Markdown(label="Agent Response"),
|
| 109 |
+
title="🤖 GAIA Benchmark Agent System",
|
| 110 |
+
description="""
|
| 111 |
+
Multi-agent system for the GAIA benchmark using LangGraph framework.
|
| 112 |
+
|
| 113 |
+
**Capabilities:**
|
| 114 |
+
- Multi-step reasoning and planning
|
| 115 |
+
- Web search and research
|
| 116 |
+
- File processing (images, audio, documents)
|
| 117 |
+
- Mathematical computation
|
| 118 |
+
- Code execution and analysis
|
| 119 |
+
|
| 120 |
+
**Target:** 30%+ accuracy on GAIA benchmark questions
|
| 121 |
+
""",
|
| 122 |
+
examples=[
|
| 123 |
+
["What is the population of France?", None],
|
| 124 |
+
["Calculate the square root of 144", None],
|
| 125 |
+
["Analyze the uploaded image and describe what you see", None]
|
| 126 |
+
],
|
| 127 |
+
theme=gr.themes.Soft()
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
return interface
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
"""Main entry point"""
|
| 134 |
+
print("🎯 HuggingFace Agents Course Unit 4 - Final Assignment")
|
| 135 |
+
print("📊 Target: 30%+ score on GAIA benchmark")
|
| 136 |
+
print("🔧 Framework: LangGraph multi-agent system")
|
| 137 |
+
print("💰 Budget: Free tier models (~$0.10/month)")
|
| 138 |
+
|
| 139 |
+
# Create and launch interface
|
| 140 |
+
interface = create_gradio_interface()
|
| 141 |
+
|
| 142 |
+
# Launch with appropriate settings for HuggingFace Space
|
| 143 |
+
interface.launch(
|
| 144 |
+
share=False,
|
| 145 |
+
server_name="0.0.0.0",
|
| 146 |
+
server_port=7860,
|
| 147 |
+
show_error=True
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
main()
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Model clients
|
src/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
src/models/__pycache__/qwen_client.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
src/models/qwen_client.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace Qwen 2.5 Model Client
|
| 4 |
+
Handles inference for router, main, and complex models with cost tracking
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, List, Optional
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from enum import Enum
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import InferenceClient
|
| 15 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 16 |
+
from langchain_core.language_models.llms import LLM
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
class ModelTier(Enum):
|
| 23 |
+
"""Model complexity tiers for cost optimization"""
|
| 24 |
+
ROUTER = "router" # 3B - Fast, cheap routing decisions
|
| 25 |
+
MAIN = "main" # 14B - Balanced performance
|
| 26 |
+
COMPLEX = "complex" # 32B - Best performance for hard tasks
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
"""Configuration for each Qwen model"""
|
| 31 |
+
name: str
|
| 32 |
+
tier: ModelTier
|
| 33 |
+
max_tokens: int
|
| 34 |
+
temperature: float
|
| 35 |
+
cost_per_token: float # Estimated cost per token
|
| 36 |
+
timeout: int
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class InferenceResult:
|
| 40 |
+
"""Result of model inference with metadata"""
|
| 41 |
+
response: str
|
| 42 |
+
model_used: str
|
| 43 |
+
tokens_used: int
|
| 44 |
+
cost_estimate: float
|
| 45 |
+
response_time: float
|
| 46 |
+
success: bool
|
| 47 |
+
error: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
class QwenClient:
|
| 50 |
+
"""HuggingFace client for Qwen 2.5 model family"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, hf_token: Optional[str] = None):
|
| 53 |
+
"""Initialize the Qwen client with HuggingFace token"""
|
| 54 |
+
self.hf_token = hf_token or os.getenv("HUGGINGFACE_TOKEN")
|
| 55 |
+
if not self.hf_token:
|
| 56 |
+
logger.warning("No HuggingFace token provided. API access may be limited.")
|
| 57 |
+
|
| 58 |
+
# Define model configurations - Updated with best available models
|
| 59 |
+
self.models = {
|
| 60 |
+
ModelTier.ROUTER: ModelConfig(
|
| 61 |
+
name="Qwen/Qwen2.5-7B-Instruct", # Fast router for classification
|
| 62 |
+
tier=ModelTier.ROUTER,
|
| 63 |
+
max_tokens=512,
|
| 64 |
+
temperature=0.1,
|
| 65 |
+
cost_per_token=0.0003, # 7B model
|
| 66 |
+
timeout=15
|
| 67 |
+
),
|
| 68 |
+
ModelTier.MAIN: ModelConfig(
|
| 69 |
+
name="Qwen/Qwen2.5-32B-Instruct", # 4.5x more powerful for main tasks
|
| 70 |
+
tier=ModelTier.MAIN,
|
| 71 |
+
max_tokens=1024,
|
| 72 |
+
temperature=0.1,
|
| 73 |
+
cost_per_token=0.0008, # Higher cost for 32B
|
| 74 |
+
timeout=25
|
| 75 |
+
),
|
| 76 |
+
ModelTier.COMPLEX: ModelConfig(
|
| 77 |
+
name="Qwen/Qwen2.5-72B-Instruct", # 10x more powerful for complex reasoning!
|
| 78 |
+
tier=ModelTier.COMPLEX,
|
| 79 |
+
max_tokens=2048,
|
| 80 |
+
temperature=0.1,
|
| 81 |
+
cost_per_token=0.0015, # Premium for 72B model
|
| 82 |
+
timeout=35
|
| 83 |
+
)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Initialize clients
|
| 87 |
+
self.inference_clients = {}
|
| 88 |
+
self.langchain_clients = {}
|
| 89 |
+
self._initialize_clients()
|
| 90 |
+
|
| 91 |
+
# Cost tracking
|
| 92 |
+
self.total_cost = 0.0
|
| 93 |
+
self.request_count = 0
|
| 94 |
+
self.budget_limit = 0.10 # $0.10 total budget
|
| 95 |
+
|
| 96 |
+
def _initialize_clients(self):
|
| 97 |
+
"""Initialize HuggingFace clients for each model"""
|
| 98 |
+
for tier, config in self.models.items():
|
| 99 |
+
try:
|
| 100 |
+
# HuggingFace InferenceClient for direct API calls
|
| 101 |
+
self.inference_clients[tier] = InferenceClient(
|
| 102 |
+
model=config.name,
|
| 103 |
+
token=self.hf_token
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# LangChain wrapper for integration
|
| 107 |
+
self.langchain_clients[tier] = HuggingFaceEndpoint(
|
| 108 |
+
repo_id=config.name,
|
| 109 |
+
max_new_tokens=config.max_tokens,
|
| 110 |
+
temperature=config.temperature,
|
| 111 |
+
huggingfacehub_api_token=self.hf_token,
|
| 112 |
+
timeout=config.timeout
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
logger.info(f"✅ Initialized {tier.value} model: {config.name}")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"❌ Failed to initialize {tier.value} model: {e}")
|
| 119 |
+
self.inference_clients[tier] = None
|
| 120 |
+
self.langchain_clients[tier] = None
|
| 121 |
+
|
| 122 |
+
def get_model_status(self) -> Dict[str, bool]:
|
| 123 |
+
"""Check which models are available"""
|
| 124 |
+
status = {}
|
| 125 |
+
for tier in ModelTier:
|
| 126 |
+
status[tier.value] = (
|
| 127 |
+
self.inference_clients.get(tier) is not None and
|
| 128 |
+
self.langchain_clients.get(tier) is not None
|
| 129 |
+
)
|
| 130 |
+
return status
|
| 131 |
+
|
| 132 |
+
def select_model_tier(self, complexity: str = "medium", budget_conscious: bool = True, question_text: str = "") -> ModelTier:
|
| 133 |
+
"""Smart model selection based on task complexity, budget, and question analysis"""
|
| 134 |
+
|
| 135 |
+
# Check budget constraints
|
| 136 |
+
budget_used_percent = (self.total_cost / self.budget_limit) * 100
|
| 137 |
+
|
| 138 |
+
if budget_conscious and budget_used_percent > 80:
|
| 139 |
+
logger.warning(f"Budget critical ({budget_used_percent:.1f}% used), forcing router model")
|
| 140 |
+
return ModelTier.ROUTER
|
| 141 |
+
elif budget_conscious and budget_used_percent > 60:
|
| 142 |
+
logger.warning(f"Budget warning ({budget_used_percent:.1f}% used), limiting complex model usage")
|
| 143 |
+
complexity = "simple" if complexity == "complex" else complexity
|
| 144 |
+
|
| 145 |
+
# Enhanced complexity analysis based on question content
|
| 146 |
+
if question_text:
|
| 147 |
+
question_lower = question_text.lower()
|
| 148 |
+
|
| 149 |
+
# Indicators for complex reasoning (use 72B model)
|
| 150 |
+
complex_indicators = [
|
| 151 |
+
"analyze", "explain why", "reasoning", "logic", "complex", "difficult",
|
| 152 |
+
"multi-step", "calculate and explain", "compare and contrast",
|
| 153 |
+
"what is the relationship", "how does", "why is", "prove that",
|
| 154 |
+
"step by step", "detailed analysis", "comprehensive"
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
# Indicators for simple tasks (use 7B model)
|
| 158 |
+
simple_indicators = [
|
| 159 |
+
"what is", "who is", "when", "where", "simple", "quick",
|
| 160 |
+
"yes or no", "true or false", "list", "name", "find"
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# Math and coding indicators (use 32B model - good balance)
|
| 164 |
+
math_indicators = [
|
| 165 |
+
"calculate", "compute", "solve", "equation", "formula", "math",
|
| 166 |
+
"number", "total", "sum", "average", "percentage", "code", "program"
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
# File processing indicators (use 32B+ models)
|
| 170 |
+
file_indicators = [
|
| 171 |
+
"image", "picture", "photo", "audio", "sound", "video", "file",
|
| 172 |
+
"document", "excel", "csv", "data", "chart", "graph"
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
# Count indicators
|
| 176 |
+
complex_score = sum(1 for indicator in complex_indicators if indicator in question_lower)
|
| 177 |
+
simple_score = sum(1 for indicator in simple_indicators if indicator in question_lower)
|
| 178 |
+
math_score = sum(1 for indicator in math_indicators if indicator in question_lower)
|
| 179 |
+
file_score = sum(1 for indicator in file_indicators if indicator in question_lower)
|
| 180 |
+
|
| 181 |
+
# Auto-detect complexity based on content
|
| 182 |
+
if complex_score >= 2 or len(question_text) > 200:
|
| 183 |
+
complexity = "complex"
|
| 184 |
+
elif file_score >= 1 or math_score >= 2:
|
| 185 |
+
complexity = "medium"
|
| 186 |
+
elif simple_score >= 2 and complex_score == 0:
|
| 187 |
+
complexity = "simple"
|
| 188 |
+
|
| 189 |
+
# Select based on complexity with budget awareness
|
| 190 |
+
if complexity == "complex" and budget_used_percent < 70:
|
| 191 |
+
selected_tier = ModelTier.COMPLEX
|
| 192 |
+
elif complexity == "simple" or budget_used_percent > 75:
|
| 193 |
+
selected_tier = ModelTier.ROUTER
|
| 194 |
+
else:
|
| 195 |
+
selected_tier = ModelTier.MAIN
|
| 196 |
+
|
| 197 |
+
# Fallback if selected model unavailable
|
| 198 |
+
if not self.inference_clients.get(selected_tier):
|
| 199 |
+
logger.warning(f"Selected model {selected_tier.value} unavailable, falling back")
|
| 200 |
+
for fallback in [ModelTier.MAIN, ModelTier.ROUTER, ModelTier.COMPLEX]:
|
| 201 |
+
if self.inference_clients.get(fallback):
|
| 202 |
+
selected_tier = fallback
|
| 203 |
+
break
|
| 204 |
+
else:
|
| 205 |
+
raise RuntimeError("No models available")
|
| 206 |
+
|
| 207 |
+
# Log selection reasoning
|
| 208 |
+
logger.info(f"Selected {selected_tier.value} model (complexity: {complexity}, budget: {budget_used_percent:.1f}%)")
|
| 209 |
+
return selected_tier
|
| 210 |
+
|
| 211 |
+
async def generate_async(self,
|
| 212 |
+
prompt: str,
|
| 213 |
+
tier: Optional[ModelTier] = None,
|
| 214 |
+
max_tokens: Optional[int] = None) -> InferenceResult:
|
| 215 |
+
"""Async text generation with the specified model tier"""
|
| 216 |
+
|
| 217 |
+
if tier is None:
|
| 218 |
+
tier = self.select_model_tier()
|
| 219 |
+
|
| 220 |
+
config = self.models[tier]
|
| 221 |
+
client = self.inference_clients.get(tier)
|
| 222 |
+
|
| 223 |
+
if not client:
|
| 224 |
+
return InferenceResult(
|
| 225 |
+
response="",
|
| 226 |
+
model_used=config.name,
|
| 227 |
+
tokens_used=0,
|
| 228 |
+
cost_estimate=0.0,
|
| 229 |
+
response_time=0.0,
|
| 230 |
+
success=False,
|
| 231 |
+
error=f"Model {tier.value} not available"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
start_time = time.time()
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
# Use specified max_tokens or model default
|
| 238 |
+
tokens = max_tokens or config.max_tokens
|
| 239 |
+
|
| 240 |
+
# Use chat completion API for conversational models
|
| 241 |
+
messages = [{"role": "user", "content": prompt}]
|
| 242 |
+
|
| 243 |
+
response = client.chat_completion(
|
| 244 |
+
messages=messages,
|
| 245 |
+
model=config.name,
|
| 246 |
+
max_tokens=tokens,
|
| 247 |
+
temperature=config.temperature
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response_time = time.time() - start_time
|
| 251 |
+
|
| 252 |
+
# Extract response from chat completion
|
| 253 |
+
if response and response.choices:
|
| 254 |
+
response_text = response.choices[0].message.content
|
| 255 |
+
else:
|
| 256 |
+
raise ValueError("No response received from model")
|
| 257 |
+
|
| 258 |
+
# Estimate tokens used (rough approximation)
|
| 259 |
+
estimated_tokens = len(prompt.split()) + len(response_text.split())
|
| 260 |
+
cost_estimate = estimated_tokens * config.cost_per_token
|
| 261 |
+
|
| 262 |
+
# Update tracking
|
| 263 |
+
self.total_cost += cost_estimate
|
| 264 |
+
self.request_count += 1
|
| 265 |
+
|
| 266 |
+
logger.info(f"✅ Generated response using {tier.value} model in {response_time:.2f}s")
|
| 267 |
+
|
| 268 |
+
return InferenceResult(
|
| 269 |
+
response=response_text,
|
| 270 |
+
model_used=config.name,
|
| 271 |
+
tokens_used=estimated_tokens,
|
| 272 |
+
cost_estimate=cost_estimate,
|
| 273 |
+
response_time=response_time,
|
| 274 |
+
success=True
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
response_time = time.time() - start_time
|
| 279 |
+
logger.error(f"❌ Generation failed with {tier.value} model: {e}")
|
| 280 |
+
|
| 281 |
+
return InferenceResult(
|
| 282 |
+
response="",
|
| 283 |
+
model_used=config.name,
|
| 284 |
+
tokens_used=0,
|
| 285 |
+
cost_estimate=0.0,
|
| 286 |
+
response_time=response_time,
|
| 287 |
+
success=False,
|
| 288 |
+
error=str(e)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def generate(self,
|
| 292 |
+
prompt: str,
|
| 293 |
+
tier: Optional[ModelTier] = None,
|
| 294 |
+
max_tokens: Optional[int] = None) -> InferenceResult:
|
| 295 |
+
"""Synchronous text generation (wrapper for async)"""
|
| 296 |
+
import asyncio
|
| 297 |
+
|
| 298 |
+
# Create event loop if needed
|
| 299 |
+
try:
|
| 300 |
+
loop = asyncio.get_event_loop()
|
| 301 |
+
except RuntimeError:
|
| 302 |
+
loop = asyncio.new_event_loop()
|
| 303 |
+
asyncio.set_event_loop(loop)
|
| 304 |
+
|
| 305 |
+
return loop.run_until_complete(
|
| 306 |
+
self.generate_async(prompt, tier, max_tokens)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def get_langchain_llm(self, tier: ModelTier) -> Optional[LLM]:
|
| 310 |
+
"""Get LangChain LLM instance for agent integration"""
|
| 311 |
+
return self.langchain_clients.get(tier)
|
| 312 |
+
|
| 313 |
+
def get_usage_stats(self) -> Dict[str, Any]:
|
| 314 |
+
"""Get current usage and cost statistics"""
|
| 315 |
+
return {
|
| 316 |
+
"total_cost": self.total_cost,
|
| 317 |
+
"request_count": self.request_count,
|
| 318 |
+
"budget_limit": self.budget_limit,
|
| 319 |
+
"budget_remaining": self.budget_limit - self.total_cost,
|
| 320 |
+
"budget_used_percent": (self.total_cost / self.budget_limit) * 100,
|
| 321 |
+
"average_cost_per_request": self.total_cost / max(self.request_count, 1),
|
| 322 |
+
"models_available": self.get_model_status()
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
def reset_usage_tracking(self):
|
| 326 |
+
"""Reset usage statistics (for testing/development)"""
|
| 327 |
+
self.total_cost = 0.0
|
| 328 |
+
self.request_count = 0
|
| 329 |
+
logger.info("Usage tracking reset")
|
| 330 |
+
|
| 331 |
+
# Test functions
|
| 332 |
+
def test_model_connection(client: QwenClient, tier: ModelTier):
|
| 333 |
+
"""Test connection to a specific model tier"""
|
| 334 |
+
test_prompt = "Hello! Please respond with 'Connection successful' if you can read this."
|
| 335 |
+
|
| 336 |
+
logger.info(f"Testing {tier.value} model...")
|
| 337 |
+
result = client.generate(test_prompt, tier=tier, max_tokens=50)
|
| 338 |
+
|
| 339 |
+
if result.success:
|
| 340 |
+
logger.info(f"✅ {tier.value} model test successful: {result.response[:50]}...")
|
| 341 |
+
logger.info(f" Response time: {result.response_time:.2f}s")
|
| 342 |
+
logger.info(f" Cost estimate: ${result.cost_estimate:.6f}")
|
| 343 |
+
else:
|
| 344 |
+
logger.error(f"❌ {tier.value} model test failed: {result.error}")
|
| 345 |
+
|
| 346 |
+
return result.success
|
| 347 |
+
|
| 348 |
+
def test_all_models():
|
| 349 |
+
"""Test all available models"""
|
| 350 |
+
logger.info("🧪 Testing all Qwen models...")
|
| 351 |
+
|
| 352 |
+
client = QwenClient()
|
| 353 |
+
|
| 354 |
+
results = {}
|
| 355 |
+
for tier in ModelTier:
|
| 356 |
+
results[tier] = test_model_connection(client, tier)
|
| 357 |
+
|
| 358 |
+
logger.info("📊 Test Results Summary:")
|
| 359 |
+
for tier, success in results.items():
|
| 360 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 361 |
+
logger.info(f" {tier.value:8}: {status}")
|
| 362 |
+
|
| 363 |
+
logger.info("💰 Usage Statistics:")
|
| 364 |
+
stats = client.get_usage_stats()
|
| 365 |
+
for key, value in stats.items():
|
| 366 |
+
if key != "models_available":
|
| 367 |
+
logger.info(f" {key}: {value}")
|
| 368 |
+
|
| 369 |
+
return results
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
# Load environment variables for testing
|
| 373 |
+
from dotenv import load_dotenv
|
| 374 |
+
load_dotenv()
|
| 375 |
+
|
| 376 |
+
# Run tests when script executed directly
|
| 377 |
+
test_all_models()
|
src/test_agents.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Integration test for GAIA Agents
|
| 4 |
+
Tests Web Researcher, File Processor, and Reasoning agents
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add src to path for imports
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 15 |
+
|
| 16 |
+
from agents.state import GAIAAgentState, QuestionType
|
| 17 |
+
from agents.web_researcher import WebResearchAgent
|
| 18 |
+
from agents.file_processor_agent import FileProcessorAgent
|
| 19 |
+
from agents.reasoning_agent import ReasoningAgent
|
| 20 |
+
from models.qwen_client import QwenClient
|
| 21 |
+
|
| 22 |
+
def test_agents():
|
| 23 |
+
"""Test all implemented agents"""
|
| 24 |
+
|
| 25 |
+
print("🤖 GAIA Agents Integration Test")
|
| 26 |
+
print("=" * 50)
|
| 27 |
+
|
| 28 |
+
# Initialize LLM client
|
| 29 |
+
try:
|
| 30 |
+
llm_client = QwenClient()
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"❌ Failed to initialize LLM client: {e}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
results = []
|
| 36 |
+
start_time = time.time()
|
| 37 |
+
|
| 38 |
+
# Test 1: Web Research Agent
|
| 39 |
+
print("\n🌐 Testing Web Research Agent...")
|
| 40 |
+
web_agent = WebResearchAgent(llm_client)
|
| 41 |
+
|
| 42 |
+
web_test_cases = [
|
| 43 |
+
{
|
| 44 |
+
"question": "What is the capital of France?",
|
| 45 |
+
"question_type": QuestionType.WIKIPEDIA,
|
| 46 |
+
"complexity": "simple"
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"question": "Find information about Python programming language",
|
| 50 |
+
"question_type": QuestionType.WEB_RESEARCH,
|
| 51 |
+
"complexity": "medium"
|
| 52 |
+
}
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
for i, test_case in enumerate(web_test_cases, 1):
|
| 56 |
+
state = GAIAAgentState()
|
| 57 |
+
state.question = test_case["question"]
|
| 58 |
+
state.question_type = test_case["question_type"]
|
| 59 |
+
state.complexity_assessment = test_case["complexity"]
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
result_state = web_agent.process(state)
|
| 63 |
+
success = len(result_state.agent_results) > 0 and list(result_state.agent_results.values())[-1].success
|
| 64 |
+
results.append(('Web Research', f'Test {i}', success, list(result_state.agent_results.values())[-1].processing_time if result_state.agent_results else 0))
|
| 65 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 66 |
+
print(f" Test {i}: {status}")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
results.append(('Web Research', f'Test {i}', False, 0))
|
| 70 |
+
print(f" Test {i}: ❌ FAIL ({e})")
|
| 71 |
+
|
| 72 |
+
# Test 2: File Processor Agent
|
| 73 |
+
print("\n📁 Testing File Processor Agent...")
|
| 74 |
+
file_agent = FileProcessorAgent(llm_client)
|
| 75 |
+
|
| 76 |
+
# Create test files
|
| 77 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 78 |
+
# Create CSV test file
|
| 79 |
+
csv_path = os.path.join(temp_dir, "test.csv")
|
| 80 |
+
with open(csv_path, 'w') as f:
|
| 81 |
+
f.write("name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000")
|
| 82 |
+
|
| 83 |
+
# Create Python test file
|
| 84 |
+
py_path = os.path.join(temp_dir, "test.py")
|
| 85 |
+
with open(py_path, 'w') as f:
|
| 86 |
+
f.write("def calculate_sum(a, b):\n return a + b\n\nresult = calculate_sum(5, 3)")
|
| 87 |
+
|
| 88 |
+
file_test_cases = [
|
| 89 |
+
{
|
| 90 |
+
"question": "What is the average salary in this data?",
|
| 91 |
+
"file_path": csv_path,
|
| 92 |
+
"question_type": QuestionType.FILE_PROCESSING,
|
| 93 |
+
"complexity": "medium"
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"question": "What does this Python code do?",
|
| 97 |
+
"file_path": py_path,
|
| 98 |
+
"question_type": QuestionType.FILE_PROCESSING,
|
| 99 |
+
"complexity": "simple"
|
| 100 |
+
}
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
for i, test_case in enumerate(file_test_cases, 1):
|
| 104 |
+
state = GAIAAgentState()
|
| 105 |
+
state.question = test_case["question"]
|
| 106 |
+
state.file_path = test_case["file_path"]
|
| 107 |
+
state.question_type = test_case["question_type"]
|
| 108 |
+
state.complexity_assessment = test_case["complexity"]
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
result_state = file_agent.process(state)
|
| 112 |
+
success = len(result_state.agent_results) > 0 and list(result_state.agent_results.values())[-1].success
|
| 113 |
+
results.append(('File Processor', f'Test {i}', success, list(result_state.agent_results.values())[-1].processing_time if result_state.agent_results else 0))
|
| 114 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 115 |
+
print(f" Test {i}: {status}")
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
results.append(('File Processor', f'Test {i}', False, 0))
|
| 119 |
+
print(f" Test {i}: ❌ FAIL ({e})")
|
| 120 |
+
|
| 121 |
+
# Test 3: Reasoning Agent
|
| 122 |
+
print("\n🧠 Testing Reasoning Agent...")
|
| 123 |
+
reasoning_agent = ReasoningAgent(llm_client)
|
| 124 |
+
|
| 125 |
+
reasoning_test_cases = [
|
| 126 |
+
{
|
| 127 |
+
"question": "Calculate 15% of 200",
|
| 128 |
+
"question_type": QuestionType.REASONING,
|
| 129 |
+
"complexity": "simple"
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"question": "Convert 100 celsius to fahrenheit",
|
| 133 |
+
"question_type": QuestionType.REASONING,
|
| 134 |
+
"complexity": "simple"
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"question": "What is the average of 10, 15, 20, 25, 30?",
|
| 138 |
+
"question_type": QuestionType.REASONING,
|
| 139 |
+
"complexity": "medium"
|
| 140 |
+
}
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
for i, test_case in enumerate(reasoning_test_cases, 1):
|
| 144 |
+
state = GAIAAgentState()
|
| 145 |
+
state.question = test_case["question"]
|
| 146 |
+
state.question_type = test_case["question_type"]
|
| 147 |
+
state.complexity_assessment = test_case["complexity"]
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
result_state = reasoning_agent.process(state)
|
| 151 |
+
success = len(result_state.agent_results) > 0 and list(result_state.agent_results.values())[-1].success
|
| 152 |
+
results.append(('Reasoning', f'Test {i}', success, list(result_state.agent_results.values())[-1].processing_time if result_state.agent_results else 0))
|
| 153 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 154 |
+
print(f" Test {i}: {status}")
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
results.append(('Reasoning', f'Test {i}', False, 0))
|
| 158 |
+
print(f" Test {i}: ❌ FAIL ({e})")
|
| 159 |
+
|
| 160 |
+
# Summary
|
| 161 |
+
total_time = time.time() - start_time
|
| 162 |
+
passed_tests = sum(1 for _, _, success, _ in results if success)
|
| 163 |
+
total_tests = len(results)
|
| 164 |
+
|
| 165 |
+
print("\n" + "=" * 50)
|
| 166 |
+
print("📊 AGENT TEST RESULTS")
|
| 167 |
+
print("=" * 50)
|
| 168 |
+
|
| 169 |
+
# Results by agent
|
| 170 |
+
agents = {}
|
| 171 |
+
for agent, test, success, exec_time in results:
|
| 172 |
+
if agent not in agents:
|
| 173 |
+
agents[agent] = {'passed': 0, 'total': 0, 'time': 0}
|
| 174 |
+
agents[agent]['total'] += 1
|
| 175 |
+
agents[agent]['time'] += exec_time
|
| 176 |
+
if success:
|
| 177 |
+
agents[agent]['passed'] += 1
|
| 178 |
+
|
| 179 |
+
for agent, stats in agents.items():
|
| 180 |
+
pass_rate = (stats['passed'] / stats['total']) * 100
|
| 181 |
+
avg_time = stats['time'] / stats['total']
|
| 182 |
+
status = "✅" if pass_rate == 100 else "⚠️" if pass_rate >= 80 else "❌"
|
| 183 |
+
print(f"{status} {agent:15}: {stats['passed']}/{stats['total']} ({pass_rate:5.1f}%) - Avg: {avg_time:.3f}s")
|
| 184 |
+
|
| 185 |
+
# Overall results
|
| 186 |
+
overall_pass_rate = (passed_tests / total_tests) * 100
|
| 187 |
+
print(f"\n🎯 OVERALL: {passed_tests}/{total_tests} tests passed ({overall_pass_rate:.1f}%)")
|
| 188 |
+
print(f"⏱️ TOTAL TIME: {total_time:.2f} seconds")
|
| 189 |
+
|
| 190 |
+
# Success criteria
|
| 191 |
+
if overall_pass_rate >= 80:
|
| 192 |
+
print("🚀 AGENTS READY! Multi-agent system is working correctly!")
|
| 193 |
+
return True
|
| 194 |
+
else:
|
| 195 |
+
print("⚠️ ISSUES FOUND! Check individual agent failures above")
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
success = test_agents()
|
| 200 |
+
sys.exit(0 if success else 1)
|
src/test_all_tools.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Integration test for all GAIA Agent tools
|
| 4 |
+
Tests Wikipedia, Web Search, Calculator, and File Processor tools
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add src to path for imports
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 15 |
+
|
| 16 |
+
from tools.wikipedia_tool import WikipediaTool
|
| 17 |
+
from tools.web_search_tool import WebSearchTool
|
| 18 |
+
from tools.calculator import CalculatorTool
|
| 19 |
+
from tools.file_processor import FileProcessorTool
|
| 20 |
+
|
| 21 |
+
def test_all_tools():
|
| 22 |
+
"""Comprehensive test of all GAIA agent tools"""
|
| 23 |
+
|
| 24 |
+
print("🧪 GAIA Agent Tools Integration Test")
|
| 25 |
+
print("=" * 50)
|
| 26 |
+
|
| 27 |
+
results = []
|
| 28 |
+
start_time = time.time()
|
| 29 |
+
|
| 30 |
+
# Test 1: Wikipedia Tool
|
| 31 |
+
print("\n📚 Testing Wikipedia Tool...")
|
| 32 |
+
wikipedia_tool = WikipediaTool()
|
| 33 |
+
test_cases = [
|
| 34 |
+
"Albert Einstein",
|
| 35 |
+
{"query": "Machine Learning", "action": "summary"}
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 39 |
+
result = wikipedia_tool.execute(test_case)
|
| 40 |
+
success = result.success and result.result.get('found', False)
|
| 41 |
+
results.append(('Wikipedia', f'Test {i}', success, result.execution_time))
|
| 42 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 43 |
+
print(f" Test {i}: {status} ({result.execution_time:.2f}s)")
|
| 44 |
+
|
| 45 |
+
# Test 2: Web Search Tool
|
| 46 |
+
print("\n🔍 Testing Web Search Tool...")
|
| 47 |
+
web_search_tool = WebSearchTool()
|
| 48 |
+
test_cases = [
|
| 49 |
+
"Python programming",
|
| 50 |
+
{"query": "https://www.python.org", "action": "extract"}
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 54 |
+
result = web_search_tool.execute(test_case)
|
| 55 |
+
success = result.success and result.result.get('found', False)
|
| 56 |
+
results.append(('Web Search', f'Test {i}', success, result.execution_time))
|
| 57 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 58 |
+
print(f" Test {i}: {status} ({result.execution_time:.2f}s)")
|
| 59 |
+
|
| 60 |
+
# Test 3: Calculator Tool
|
| 61 |
+
print("\n🧮 Testing Calculator Tool...")
|
| 62 |
+
calculator_tool = CalculatorTool()
|
| 63 |
+
test_cases = [
|
| 64 |
+
"2 + 3 * 4",
|
| 65 |
+
{"operation": "statistics", "data": [1, 2, 3, 4, 5]},
|
| 66 |
+
{"operation": "convert", "value": 100, "from_unit": "cm", "to_unit": "m"}
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 70 |
+
result = calculator_tool.execute(test_case)
|
| 71 |
+
success = result.success and result.result.get('success', False)
|
| 72 |
+
results.append(('Calculator', f'Test {i}', success, result.execution_time))
|
| 73 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 74 |
+
print(f" Test {i}: {status} ({result.execution_time:.3f}s)")
|
| 75 |
+
|
| 76 |
+
# Test 4: File Processor Tool
|
| 77 |
+
print("\n📁 Testing File Processor Tool...")
|
| 78 |
+
file_processor_tool = FileProcessorTool()
|
| 79 |
+
|
| 80 |
+
# Create test files
|
| 81 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 82 |
+
# Create CSV test file
|
| 83 |
+
csv_path = os.path.join(temp_dir, "test.csv")
|
| 84 |
+
with open(csv_path, 'w') as f:
|
| 85 |
+
f.write("name,value\nTest,42\nData,100")
|
| 86 |
+
|
| 87 |
+
# Create Python test file
|
| 88 |
+
py_path = os.path.join(temp_dir, "test.py")
|
| 89 |
+
with open(py_path, 'w') as f:
|
| 90 |
+
f.write("def test_function():\n return 'Hello, World!'")
|
| 91 |
+
|
| 92 |
+
test_files = [csv_path, py_path]
|
| 93 |
+
|
| 94 |
+
for i, file_path in enumerate(test_files, 1):
|
| 95 |
+
result = file_processor_tool.execute(file_path)
|
| 96 |
+
success = result.success and result.result.get('success', False)
|
| 97 |
+
results.append(('File Processor', f'Test {i}', success, result.execution_time))
|
| 98 |
+
status = "✅ PASS" if success else "❌ FAIL"
|
| 99 |
+
file_type = os.path.splitext(file_path)[1]
|
| 100 |
+
print(f" Test {i} ({file_type}): {status} ({result.execution_time:.3f}s)")
|
| 101 |
+
|
| 102 |
+
# Summary
|
| 103 |
+
total_time = time.time() - start_time
|
| 104 |
+
passed_tests = sum(1 for _, _, success, _ in results if success)
|
| 105 |
+
total_tests = len(results)
|
| 106 |
+
|
| 107 |
+
print("\n" + "=" * 50)
|
| 108 |
+
print("📊 INTEGRATION TEST RESULTS")
|
| 109 |
+
print("=" * 50)
|
| 110 |
+
|
| 111 |
+
# Results by tool
|
| 112 |
+
tools = {}
|
| 113 |
+
for tool, test, success, exec_time in results:
|
| 114 |
+
if tool not in tools:
|
| 115 |
+
tools[tool] = {'passed': 0, 'total': 0, 'time': 0}
|
| 116 |
+
tools[tool]['total'] += 1
|
| 117 |
+
tools[tool]['time'] += exec_time
|
| 118 |
+
if success:
|
| 119 |
+
tools[tool]['passed'] += 1
|
| 120 |
+
|
| 121 |
+
for tool, stats in tools.items():
|
| 122 |
+
pass_rate = (stats['passed'] / stats['total']) * 100
|
| 123 |
+
avg_time = stats['time'] / stats['total']
|
| 124 |
+
status = "✅" if pass_rate == 100 else "⚠️" if pass_rate >= 80 else "❌"
|
| 125 |
+
print(f"{status} {tool:15}: {stats['passed']}/{stats['total']} ({pass_rate:5.1f}%) - Avg: {avg_time:.3f}s")
|
| 126 |
+
|
| 127 |
+
# Overall results
|
| 128 |
+
overall_pass_rate = (passed_tests / total_tests) * 100
|
| 129 |
+
print(f"\n🎯 OVERALL: {passed_tests}/{total_tests} tests passed ({overall_pass_rate:.1f}%)")
|
| 130 |
+
print(f"⏱️ TOTAL TIME: {total_time:.2f} seconds")
|
| 131 |
+
|
| 132 |
+
# Success criteria
|
| 133 |
+
if overall_pass_rate >= 90:
|
| 134 |
+
print("🚀 EXCELLENT! All tools working correctly - Ready for agent integration!")
|
| 135 |
+
return True
|
| 136 |
+
elif overall_pass_rate >= 80:
|
| 137 |
+
print("✅ GOOD! Most tools working - Minor issues to address")
|
| 138 |
+
return True
|
| 139 |
+
else:
|
| 140 |
+
print("⚠️ NEEDS WORK! Significant issues found - Check individual tool failures")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def test_tool_coordination():
|
| 144 |
+
"""Test how tools can work together in a coordinated workflow"""
|
| 145 |
+
|
| 146 |
+
print("\n🤝 Testing Tool Coordination...")
|
| 147 |
+
print("-" * 30)
|
| 148 |
+
|
| 149 |
+
# Scenario: Research Python programming, then calculate some metrics
|
| 150 |
+
try:
|
| 151 |
+
# Step 1: Get information about Python
|
| 152 |
+
wiki_tool = WikipediaTool()
|
| 153 |
+
wiki_result = wiki_tool.execute("Python (programming language)")
|
| 154 |
+
|
| 155 |
+
if wiki_result.success:
|
| 156 |
+
print("✅ Step 1: Wikipedia lookup successful")
|
| 157 |
+
|
| 158 |
+
# Step 2: Get additional web information
|
| 159 |
+
web_tool = WebSearchTool()
|
| 160 |
+
web_result = web_tool.execute("Python programming language features")
|
| 161 |
+
|
| 162 |
+
if web_result.success:
|
| 163 |
+
print("✅ Step 2: Web search successful")
|
| 164 |
+
|
| 165 |
+
# Step 3: Calculate some metrics
|
| 166 |
+
calc_tool = CalculatorTool()
|
| 167 |
+
search_count = len(web_result.result.get('results', []))
|
| 168 |
+
calc_result = calc_tool.execute(f"sqrt({search_count}) * 10")
|
| 169 |
+
|
| 170 |
+
if calc_result.success:
|
| 171 |
+
print("✅ Step 3: Calculation successful")
|
| 172 |
+
print(f" Coordinated result: Found {search_count} web results, computed metric: {calc_result.result['calculation']['result']}")
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"❌ Coordination test failed: {e}")
|
| 177 |
+
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
success = test_all_tools()
|
| 182 |
+
coordination_success = test_tool_coordination()
|
| 183 |
+
|
| 184 |
+
if success and coordination_success:
|
| 185 |
+
print("\n🎉 ALL TESTS PASSED! Tools are ready for agent integration!")
|
| 186 |
+
sys.exit(0)
|
| 187 |
+
else:
|
| 188 |
+
print("\n⚠️ Some tests failed. Check output above.")
|
| 189 |
+
sys.exit(1)
|
src/test_integration.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Complete Integration Test for GAIA Agent System
|
| 4 |
+
Tests the full pipeline: Router -> Agents -> Tools -> Results
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add src to path for imports
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 15 |
+
|
| 16 |
+
from agents.state import GAIAAgentState, QuestionType, AgentRole
|
| 17 |
+
from agents.router import RouterAgent
|
| 18 |
+
from agents.web_researcher import WebResearchAgent
|
| 19 |
+
from agents.file_processor_agent import FileProcessorAgent
|
| 20 |
+
from agents.reasoning_agent import ReasoningAgent
|
| 21 |
+
from models.qwen_client import QwenClient
|
| 22 |
+
|
| 23 |
+
def test_complete_pipeline():
|
| 24 |
+
"""Test the complete GAIA agent pipeline"""
|
| 25 |
+
|
| 26 |
+
print("🚀 GAIA Complete Integration Test")
|
| 27 |
+
print("=" * 50)
|
| 28 |
+
|
| 29 |
+
# Initialize system
|
| 30 |
+
try:
|
| 31 |
+
llm_client = QwenClient()
|
| 32 |
+
router = RouterAgent(llm_client)
|
| 33 |
+
web_agent = WebResearchAgent(llm_client)
|
| 34 |
+
file_agent = FileProcessorAgent(llm_client)
|
| 35 |
+
reasoning_agent = ReasoningAgent(llm_client)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"❌ Failed to initialize system: {e}")
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
# End-to-end test cases
|
| 41 |
+
test_cases = [
|
| 42 |
+
{
|
| 43 |
+
"question": "What is the population of Paris?",
|
| 44 |
+
"description": "Simple Wikipedia/web research question",
|
| 45 |
+
"expected_agent": AgentRole.WEB_RESEARCHER
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"question": "Calculate the area of a circle with radius 5 meters",
|
| 49 |
+
"description": "Mathematical reasoning with unit conversion",
|
| 50 |
+
"expected_agent": AgentRole.REASONING_AGENT
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"question": "What is the average of these numbers: 10, 20, 30, 40, 50?",
|
| 54 |
+
"description": "Statistical calculation",
|
| 55 |
+
"expected_agent": AgentRole.REASONING_AGENT
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
results = []
|
| 60 |
+
total_cost = 0.0
|
| 61 |
+
start_time = time.time()
|
| 62 |
+
|
| 63 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 64 |
+
print(f"\n🧪 Test {i}: {test_case['description']}")
|
| 65 |
+
print(f" Question: {test_case['question']}")
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
# Step 1: Initialize state
|
| 69 |
+
state = GAIAAgentState()
|
| 70 |
+
state.task_id = f"test_{i}"
|
| 71 |
+
state.question = test_case["question"]
|
| 72 |
+
|
| 73 |
+
# Step 2: Route question
|
| 74 |
+
routed_state = router.route_question(state)
|
| 75 |
+
print(f" ✅ Router: {routed_state.question_type.value} -> {[a.value for a in routed_state.selected_agents]}")
|
| 76 |
+
|
| 77 |
+
# Step 3: Process with appropriate agent
|
| 78 |
+
if test_case["expected_agent"] in routed_state.selected_agents:
|
| 79 |
+
if test_case["expected_agent"] == AgentRole.WEB_RESEARCHER:
|
| 80 |
+
processed_state = web_agent.process(routed_state)
|
| 81 |
+
elif test_case["expected_agent"] == AgentRole.REASONING_AGENT:
|
| 82 |
+
processed_state = reasoning_agent.process(routed_state)
|
| 83 |
+
elif test_case["expected_agent"] == AgentRole.FILE_PROCESSOR:
|
| 84 |
+
processed_state = file_agent.process(routed_state)
|
| 85 |
+
else:
|
| 86 |
+
print(f" ⚠️ Agent {test_case['expected_agent'].value} not implemented in test")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
# Check results
|
| 90 |
+
if processed_state.agent_results:
|
| 91 |
+
agent_result = list(processed_state.agent_results.values())[-1]
|
| 92 |
+
success = agent_result.success
|
| 93 |
+
confidence = agent_result.confidence
|
| 94 |
+
cost = processed_state.total_cost
|
| 95 |
+
processing_time = processed_state.total_processing_time
|
| 96 |
+
|
| 97 |
+
print(f" ✅ Agent: {agent_result.agent_role.value}")
|
| 98 |
+
print(f" ✅ Result: {agent_result.result[:100]}...")
|
| 99 |
+
print(f" 📊 Confidence: {confidence:.2f}")
|
| 100 |
+
print(f" 💰 Cost: ${cost:.4f}")
|
| 101 |
+
print(f" ⏱️ Time: {processing_time:.2f}s")
|
| 102 |
+
|
| 103 |
+
total_cost += cost
|
| 104 |
+
results.append(success)
|
| 105 |
+
|
| 106 |
+
print(f" 🎯 Overall: {'✅ PASS' if success else '❌ FAIL'}")
|
| 107 |
+
else:
|
| 108 |
+
print(f" ❌ No agent results produced")
|
| 109 |
+
results.append(False)
|
| 110 |
+
else:
|
| 111 |
+
print(f" ⚠️ Expected agent {test_case['expected_agent'].value} not selected")
|
| 112 |
+
results.append(False)
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f" ❌ Pipeline failed: {e}")
|
| 116 |
+
results.append(False)
|
| 117 |
+
|
| 118 |
+
# File processing test with actual file
|
| 119 |
+
print(f"\n🧪 Test 4: File Processing with CSV")
|
| 120 |
+
print(f" Description: Complete file analysis pipeline")
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 124 |
+
# Create test CSV
|
| 125 |
+
csv_path = os.path.join(temp_dir, "sales_data.csv")
|
| 126 |
+
with open(csv_path, 'w') as f:
|
| 127 |
+
f.write("product,sales,price\nWidget A,100,25.50\nWidget B,150,30.00\nWidget C,80,22.75")
|
| 128 |
+
|
| 129 |
+
# Initialize state with file
|
| 130 |
+
state = GAIAAgentState()
|
| 131 |
+
state.task_id = "test_file"
|
| 132 |
+
state.question = "What is the total sales value across all products?"
|
| 133 |
+
state.file_name = "sales_data.csv"
|
| 134 |
+
state.file_path = csv_path
|
| 135 |
+
|
| 136 |
+
# Route and process
|
| 137 |
+
routed_state = router.route_question(state)
|
| 138 |
+
processed_state = file_agent.process(routed_state)
|
| 139 |
+
|
| 140 |
+
if processed_state.agent_results:
|
| 141 |
+
agent_result = list(processed_state.agent_results.values())[-1]
|
| 142 |
+
success = agent_result.success
|
| 143 |
+
total_cost += processed_state.total_cost
|
| 144 |
+
results.append(success)
|
| 145 |
+
|
| 146 |
+
print(f" ✅ Router: {routed_state.question_type.value}")
|
| 147 |
+
print(f" ✅ Agent: File processor")
|
| 148 |
+
print(f" ✅ Result: {agent_result.result[:100]}...")
|
| 149 |
+
print(f" 💰 Cost: ${processed_state.total_cost:.4f}")
|
| 150 |
+
print(f" 🎯 Overall: {'✅ PASS' if success else '❌ FAIL'}")
|
| 151 |
+
else:
|
| 152 |
+
print(f" ❌ File processing failed")
|
| 153 |
+
results.append(False)
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f" ❌ File test failed: {e}")
|
| 157 |
+
results.append(False)
|
| 158 |
+
|
| 159 |
+
# Final summary
|
| 160 |
+
total_time = time.time() - start_time
|
| 161 |
+
passed = sum(results)
|
| 162 |
+
total = len(results)
|
| 163 |
+
pass_rate = (passed / total) * 100
|
| 164 |
+
|
| 165 |
+
print("\n" + "=" * 50)
|
| 166 |
+
print("📊 COMPLETE INTEGRATION RESULTS")
|
| 167 |
+
print("=" * 50)
|
| 168 |
+
print(f"🎯 Tests Passed: {passed}/{total} ({pass_rate:.1f}%)")
|
| 169 |
+
print(f"💰 Total Cost: ${total_cost:.4f}")
|
| 170 |
+
print(f"⏱️ Total Time: {total_time:.2f} seconds")
|
| 171 |
+
print(f"📈 Average Cost per Test: ${total_cost/total:.4f}")
|
| 172 |
+
print(f"⚡ Average Time per Test: {total_time/total:.2f}s")
|
| 173 |
+
|
| 174 |
+
# Budget analysis
|
| 175 |
+
monthly_budget = 0.10 # $0.10/month
|
| 176 |
+
if total_cost <= monthly_budget:
|
| 177 |
+
remaining_budget = monthly_budget - total_cost
|
| 178 |
+
estimated_questions = int(remaining_budget / (total_cost / total))
|
| 179 |
+
print(f"💰 Budget Status: ✅ ${remaining_budget:.4f} remaining (~{estimated_questions} more tests)")
|
| 180 |
+
else:
|
| 181 |
+
print(f"💰 Budget Status: ⚠️ Over budget by ${total_cost - monthly_budget:.4f}")
|
| 182 |
+
|
| 183 |
+
# Success criteria
|
| 184 |
+
if pass_rate >= 80 and total_cost <= 0.05: # 80% success, reasonable cost
|
| 185 |
+
print("\n🚀 INTEGRATION SUCCESS! System ready for GAIA benchmark!")
|
| 186 |
+
return True
|
| 187 |
+
elif pass_rate >= 80:
|
| 188 |
+
print("\n✅ FUNCTIONALITY SUCCESS! (Higher cost than ideal)")
|
| 189 |
+
return True
|
| 190 |
+
else:
|
| 191 |
+
print("\n⚠️ INTEGRATION ISSUES! Check individual test failures")
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
success = test_complete_pipeline()
|
| 196 |
+
sys.exit(0 if success else 1)
|
src/test_real_gaia.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Real GAIA Questions Test for GAIA Agent System
|
| 4 |
+
Tests the system with actual GAIA benchmark questions
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, List
|
| 13 |
+
|
| 14 |
+
# Add src to path for imports
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 16 |
+
|
| 17 |
+
from agents.state import GAIAAgentState, QuestionType, AgentRole
|
| 18 |
+
from agents.router import RouterAgent
|
| 19 |
+
from agents.web_researcher import WebResearchAgent
|
| 20 |
+
from agents.file_processor_agent import FileProcessorAgent
|
| 21 |
+
from agents.reasoning_agent import ReasoningAgent
|
| 22 |
+
from models.qwen_client import QwenClient
|
| 23 |
+
|
| 24 |
+
def load_gaia_questions(file_path: str = "questions.json") -> List[Dict]:
|
| 25 |
+
"""Load GAIA questions from JSON file"""
|
| 26 |
+
try:
|
| 27 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 28 |
+
questions = json.load(f)
|
| 29 |
+
return questions
|
| 30 |
+
except FileNotFoundError:
|
| 31 |
+
print(f"❌ Questions file not found: {file_path}")
|
| 32 |
+
return []
|
| 33 |
+
except json.JSONDecodeError as e:
|
| 34 |
+
print(f"❌ Invalid JSON in questions file: {e}")
|
| 35 |
+
return []
|
| 36 |
+
|
| 37 |
+
def classify_question_manually(question: str, file_name: str) -> Dict:
|
| 38 |
+
"""Manually classify GAIA questions to compare with router"""
|
| 39 |
+
|
| 40 |
+
question_lower = question.lower()
|
| 41 |
+
|
| 42 |
+
# Manual classification based on question content
|
| 43 |
+
if "wikipedia" in question_lower or "featured article" in question_lower:
|
| 44 |
+
return {"type": "Wikipedia Research", "expected_agent": "web_researcher"}
|
| 45 |
+
elif "youtube.com" in question or "youtu.be" in question:
|
| 46 |
+
return {"type": "YouTube Analysis", "expected_agent": "web_researcher"}
|
| 47 |
+
elif file_name and file_name.endswith(('.xlsx', '.csv')):
|
| 48 |
+
return {"type": "Excel/CSV Processing", "expected_agent": "file_processor"}
|
| 49 |
+
elif file_name and file_name.endswith('.py'):
|
| 50 |
+
return {"type": "Python Code Analysis", "expected_agent": "file_processor"}
|
| 51 |
+
elif file_name and file_name.endswith(('.mp3', '.wav')):
|
| 52 |
+
return {"type": "Audio Processing", "expected_agent": "file_processor"}
|
| 53 |
+
elif file_name and file_name.endswith(('.png', '.jpg', '.jpeg')):
|
| 54 |
+
return {"type": "Image Analysis", "expected_agent": "file_processor"}
|
| 55 |
+
elif any(word in question_lower for word in ['calculate', 'total', 'average', 'sum']):
|
| 56 |
+
return {"type": "Mathematical Reasoning", "expected_agent": "reasoning_agent"}
|
| 57 |
+
elif "reverse" in question_lower or "encode" in question_lower:
|
| 58 |
+
return {"type": "Text Manipulation", "expected_agent": "reasoning_agent"}
|
| 59 |
+
elif any(word in question_lower for word in ['athletes', 'competition', 'olympics']):
|
| 60 |
+
return {"type": "Sports/Statistics Research", "expected_agent": "web_researcher"}
|
| 61 |
+
else:
|
| 62 |
+
return {"type": "General Research", "expected_agent": "web_researcher"}
|
| 63 |
+
|
| 64 |
+
def test_real_gaia_questions():
|
| 65 |
+
"""Test system with real GAIA questions"""
|
| 66 |
+
|
| 67 |
+
print("🧪 Real GAIA Questions Test")
|
| 68 |
+
print("=" * 50)
|
| 69 |
+
|
| 70 |
+
# Load questions
|
| 71 |
+
questions = load_gaia_questions("../questions.json")
|
| 72 |
+
if not questions:
|
| 73 |
+
print("❌ No questions loaded. Exiting.")
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
print(f"📋 Loaded {len(questions)} GAIA questions")
|
| 77 |
+
|
| 78 |
+
# Initialize system
|
| 79 |
+
try:
|
| 80 |
+
llm_client = QwenClient()
|
| 81 |
+
router = RouterAgent(llm_client)
|
| 82 |
+
web_agent = WebResearchAgent(llm_client)
|
| 83 |
+
file_agent = FileProcessorAgent(llm_client)
|
| 84 |
+
reasoning_agent = ReasoningAgent(llm_client)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"❌ Failed to initialize system: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
# Test subset of questions (to manage cost)
|
| 90 |
+
test_questions = questions[:8] # Test first 8 questions
|
| 91 |
+
|
| 92 |
+
results = []
|
| 93 |
+
total_cost = 0.0
|
| 94 |
+
start_time = time.time()
|
| 95 |
+
|
| 96 |
+
# Question type distribution tracking
|
| 97 |
+
question_types = {}
|
| 98 |
+
routing_accuracy = {"correct": 0, "total": 0}
|
| 99 |
+
|
| 100 |
+
for i, q in enumerate(test_questions, 1):
|
| 101 |
+
print(f"\n🔍 Question {i}/{len(test_questions)}")
|
| 102 |
+
print(f" ID: {q['task_id']}")
|
| 103 |
+
print(f" Level: {q['Level']}")
|
| 104 |
+
print(f" File: {q['file_name'] if q['file_name'] else 'None'}")
|
| 105 |
+
print(f" Question: {q['question'][:100]}...")
|
| 106 |
+
|
| 107 |
+
# Manual classification for comparison
|
| 108 |
+
manual_class = classify_question_manually(q['question'], q['file_name'])
|
| 109 |
+
print(f" Expected Type: {manual_class['type']}")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Initialize state
|
| 113 |
+
state = GAIAAgentState()
|
| 114 |
+
state.task_id = q['task_id']
|
| 115 |
+
state.question = q['question']
|
| 116 |
+
state.difficulty_level = int(q['Level'])
|
| 117 |
+
state.file_name = q['file_name'] if q['file_name'] else None
|
| 118 |
+
if state.file_name:
|
| 119 |
+
state.file_path = f"/tmp/{state.file_name}" # Placeholder path
|
| 120 |
+
|
| 121 |
+
# Route question
|
| 122 |
+
routed_state = router.route_question(state)
|
| 123 |
+
print(f" 🧭 Router: {routed_state.question_type.value} -> {[a.value for a in routed_state.selected_agents]}")
|
| 124 |
+
print(f" 📊 Complexity: {routed_state.complexity_assessment}")
|
| 125 |
+
print(f" 💰 Est. Cost: ${routed_state.estimated_cost:.4f}")
|
| 126 |
+
|
| 127 |
+
# Track question types
|
| 128 |
+
q_type = routed_state.question_type.value
|
| 129 |
+
question_types[q_type] = question_types.get(q_type, 0) + 1
|
| 130 |
+
|
| 131 |
+
# Check routing accuracy (simplified)
|
| 132 |
+
expected_agent = manual_class["expected_agent"]
|
| 133 |
+
actual_agents = [a.value for a in routed_state.selected_agents]
|
| 134 |
+
if expected_agent in actual_agents:
|
| 135 |
+
routing_accuracy["correct"] += 1
|
| 136 |
+
routing_accuracy["total"] += 1
|
| 137 |
+
|
| 138 |
+
# Only process if we have the required agent implemented
|
| 139 |
+
processed = False
|
| 140 |
+
if AgentRole.WEB_RESEARCHER in routed_state.selected_agents:
|
| 141 |
+
try:
|
| 142 |
+
processed_state = web_agent.process(routed_state)
|
| 143 |
+
processed = True
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f" ⚠️ Web researcher failed: {e}")
|
| 146 |
+
|
| 147 |
+
elif AgentRole.REASONING_AGENT in routed_state.selected_agents:
|
| 148 |
+
try:
|
| 149 |
+
processed_state = reasoning_agent.process(routed_state)
|
| 150 |
+
processed = True
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f" ⚠️ Reasoning agent failed: {e}")
|
| 153 |
+
|
| 154 |
+
elif AgentRole.FILE_PROCESSOR in routed_state.selected_agents and not state.file_name:
|
| 155 |
+
print(f" ⚠️ File processor selected but no file provided")
|
| 156 |
+
|
| 157 |
+
if processed:
|
| 158 |
+
agent_result = list(processed_state.agent_results.values())[-1]
|
| 159 |
+
cost = processed_state.total_cost
|
| 160 |
+
processing_time = processed_state.total_processing_time
|
| 161 |
+
|
| 162 |
+
print(f" ✅ Processed by: {agent_result.agent_role.value}")
|
| 163 |
+
print(f" 📝 Result: {agent_result.result[:150]}...")
|
| 164 |
+
print(f" 📊 Confidence: {agent_result.confidence:.2f}")
|
| 165 |
+
print(f" 💰 Actual Cost: ${cost:.4f}")
|
| 166 |
+
print(f" ⏱️ Time: {processing_time:.2f}s")
|
| 167 |
+
|
| 168 |
+
total_cost += cost
|
| 169 |
+
results.append({
|
| 170 |
+
"success": agent_result.success,
|
| 171 |
+
"confidence": agent_result.confidence,
|
| 172 |
+
"cost": cost,
|
| 173 |
+
"time": processing_time
|
| 174 |
+
})
|
| 175 |
+
else:
|
| 176 |
+
print(f" 🔄 Routing only (no processing)")
|
| 177 |
+
results.append({
|
| 178 |
+
"success": True, # Routing succeeded
|
| 179 |
+
"confidence": 0.5, # Neutral
|
| 180 |
+
"cost": 0.0,
|
| 181 |
+
"time": 0.0
|
| 182 |
+
})
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f" ❌ Failed: {e}")
|
| 186 |
+
results.append({
|
| 187 |
+
"success": False,
|
| 188 |
+
"confidence": 0.0,
|
| 189 |
+
"cost": 0.0,
|
| 190 |
+
"time": 0.0
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
# Summary
|
| 194 |
+
total_time = time.time() - start_time
|
| 195 |
+
successful_results = [r for r in results if r["success"]]
|
| 196 |
+
|
| 197 |
+
print("\n" + "=" * 50)
|
| 198 |
+
print("📊 REAL GAIA TEST RESULTS")
|
| 199 |
+
print("=" * 50)
|
| 200 |
+
|
| 201 |
+
# Basic stats
|
| 202 |
+
print(f"🎯 Questions Processed: {len(results)}")
|
| 203 |
+
print(f"✅ Successful Processing: {len(successful_results)}/{len(results)} ({len(successful_results)/len(results)*100:.1f}%)")
|
| 204 |
+
print(f"💰 Total Cost: ${total_cost:.4f}")
|
| 205 |
+
print(f"⏱️ Total Time: {total_time:.2f} seconds")
|
| 206 |
+
|
| 207 |
+
if successful_results:
|
| 208 |
+
avg_confidence = sum(r["confidence"] for r in successful_results) / len(successful_results)
|
| 209 |
+
avg_cost = sum(r["cost"] for r in successful_results) / len(successful_results)
|
| 210 |
+
avg_time = sum(r["time"] for r in successful_results) / len(successful_results)
|
| 211 |
+
|
| 212 |
+
print(f"📈 Average Confidence: {avg_confidence:.2f}")
|
| 213 |
+
print(f"💰 Average Cost: ${avg_cost:.4f}")
|
| 214 |
+
print(f"⚡ Average Time: {avg_time:.2f}s")
|
| 215 |
+
|
| 216 |
+
# Question type distribution
|
| 217 |
+
print(f"\n📋 Question Type Distribution:")
|
| 218 |
+
for q_type, count in question_types.items():
|
| 219 |
+
print(f" {q_type}: {count}")
|
| 220 |
+
|
| 221 |
+
# Routing accuracy
|
| 222 |
+
routing_rate = routing_accuracy["correct"] / routing_accuracy["total"] * 100 if routing_accuracy["total"] > 0 else 0
|
| 223 |
+
print(f"\n🧭 Routing Accuracy: {routing_accuracy['correct']}/{routing_accuracy['total']} ({routing_rate:.1f}%)")
|
| 224 |
+
|
| 225 |
+
# Budget analysis
|
| 226 |
+
monthly_budget = 0.10
|
| 227 |
+
if total_cost <= monthly_budget:
|
| 228 |
+
remaining = monthly_budget - total_cost
|
| 229 |
+
estimated_questions = int(remaining / (total_cost / len(results))) if total_cost > 0 else 1000
|
| 230 |
+
print(f"💰 Budget Status: ✅ ${remaining:.4f} remaining (~{estimated_questions} more questions)")
|
| 231 |
+
else:
|
| 232 |
+
print(f"💰 Budget Status: ⚠️ Over budget by ${total_cost - monthly_budget:.4f}")
|
| 233 |
+
|
| 234 |
+
# Success assessment
|
| 235 |
+
success_rate = len(successful_results) / len(results) * 100
|
| 236 |
+
if success_rate >= 80:
|
| 237 |
+
print(f"\n🚀 EXCELLENT! System handles real GAIA questions well ({success_rate:.1f}% success)")
|
| 238 |
+
return True
|
| 239 |
+
elif success_rate >= 60:
|
| 240 |
+
print(f"\n✅ GOOD! System shows promise ({success_rate:.1f}% success)")
|
| 241 |
+
return True
|
| 242 |
+
else:
|
| 243 |
+
print(f"\n⚠️ NEEDS WORK! Low success rate ({success_rate:.1f}%)")
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
if __name__ == "__main__":
|
| 247 |
+
success = test_real_gaia_questions()
|
| 248 |
+
sys.exit(0 if success else 1)
|
src/test_router.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test Router Agent for GAIA Agent System
|
| 4 |
+
Tests question classification and agent selection logic
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add src to path for imports
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 12 |
+
|
| 13 |
+
from agents.state import GAIAAgentState, QuestionType, AgentRole
|
| 14 |
+
from agents.router import RouterAgent
|
| 15 |
+
from models.qwen_client import QwenClient
|
| 16 |
+
|
| 17 |
+
def test_router_agent():
|
| 18 |
+
"""Test the router agent with various question types"""
|
| 19 |
+
|
| 20 |
+
print("🧭 GAIA Router Agent Test")
|
| 21 |
+
print("=" * 40)
|
| 22 |
+
|
| 23 |
+
# Initialize LLM client and router
|
| 24 |
+
try:
|
| 25 |
+
llm_client = QwenClient()
|
| 26 |
+
router = RouterAgent(llm_client)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"❌ Failed to initialize router: {e}")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
# Test cases covering all question types
|
| 32 |
+
test_cases = [
|
| 33 |
+
{
|
| 34 |
+
"question": "What is the capital of France?",
|
| 35 |
+
"expected_type": [QuestionType.WIKIPEDIA, QuestionType.WEB_RESEARCH, QuestionType.UNKNOWN], # Allow multiple valid types
|
| 36 |
+
"expected_agents": [AgentRole.WEB_RESEARCHER]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"question": "Calculate 25% of 400 and add 50",
|
| 40 |
+
"expected_type": [QuestionType.MATHEMATICAL],
|
| 41 |
+
"expected_agents": [AgentRole.REASONING_AGENT]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"question": "What information can you extract from this CSV file?",
|
| 45 |
+
"expected_type": [QuestionType.FILE_PROCESSING],
|
| 46 |
+
"expected_agents": [AgentRole.FILE_PROCESSOR],
|
| 47 |
+
"has_file": True
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"question": "Search for recent news about artificial intelligence",
|
| 51 |
+
"expected_type": [QuestionType.WEB_RESEARCH],
|
| 52 |
+
"expected_agents": [AgentRole.WEB_RESEARCHER]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"question": "What does this Python code do and how can it be improved?",
|
| 56 |
+
"expected_type": [QuestionType.CODE_EXECUTION, QuestionType.FILE_PROCESSING], # Both are valid
|
| 57 |
+
"expected_agents": [AgentRole.FILE_PROCESSOR, AgentRole.CODE_EXECUTOR], # Either is acceptable
|
| 58 |
+
"has_file": True
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
results = []
|
| 63 |
+
|
| 64 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 65 |
+
print(f"\n--- Test {i}: {test_case['question'][:50]}... ---")
|
| 66 |
+
|
| 67 |
+
# Create state
|
| 68 |
+
state = GAIAAgentState()
|
| 69 |
+
state.question = test_case["question"]
|
| 70 |
+
if test_case.get("has_file"):
|
| 71 |
+
state.file_name = "test_file.csv"
|
| 72 |
+
state.file_path = "/tmp/test_file.csv"
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Process with router
|
| 76 |
+
result_state = router.route_question(state)
|
| 77 |
+
|
| 78 |
+
# Check results
|
| 79 |
+
type_correct = result_state.question_type in test_case["expected_type"]
|
| 80 |
+
agents_correct = any(agent in result_state.selected_agents for agent in test_case["expected_agents"])
|
| 81 |
+
|
| 82 |
+
success = type_correct and agents_correct
|
| 83 |
+
results.append(success)
|
| 84 |
+
|
| 85 |
+
print(f" Question Type: {result_state.question_type.value} ({'✅' if type_correct else '❌'})")
|
| 86 |
+
print(f" Selected Agents: {[a.value for a in result_state.selected_agents]} ({'✅' if agents_correct else '❌'})")
|
| 87 |
+
print(f" Complexity: {result_state.complexity_assessment}")
|
| 88 |
+
print(f" Overall: {'✅ PASS' if success else '❌ FAIL'}")
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f" ❌ FAIL: {e}")
|
| 92 |
+
results.append(False)
|
| 93 |
+
|
| 94 |
+
# Summary
|
| 95 |
+
passed = sum(results)
|
| 96 |
+
total = len(results)
|
| 97 |
+
pass_rate = (passed / total) * 100
|
| 98 |
+
|
| 99 |
+
print("\n" + "=" * 40)
|
| 100 |
+
print(f"🎯 ROUTER RESULTS: {passed}/{total} ({pass_rate:.1f}%)")
|
| 101 |
+
|
| 102 |
+
if pass_rate >= 80:
|
| 103 |
+
print("🚀 Router working correctly!")
|
| 104 |
+
return True
|
| 105 |
+
else:
|
| 106 |
+
print("⚠️ Router needs improvement")
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
success = test_router_agent()
|
| 111 |
+
sys.exit(0 if success else 1)
|
src/test_workflow.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Complete Workflow Test for GAIA Agent System
|
| 4 |
+
Tests both LangGraph and simplified workflow implementations
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add src to path for imports
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 15 |
+
|
| 16 |
+
from workflow.gaia_workflow import GAIAWorkflow, SimpleGAIAWorkflow
|
| 17 |
+
from models.qwen_client import QwenClient
|
| 18 |
+
|
| 19 |
+
def test_simple_workflow():
|
| 20 |
+
"""Test the simplified workflow implementation"""
|
| 21 |
+
|
| 22 |
+
print("🧪 Testing Simple GAIA Workflow")
|
| 23 |
+
print("=" * 50)
|
| 24 |
+
|
| 25 |
+
# Initialize workflow
|
| 26 |
+
try:
|
| 27 |
+
llm_client = QwenClient()
|
| 28 |
+
workflow = SimpleGAIAWorkflow(llm_client)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
print(f"❌ Failed to initialize workflow: {e}")
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
# Test cases
|
| 34 |
+
test_cases = [
|
| 35 |
+
{
|
| 36 |
+
"question": "What is the capital of France?",
|
| 37 |
+
"description": "Simple web research question",
|
| 38 |
+
"expected_agents": ["web_researcher"]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"question": "Calculate 25% of 200",
|
| 42 |
+
"description": "Mathematical reasoning question",
|
| 43 |
+
"expected_agents": ["reasoning_agent"]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"question": "What is the average of 10, 15, 20?",
|
| 47 |
+
"description": "Statistical calculation",
|
| 48 |
+
"expected_agents": ["reasoning_agent"]
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
results = []
|
| 53 |
+
total_cost = 0.0
|
| 54 |
+
start_time = time.time()
|
| 55 |
+
|
| 56 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 57 |
+
print(f"\n🔍 Test {i}: {test_case['description']}")
|
| 58 |
+
print(f" Question: {test_case['question']}")
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# Process question
|
| 62 |
+
result_state = workflow.process_question(
|
| 63 |
+
question=test_case["question"],
|
| 64 |
+
task_id=f"simple_test_{i}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Check results
|
| 68 |
+
success = result_state.is_complete and result_state.final_answer
|
| 69 |
+
confidence = result_state.final_confidence
|
| 70 |
+
cost = result_state.total_cost
|
| 71 |
+
|
| 72 |
+
print(f" ✅ Router: {result_state.question_type.value}")
|
| 73 |
+
print(f" ✅ Agents: {[a.value for a in result_state.selected_agents]}")
|
| 74 |
+
print(f" ✅ Final Answer: {result_state.final_answer[:100]}...")
|
| 75 |
+
print(f" 📊 Confidence: {confidence:.2f}")
|
| 76 |
+
print(f" 💰 Cost: ${cost:.4f}")
|
| 77 |
+
print(f" 🎯 Success: {'✅ PASS' if success else '❌ FAIL'}")
|
| 78 |
+
|
| 79 |
+
total_cost += cost
|
| 80 |
+
results.append(bool(success))
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f" ❌ Test failed: {e}")
|
| 84 |
+
results.append(False)
|
| 85 |
+
|
| 86 |
+
# Summary
|
| 87 |
+
total_time = time.time() - start_time
|
| 88 |
+
passed = sum(results)
|
| 89 |
+
total = len(results)
|
| 90 |
+
|
| 91 |
+
print(f"\n📊 Simple Workflow Results:")
|
| 92 |
+
print(f" 🎯 Tests Passed: {passed}/{total} ({passed/total*100:.1f}%)")
|
| 93 |
+
print(f" 💰 Total Cost: ${total_cost:.4f}")
|
| 94 |
+
print(f" ⏱️ Total Time: {total_time:.2f}s")
|
| 95 |
+
|
| 96 |
+
return passed >= total * 0.8 # 80% success rate
|
| 97 |
+
|
| 98 |
+
def test_complete_workflow_with_files():
|
| 99 |
+
"""Test workflow with file processing"""
|
| 100 |
+
|
| 101 |
+
print("\n🧪 Testing Complete Workflow with Files")
|
| 102 |
+
print("=" * 50)
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
llm_client = QwenClient()
|
| 106 |
+
workflow = SimpleGAIAWorkflow(llm_client)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"❌ Failed to initialize workflow: {e}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
# Create test file
|
| 112 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 113 |
+
csv_path = os.path.join(temp_dir, "test_data.csv")
|
| 114 |
+
with open(csv_path, 'w') as f:
|
| 115 |
+
f.write("item,quantity,price\nApple,10,1.50\nBanana,20,0.75\nOrange,15,2.00")
|
| 116 |
+
|
| 117 |
+
print(f"📁 Created test file: {csv_path}")
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
result_state = workflow.process_question(
|
| 121 |
+
question="What is the total value of all items in this data?",
|
| 122 |
+
file_path=csv_path,
|
| 123 |
+
file_name="test_data.csv",
|
| 124 |
+
task_id="file_test"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
success = result_state.is_complete and result_state.final_answer
|
| 128 |
+
|
| 129 |
+
print(f" ✅ Router: {result_state.question_type.value}")
|
| 130 |
+
print(f" ✅ Agents: {[a.value for a in result_state.selected_agents]}")
|
| 131 |
+
print(f" ✅ Final Answer: {result_state.final_answer[:150]}...")
|
| 132 |
+
print(f" 📊 Confidence: {result_state.final_confidence:.2f}")
|
| 133 |
+
print(f" 💰 Cost: ${result_state.total_cost:.4f}")
|
| 134 |
+
print(f" 🎯 File Processing: {'✅ PASS' if success else '❌ FAIL'}")
|
| 135 |
+
|
| 136 |
+
return bool(success)
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f" ❌ File test failed: {e}")
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
def test_workflow_error_handling():
|
| 143 |
+
"""Test workflow error handling and edge cases"""
|
| 144 |
+
|
| 145 |
+
print("\n🧪 Testing Workflow Error Handling")
|
| 146 |
+
print("=" * 50)
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
llm_client = QwenClient()
|
| 150 |
+
workflow = SimpleGAIAWorkflow(llm_client)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"❌ Failed to initialize workflow: {e}")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
# Test cases that might cause errors
|
| 156 |
+
error_test_cases = [
|
| 157 |
+
{
|
| 158 |
+
"question": "", # Empty question
|
| 159 |
+
"description": "Empty question"
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"question": "x" * 5000, # Very long question
|
| 163 |
+
"description": "Extremely long question"
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"question": "What is this file about?",
|
| 167 |
+
"file_path": "/nonexistent/file.txt", # Non-existent file
|
| 168 |
+
"description": "Non-existent file"
|
| 169 |
+
}
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
results = []
|
| 173 |
+
|
| 174 |
+
for i, test_case in enumerate(error_test_cases, 1):
|
| 175 |
+
print(f"\n🔍 Error Test {i}: {test_case['description']}")
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
result_state = workflow.process_question(
|
| 179 |
+
question=test_case["question"],
|
| 180 |
+
file_path=test_case.get("file_path"),
|
| 181 |
+
task_id=f"error_test_{i}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Check if error was handled gracefully
|
| 185 |
+
graceful_handling = (
|
| 186 |
+
result_state.is_complete and
|
| 187 |
+
result_state.final_answer and
|
| 188 |
+
not result_state.final_answer.startswith("Traceback")
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print(f" ✅ Graceful Handling: {'✅ PASS' if graceful_handling else '❌ FAIL'}")
|
| 192 |
+
print(f" ✅ Error Messages: {len(result_state.error_messages)}")
|
| 193 |
+
print(f" ✅ Final Answer: {result_state.final_answer[:100]}...")
|
| 194 |
+
|
| 195 |
+
results.append(graceful_handling)
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f" ❌ Unhandled exception: {e}")
|
| 199 |
+
results.append(False)
|
| 200 |
+
|
| 201 |
+
passed = sum(results)
|
| 202 |
+
total = len(results)
|
| 203 |
+
|
| 204 |
+
print(f"\n📊 Error Handling Results:")
|
| 205 |
+
print(f" 🎯 Tests Passed: {passed}/{total} ({passed/total*100:.1f}%)")
|
| 206 |
+
|
| 207 |
+
return passed >= total * 0.8
|
| 208 |
+
|
| 209 |
+
def test_workflow_state_management():
|
| 210 |
+
"""Test workflow state management and tracking"""
|
| 211 |
+
|
| 212 |
+
print("\n🧪 Testing Workflow State Management")
|
| 213 |
+
print("=" * 50)
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
llm_client = QwenClient()
|
| 217 |
+
workflow = SimpleGAIAWorkflow(llm_client)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"❌ Failed to initialize workflow: {e}")
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
result_state = workflow.process_question(
|
| 224 |
+
question="What is the square root of 144?",
|
| 225 |
+
task_id="state_test"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Verify state completeness
|
| 229 |
+
state_checks = {
|
| 230 |
+
"has_task_id": bool(result_state.task_id),
|
| 231 |
+
"has_question": bool(result_state.question),
|
| 232 |
+
"has_routing_decision": bool(result_state.routing_decision),
|
| 233 |
+
"has_processing_steps": len(result_state.processing_steps) > 0,
|
| 234 |
+
"has_final_answer": bool(result_state.final_answer),
|
| 235 |
+
"is_complete": result_state.is_complete,
|
| 236 |
+
"has_cost_tracking": result_state.total_cost >= 0,
|
| 237 |
+
"has_timing": result_state.total_processing_time >= 0
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
print(" 📊 State Management Checks:")
|
| 241 |
+
for check, passed in state_checks.items():
|
| 242 |
+
status = "✅" if passed else "❌"
|
| 243 |
+
print(f" {status} {check}: {passed}")
|
| 244 |
+
|
| 245 |
+
# Check state summary
|
| 246 |
+
summary = result_state.get_summary()
|
| 247 |
+
print(f"\n 📋 State Summary:")
|
| 248 |
+
for key, value in summary.items():
|
| 249 |
+
print(f" {key}: {value}")
|
| 250 |
+
|
| 251 |
+
# Verify processing steps
|
| 252 |
+
print(f"\n 🔄 Processing Steps ({len(result_state.processing_steps)}):")
|
| 253 |
+
for i, step in enumerate(result_state.processing_steps[-5:], 1): # Last 5 steps
|
| 254 |
+
print(f" {i}. {step}")
|
| 255 |
+
|
| 256 |
+
all_passed = all(state_checks.values())
|
| 257 |
+
print(f"\n 🎯 State Management: {'✅ PASS' if all_passed else '❌ FAIL'}")
|
| 258 |
+
|
| 259 |
+
return all_passed
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f" ❌ State test failed: {e}")
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
def main():
|
| 266 |
+
"""Run all workflow tests"""
|
| 267 |
+
|
| 268 |
+
print("🚀 GAIA Workflow Integration Tests")
|
| 269 |
+
print("=" * 60)
|
| 270 |
+
|
| 271 |
+
test_results = []
|
| 272 |
+
start_time = time.time()
|
| 273 |
+
|
| 274 |
+
# Run all tests
|
| 275 |
+
test_results.append(test_simple_workflow())
|
| 276 |
+
test_results.append(test_complete_workflow_with_files())
|
| 277 |
+
test_results.append(test_workflow_error_handling())
|
| 278 |
+
test_results.append(test_workflow_state_management())
|
| 279 |
+
|
| 280 |
+
# Summary
|
| 281 |
+
total_time = time.time() - start_time
|
| 282 |
+
passed = sum(test_results)
|
| 283 |
+
total = len(test_results)
|
| 284 |
+
|
| 285 |
+
print("\n" + "=" * 60)
|
| 286 |
+
print("📊 COMPLETE WORKFLOW TEST RESULTS")
|
| 287 |
+
print("=" * 60)
|
| 288 |
+
print(f"🎯 Test Suites Passed: {passed}/{total} ({passed/total*100:.1f}%)")
|
| 289 |
+
print(f"⏱️ Total Time: {total_time:.2f} seconds")
|
| 290 |
+
|
| 291 |
+
# Test breakdown
|
| 292 |
+
test_names = [
|
| 293 |
+
"Simple Workflow",
|
| 294 |
+
"File Processing",
|
| 295 |
+
"Error Handling",
|
| 296 |
+
"State Management"
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
print(f"\n📋 Test Breakdown:")
|
| 300 |
+
for i, (name, result) in enumerate(zip(test_names, test_results)):
|
| 301 |
+
status = "✅" if result else "❌"
|
| 302 |
+
print(f" {status} {name}")
|
| 303 |
+
|
| 304 |
+
if passed == total:
|
| 305 |
+
print("\n🚀 ALL WORKFLOW TESTS PASSED! System ready for production!")
|
| 306 |
+
return True
|
| 307 |
+
elif passed >= total * 0.8:
|
| 308 |
+
print("\n✅ MOST TESTS PASSED! System functional with minor issues.")
|
| 309 |
+
return True
|
| 310 |
+
else:
|
| 311 |
+
print("\n⚠️ SIGNIFICANT ISSUES! Review failures above.")
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
success = main()
|
| 316 |
+
sys.exit(0 if success else 1)
|
src/tools/__init__.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Tool System for GAIA Agent Framework
|
| 4 |
+
Provides base classes and interfaces for all agent tools
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
# Use existing ToolResult from agents.state
|
| 14 |
+
from agents.state import ToolResult
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class BaseTool(ABC):
|
| 19 |
+
"""
|
| 20 |
+
Base class for all agent tools
|
| 21 |
+
Provides consistent interface and error handling
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, name: str):
|
| 25 |
+
self.name = name
|
| 26 |
+
self.usage_count = 0
|
| 27 |
+
self.total_execution_time = 0.0
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def _execute_impl(self, input_data: Any, **kwargs) -> Any:
|
| 31 |
+
"""
|
| 32 |
+
Implementation-specific execution logic
|
| 33 |
+
Override this method in subclasses
|
| 34 |
+
"""
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def execute(self, input_data: Any, **kwargs) -> ToolResult:
|
| 38 |
+
"""
|
| 39 |
+
Execute the tool with error handling and metrics tracking
|
| 40 |
+
"""
|
| 41 |
+
start_time = time.time()
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
logger.info(f"Executing tool: {self.name}")
|
| 45 |
+
result = self._execute_impl(input_data, **kwargs)
|
| 46 |
+
|
| 47 |
+
execution_time = time.time() - start_time
|
| 48 |
+
self.usage_count += 1
|
| 49 |
+
self.total_execution_time += execution_time
|
| 50 |
+
|
| 51 |
+
logger.info(f"✅ Tool {self.name} completed in {execution_time:.2f}s")
|
| 52 |
+
|
| 53 |
+
return ToolResult(
|
| 54 |
+
tool_name=self.name,
|
| 55 |
+
success=True,
|
| 56 |
+
result=result,
|
| 57 |
+
execution_time=execution_time,
|
| 58 |
+
metadata={
|
| 59 |
+
"input_type": type(input_data).__name__,
|
| 60 |
+
"usage_count": self.usage_count
|
| 61 |
+
}
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
execution_time = time.time() - start_time
|
| 66 |
+
error_msg = f"Tool {self.name} failed: {str(e)}"
|
| 67 |
+
logger.error(f"❌ {error_msg}")
|
| 68 |
+
|
| 69 |
+
return ToolResult(
|
| 70 |
+
tool_name=self.name,
|
| 71 |
+
success=False,
|
| 72 |
+
result=None,
|
| 73 |
+
error=error_msg,
|
| 74 |
+
execution_time=execution_time
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 78 |
+
"""Get usage statistics for this tool"""
|
| 79 |
+
return {
|
| 80 |
+
"name": self.name,
|
| 81 |
+
"usage_count": self.usage_count,
|
| 82 |
+
"total_execution_time": self.total_execution_time,
|
| 83 |
+
"average_execution_time": self.total_execution_time / max(self.usage_count, 1)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
__all__ = ['BaseTool', 'ToolResult']
|
src/tools/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.46 kB). View file
|
|
|
src/tools/__pycache__/calculator.cpython-310.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src/tools/__pycache__/file_processor.cpython-310.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
src/tools/__pycache__/web_search_tool.cpython-310.pyc
ADDED
|
Binary file (9.17 kB). View file
|
|
|
src/tools/__pycache__/wikipedia_tool.cpython-310.pyc
ADDED
|
Binary file (7.98 kB). View file
|
|
|
src/tools/calculator.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Calculator Tool for GAIA Agent System
|
| 4 |
+
Handles mathematical calculations, unit conversions, and statistical operations
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import math
|
| 9 |
+
import statistics
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Dict, List, Optional, Any, Union
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
from tools import BaseTool
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class CalculationResult:
|
| 20 |
+
"""Container for calculation results"""
|
| 21 |
+
expression: str
|
| 22 |
+
result: Union[float, int, str]
|
| 23 |
+
result_type: str
|
| 24 |
+
steps: List[str]
|
| 25 |
+
units: Optional[str] = None
|
| 26 |
+
|
| 27 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 28 |
+
return {
|
| 29 |
+
"expression": self.expression,
|
| 30 |
+
"result": self.result,
|
| 31 |
+
"result_type": self.result_type,
|
| 32 |
+
"steps": self.steps,
|
| 33 |
+
"units": self.units
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
class CalculatorTool(BaseTool):
|
| 37 |
+
"""
|
| 38 |
+
Calculator tool for mathematical operations
|
| 39 |
+
Supports basic math, advanced functions, statistics, and unit conversions
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self):
|
| 43 |
+
super().__init__("calculator")
|
| 44 |
+
|
| 45 |
+
# Safe mathematical functions
|
| 46 |
+
self.safe_functions = {
|
| 47 |
+
# Basic functions
|
| 48 |
+
'abs': abs, 'round': round, 'min': min, 'max': max,
|
| 49 |
+
'sum': sum, 'len': len,
|
| 50 |
+
|
| 51 |
+
# Math module functions
|
| 52 |
+
'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
|
| 53 |
+
'asin': math.asin, 'acos': math.acos, 'atan': math.atan,
|
| 54 |
+
'sinh': math.sinh, 'cosh': math.cosh, 'tanh': math.tanh,
|
| 55 |
+
'exp': math.exp, 'log': math.log, 'log10': math.log10,
|
| 56 |
+
'sqrt': math.sqrt, 'pow': pow, 'ceil': math.ceil, 'floor': math.floor,
|
| 57 |
+
'factorial': math.factorial, 'gcd': math.gcd,
|
| 58 |
+
|
| 59 |
+
# Constants
|
| 60 |
+
'pi': math.pi, 'e': math.e,
|
| 61 |
+
|
| 62 |
+
# Statistics functions
|
| 63 |
+
'mean': statistics.mean, 'median': statistics.median,
|
| 64 |
+
'mode': statistics.mode, 'stdev': statistics.stdev,
|
| 65 |
+
'variance': statistics.variance
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Unit conversion factors (to base units)
|
| 69 |
+
self.unit_conversions = {
|
| 70 |
+
# Length (to meters)
|
| 71 |
+
'length': {
|
| 72 |
+
'mm': 0.001, 'cm': 0.01, 'dm': 0.1, 'm': 1,
|
| 73 |
+
'km': 1000, 'in': 0.0254, 'ft': 0.3048,
|
| 74 |
+
'yd': 0.9144, 'mi': 1609.344
|
| 75 |
+
},
|
| 76 |
+
# Weight (to grams)
|
| 77 |
+
'weight': {
|
| 78 |
+
'mg': 0.001, 'g': 1, 'kg': 1000,
|
| 79 |
+
'oz': 28.3495, 'lb': 453.592, 'ton': 1000000
|
| 80 |
+
},
|
| 81 |
+
# Temperature (special handling)
|
| 82 |
+
'temperature': {
|
| 83 |
+
'celsius': 'celsius', 'fahrenheit': 'fahrenheit',
|
| 84 |
+
'kelvin': 'kelvin', 'c': 'celsius', 'f': 'fahrenheit', 'k': 'kelvin'
|
| 85 |
+
},
|
| 86 |
+
# Time (to seconds)
|
| 87 |
+
'time': {
|
| 88 |
+
's': 1, 'min': 60, 'h': 3600, 'hr': 3600,
|
| 89 |
+
'day': 86400, 'week': 604800, 'month': 2629746, 'year': 31556952
|
| 90 |
+
},
|
| 91 |
+
# Area (to square meters)
|
| 92 |
+
'area': {
|
| 93 |
+
'mm2': 0.000001, 'cm2': 0.0001, 'm2': 1,
|
| 94 |
+
'km2': 1000000, 'in2': 0.00064516, 'ft2': 0.092903
|
| 95 |
+
},
|
| 96 |
+
# Volume (to liters)
|
| 97 |
+
'volume': {
|
| 98 |
+
'ml': 0.001, 'l': 1, 'gal': 3.78541, 'qt': 0.946353,
|
| 99 |
+
'pt': 0.473176, 'cup': 0.236588, 'fl_oz': 0.0295735
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
def _execute_impl(self, input_data: Any, **kwargs) -> Dict[str, Any]:
|
| 104 |
+
"""
|
| 105 |
+
Execute calculator operations based on input type
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
input_data: Can be:
|
| 109 |
+
- str: Mathematical expression
|
| 110 |
+
- dict: {"expression": str, "operation": str, "data": list, "units": dict}
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
if isinstance(input_data, str):
|
| 114 |
+
return self._evaluate_expression(input_data)
|
| 115 |
+
|
| 116 |
+
elif isinstance(input_data, dict):
|
| 117 |
+
operation = input_data.get("operation", "evaluate")
|
| 118 |
+
|
| 119 |
+
if operation == "evaluate":
|
| 120 |
+
expression = input_data.get("expression", "")
|
| 121 |
+
return self._evaluate_expression(expression)
|
| 122 |
+
elif operation == "statistics":
|
| 123 |
+
data = input_data.get("data", [])
|
| 124 |
+
return self._calculate_statistics(data)
|
| 125 |
+
elif operation == "convert":
|
| 126 |
+
value = input_data.get("value", 0)
|
| 127 |
+
from_unit = input_data.get("from_unit", "")
|
| 128 |
+
to_unit = input_data.get("to_unit", "")
|
| 129 |
+
return self._convert_units(value, from_unit, to_unit)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Unknown operation: {operation}")
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Unsupported input type: {type(input_data)}")
|
| 134 |
+
|
| 135 |
+
def _evaluate_expression(self, expression: str) -> Dict[str, Any]:
|
| 136 |
+
"""
|
| 137 |
+
Safely evaluate a mathematical expression
|
| 138 |
+
"""
|
| 139 |
+
try:
|
| 140 |
+
# Clean the expression
|
| 141 |
+
original_expression = expression
|
| 142 |
+
expression = self._clean_expression(expression)
|
| 143 |
+
|
| 144 |
+
steps = [f"Original: {original_expression}", f"Cleaned: {expression}"]
|
| 145 |
+
|
| 146 |
+
# Check for unit conversion patterns
|
| 147 |
+
unit_match = re.search(r'(\d+\.?\d*)\s*(\w+)\s+to\s+(\w+)', expression)
|
| 148 |
+
if unit_match:
|
| 149 |
+
value, from_unit, to_unit = unit_match.groups()
|
| 150 |
+
return self._convert_units(float(value), from_unit, to_unit)
|
| 151 |
+
|
| 152 |
+
# Replace common mathematical expressions
|
| 153 |
+
expression = self._replace_math_expressions(expression)
|
| 154 |
+
steps.append(f"With functions: {expression}")
|
| 155 |
+
|
| 156 |
+
# Validate expression safety
|
| 157 |
+
if not self._is_safe_expression(expression):
|
| 158 |
+
raise ValueError("Expression contains unsafe operations")
|
| 159 |
+
|
| 160 |
+
# Create safe evaluation environment
|
| 161 |
+
safe_dict = {
|
| 162 |
+
"__builtins__": {},
|
| 163 |
+
**self.safe_functions
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Evaluate the expression
|
| 167 |
+
result = eval(expression, safe_dict)
|
| 168 |
+
|
| 169 |
+
# Determine result type and format
|
| 170 |
+
if isinstance(result, (int, float)):
|
| 171 |
+
if result == int(result):
|
| 172 |
+
result = int(result)
|
| 173 |
+
result_type = "integer"
|
| 174 |
+
else:
|
| 175 |
+
result = round(result, 10) # Avoid floating point errors
|
| 176 |
+
result_type = "float"
|
| 177 |
+
else:
|
| 178 |
+
result_type = type(result).__name__
|
| 179 |
+
|
| 180 |
+
calc_result = CalculationResult(
|
| 181 |
+
expression=original_expression,
|
| 182 |
+
result=result,
|
| 183 |
+
result_type=result_type,
|
| 184 |
+
steps=steps
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"success": True,
|
| 189 |
+
"calculation": calc_result.to_dict(),
|
| 190 |
+
"message": f"Successfully evaluated: {result}"
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
return {
|
| 195 |
+
"success": False,
|
| 196 |
+
"expression": expression,
|
| 197 |
+
"message": f"Calculation failed: {str(e)}",
|
| 198 |
+
"error_type": type(e).__name__
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
def _clean_expression(self, expression: str) -> str:
|
| 202 |
+
"""Clean and normalize mathematical expression"""
|
| 203 |
+
# Remove extra whitespace
|
| 204 |
+
expression = re.sub(r'\s+', ' ', expression.strip())
|
| 205 |
+
|
| 206 |
+
# Replace common text with symbols
|
| 207 |
+
replacements = {
|
| 208 |
+
' plus ': '+', ' minus ': '-', ' times ': '*', ' multiply ': '*',
|
| 209 |
+
' divided by ': '/', ' divide ': '/', ' power ': '**', ' to the power of ': '**'
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
for text, symbol in replacements.items():
|
| 213 |
+
expression = expression.replace(text, symbol)
|
| 214 |
+
|
| 215 |
+
# Handle percentage
|
| 216 |
+
expression = re.sub(r'(\d+\.?\d*)%', r'(\1/100)', expression)
|
| 217 |
+
|
| 218 |
+
return expression
|
| 219 |
+
|
| 220 |
+
def _replace_math_expressions(self, expression: str) -> str:
|
| 221 |
+
"""Replace mathematical function names with proper calls"""
|
| 222 |
+
# Handle square root
|
| 223 |
+
expression = re.sub(r'sqrt\s*\(([^)]+)\)', r'sqrt(\1)', expression)
|
| 224 |
+
expression = re.sub(r'square root of (\d+\.?\d*)', r'sqrt(\1)', expression)
|
| 225 |
+
|
| 226 |
+
# Handle logarithms
|
| 227 |
+
expression = re.sub(r'log\s*\(([^)]+)\)', r'log(\1)', expression)
|
| 228 |
+
expression = re.sub(r'ln\s*\(([^)]+)\)', r'log(\1)', expression)
|
| 229 |
+
|
| 230 |
+
# Handle trigonometric functions
|
| 231 |
+
trig_functions = ['sin', 'cos', 'tan', 'asin', 'acos', 'atan']
|
| 232 |
+
for func in trig_functions:
|
| 233 |
+
expression = re.sub(f'{func}\\s*\\(([^)]+)\\)', f'{func}(\\1)', expression)
|
| 234 |
+
|
| 235 |
+
return expression
|
| 236 |
+
|
| 237 |
+
def _is_safe_expression(self, expression: str) -> bool:
|
| 238 |
+
"""Check if expression is safe to evaluate"""
|
| 239 |
+
# Forbidden patterns
|
| 240 |
+
forbidden_patterns = [
|
| 241 |
+
r'__.*__', # Dunder methods
|
| 242 |
+
r'import\s', # Import statements
|
| 243 |
+
r'exec\s*\(', # Exec function
|
| 244 |
+
r'eval\s*\(', # Eval function
|
| 245 |
+
r'open\s*\(', # File operations
|
| 246 |
+
r'file\s*\(', # File operations
|
| 247 |
+
r'input\s*\(', # Input function
|
| 248 |
+
r'raw_input\s*\(', # Raw input
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
for pattern in forbidden_patterns:
|
| 252 |
+
if re.search(pattern, expression, re.IGNORECASE):
|
| 253 |
+
return False
|
| 254 |
+
|
| 255 |
+
return True
|
| 256 |
+
|
| 257 |
+
def _calculate_statistics(self, data: List[float]) -> Dict[str, Any]:
|
| 258 |
+
"""Calculate statistical measures for a dataset"""
|
| 259 |
+
try:
|
| 260 |
+
if not data:
|
| 261 |
+
raise ValueError("Empty dataset provided")
|
| 262 |
+
|
| 263 |
+
data = [float(x) for x in data] # Ensure all values are numeric
|
| 264 |
+
|
| 265 |
+
stats = {
|
| 266 |
+
"count": len(data),
|
| 267 |
+
"sum": sum(data),
|
| 268 |
+
"mean": statistics.mean(data),
|
| 269 |
+
"median": statistics.median(data),
|
| 270 |
+
"min": min(data),
|
| 271 |
+
"max": max(data),
|
| 272 |
+
"range": max(data) - min(data)
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
# Add standard deviation and variance if enough data points
|
| 276 |
+
if len(data) > 1:
|
| 277 |
+
stats["stdev"] = statistics.stdev(data)
|
| 278 |
+
stats["variance"] = statistics.variance(data)
|
| 279 |
+
|
| 280 |
+
# Add mode if applicable
|
| 281 |
+
try:
|
| 282 |
+
stats["mode"] = statistics.mode(data)
|
| 283 |
+
except statistics.StatisticsError:
|
| 284 |
+
stats["mode"] = "No unique mode"
|
| 285 |
+
|
| 286 |
+
return {
|
| 287 |
+
"success": True,
|
| 288 |
+
"statistics": stats,
|
| 289 |
+
"data": data,
|
| 290 |
+
"message": f"Calculated statistics for {len(data)} data points"
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
return {
|
| 295 |
+
"success": False,
|
| 296 |
+
"message": f"Statistics calculation failed: {str(e)}",
|
| 297 |
+
"error_type": type(e).__name__
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
def _convert_units(self, value: float, from_unit: str, to_unit: str) -> Dict[str, Any]:
|
| 301 |
+
"""Convert between different units"""
|
| 302 |
+
try:
|
| 303 |
+
from_unit = from_unit.lower()
|
| 304 |
+
to_unit = to_unit.lower()
|
| 305 |
+
|
| 306 |
+
# Find the unit type
|
| 307 |
+
unit_type = None
|
| 308 |
+
for utype, units in self.unit_conversions.items():
|
| 309 |
+
if from_unit in units and to_unit in units:
|
| 310 |
+
unit_type = utype
|
| 311 |
+
break
|
| 312 |
+
|
| 313 |
+
if not unit_type:
|
| 314 |
+
raise ValueError(f"Cannot convert between {from_unit} and {to_unit}")
|
| 315 |
+
|
| 316 |
+
# Special handling for temperature
|
| 317 |
+
if unit_type == 'temperature':
|
| 318 |
+
result = self._convert_temperature(value, from_unit, to_unit)
|
| 319 |
+
else:
|
| 320 |
+
# Standard unit conversion
|
| 321 |
+
from_factor = self.unit_conversions[unit_type][from_unit]
|
| 322 |
+
to_factor = self.unit_conversions[unit_type][to_unit]
|
| 323 |
+
result = value * from_factor / to_factor
|
| 324 |
+
|
| 325 |
+
# Round to reasonable precision
|
| 326 |
+
if result == int(result):
|
| 327 |
+
result = int(result)
|
| 328 |
+
else:
|
| 329 |
+
result = round(result, 6)
|
| 330 |
+
|
| 331 |
+
conversion_result = CalculationResult(
|
| 332 |
+
expression=f"{value} {from_unit} to {to_unit}",
|
| 333 |
+
result=result,
|
| 334 |
+
result_type="conversion",
|
| 335 |
+
steps=[
|
| 336 |
+
f"Convert {value} {from_unit} to {to_unit}",
|
| 337 |
+
f"Result: {result} {to_unit}"
|
| 338 |
+
],
|
| 339 |
+
units=to_unit
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
return {
|
| 343 |
+
"success": True,
|
| 344 |
+
"conversion": conversion_result.to_dict(),
|
| 345 |
+
"message": f"Converted {value} {from_unit} = {result} {to_unit}"
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
return {
|
| 350 |
+
"success": False,
|
| 351 |
+
"message": f"Unit conversion failed: {str(e)}",
|
| 352 |
+
"error_type": type(e).__name__
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
def _convert_temperature(self, value: float, from_unit: str, to_unit: str) -> float:
|
| 356 |
+
"""Convert temperature between Celsius, Fahrenheit, and Kelvin"""
|
| 357 |
+
# Normalize unit names
|
| 358 |
+
unit_map = {'c': 'celsius', 'f': 'fahrenheit', 'k': 'kelvin'}
|
| 359 |
+
from_unit = unit_map.get(from_unit, from_unit)
|
| 360 |
+
to_unit = unit_map.get(to_unit, to_unit)
|
| 361 |
+
|
| 362 |
+
# Convert to Celsius first
|
| 363 |
+
if from_unit == 'fahrenheit':
|
| 364 |
+
celsius = (value - 32) * 5/9
|
| 365 |
+
elif from_unit == 'kelvin':
|
| 366 |
+
celsius = value - 273.15
|
| 367 |
+
else: # Already Celsius
|
| 368 |
+
celsius = value
|
| 369 |
+
|
| 370 |
+
# Convert from Celsius to target unit
|
| 371 |
+
if to_unit == 'fahrenheit':
|
| 372 |
+
return celsius * 9/5 + 32
|
| 373 |
+
elif to_unit == 'kelvin':
|
| 374 |
+
return celsius + 273.15
|
| 375 |
+
else: # Stay in Celsius
|
| 376 |
+
return celsius
|
| 377 |
+
|
| 378 |
+
def test_calculator_tool():
|
| 379 |
+
"""Test the calculator tool with various operations"""
|
| 380 |
+
tool = CalculatorTool()
|
| 381 |
+
|
| 382 |
+
# Test cases
|
| 383 |
+
test_cases = [
|
| 384 |
+
"2 + 3 * 4",
|
| 385 |
+
"sqrt(16) + 2^3",
|
| 386 |
+
"sin(pi/2) + cos(0)",
|
| 387 |
+
{"operation": "statistics", "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
|
| 388 |
+
{"operation": "convert", "value": 100, "from_unit": "cm", "to_unit": "m"},
|
| 389 |
+
{"operation": "convert", "value": 32, "from_unit": "f", "to_unit": "c"},
|
| 390 |
+
"10 factorial",
|
| 391 |
+
"mean([1, 2, 3, 4, 5])",
|
| 392 |
+
"15% of 200"
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
print("🧪 Testing Calculator Tool...")
|
| 396 |
+
|
| 397 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 398 |
+
print(f"\n--- Test {i}: {test_case} ---")
|
| 399 |
+
try:
|
| 400 |
+
result = tool.execute(test_case)
|
| 401 |
+
|
| 402 |
+
if result.success:
|
| 403 |
+
if 'calculation' in result.result:
|
| 404 |
+
calc = result.result['calculation']
|
| 405 |
+
print(f"✅ Result: {calc['result']} ({calc['result_type']})")
|
| 406 |
+
elif 'statistics' in result.result:
|
| 407 |
+
stats = result.result['statistics']
|
| 408 |
+
print(f"✅ Mean: {stats['mean']}, Median: {stats['median']}, StDev: {stats.get('stdev', 'N/A')}")
|
| 409 |
+
elif 'conversion' in result.result:
|
| 410 |
+
conv = result.result['conversion']
|
| 411 |
+
print(f"✅ Conversion: {conv['result']} {conv['units']}")
|
| 412 |
+
print(f" Message: {result.result.get('message', 'No message')}")
|
| 413 |
+
else:
|
| 414 |
+
print(f"❌ Error: {result.result.get('message', 'Unknown error')}")
|
| 415 |
+
|
| 416 |
+
print(f" Execution time: {result.execution_time:.3f}s")
|
| 417 |
+
|
| 418 |
+
except Exception as e:
|
| 419 |
+
print(f"❌ Exception: {str(e)}")
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
# Test when run directly
|
| 423 |
+
test_calculator_tool()
|
src/tools/file_processor.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
File Processing Tool for GAIA Agent System
|
| 4 |
+
Handles multiple file formats: images, audio, Excel/CSV, Python code
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import io
|
| 10 |
+
import logging
|
| 11 |
+
import mimetypes
|
| 12 |
+
from typing import Dict, List, Optional, Any, Union
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import ast
|
| 17 |
+
|
| 18 |
+
from tools import BaseTool
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
class FileProcessingResult:
|
| 23 |
+
"""Container for file processing results"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, file_path: str, file_type: str, success: bool,
|
| 26 |
+
content: Any = None, metadata: Dict[str, Any] = None):
|
| 27 |
+
self.file_path = file_path
|
| 28 |
+
self.file_type = file_type
|
| 29 |
+
self.success = success
|
| 30 |
+
self.content = content
|
| 31 |
+
self.metadata = metadata or {}
|
| 32 |
+
|
| 33 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 34 |
+
return {
|
| 35 |
+
"file_path": self.file_path,
|
| 36 |
+
"file_type": self.file_type,
|
| 37 |
+
"success": self.success,
|
| 38 |
+
"content": self.content,
|
| 39 |
+
"metadata": self.metadata
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
class FileProcessorTool(BaseTool):
|
| 43 |
+
"""
|
| 44 |
+
File processor tool for multiple file formats
|
| 45 |
+
Supports images, audio, Excel/CSV, and Python code analysis
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self):
|
| 49 |
+
super().__init__("file_processor")
|
| 50 |
+
|
| 51 |
+
# Supported file types
|
| 52 |
+
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
| 53 |
+
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.flac', '.m4a', '.aac'}
|
| 54 |
+
self.data_extensions = {'.csv', '.xlsx', '.xls', '.json', '.txt'}
|
| 55 |
+
self.code_extensions = {'.py', '.js', '.java', '.cpp', '.c', '.html', '.css'}
|
| 56 |
+
|
| 57 |
+
def _execute_impl(self, input_data: Any, **kwargs) -> Dict[str, Any]:
|
| 58 |
+
"""
|
| 59 |
+
Execute file processing operations based on input type
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
input_data: Can be:
|
| 63 |
+
- str: File path to process
|
| 64 |
+
- dict: {"file_path": str, "operation": str, "options": dict}
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
if isinstance(input_data, str):
|
| 68 |
+
return self._process_file(input_data)
|
| 69 |
+
|
| 70 |
+
elif isinstance(input_data, dict):
|
| 71 |
+
file_path = input_data.get("file_path", "")
|
| 72 |
+
operation = input_data.get("operation", "auto")
|
| 73 |
+
options = input_data.get("options", {})
|
| 74 |
+
|
| 75 |
+
if operation == "auto":
|
| 76 |
+
return self._process_file(file_path, **options)
|
| 77 |
+
elif operation == "analyze_image":
|
| 78 |
+
return self._analyze_image(file_path, **options)
|
| 79 |
+
elif operation == "process_data":
|
| 80 |
+
return self._process_data_file(file_path, **options)
|
| 81 |
+
elif operation == "analyze_code":
|
| 82 |
+
return self._analyze_code(file_path, **options)
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"Unknown operation: {operation}")
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unsupported input type: {type(input_data)}")
|
| 87 |
+
|
| 88 |
+
def _process_file(self, file_path: str, **options) -> Dict[str, Any]:
|
| 89 |
+
"""
|
| 90 |
+
Auto-detect file type and process accordingly
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
if not os.path.exists(file_path):
|
| 94 |
+
return {
|
| 95 |
+
"success": False,
|
| 96 |
+
"message": f"File not found: {file_path}",
|
| 97 |
+
"error_type": "file_not_found"
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Detect file type
|
| 101 |
+
file_extension = Path(file_path).suffix.lower()
|
| 102 |
+
file_type = self._detect_file_type(file_path, file_extension)
|
| 103 |
+
|
| 104 |
+
logger.info(f"Processing {file_type} file: {file_path}")
|
| 105 |
+
|
| 106 |
+
# Route to appropriate processor
|
| 107 |
+
if file_type == "image":
|
| 108 |
+
return self._analyze_image(file_path, **options)
|
| 109 |
+
elif file_type == "audio":
|
| 110 |
+
return self._analyze_audio(file_path, **options)
|
| 111 |
+
elif file_type == "data":
|
| 112 |
+
return self._process_data_file(file_path, **options)
|
| 113 |
+
elif file_type == "code":
|
| 114 |
+
return self._analyze_code(file_path, **options)
|
| 115 |
+
elif file_type == "text":
|
| 116 |
+
return self._process_text_file(file_path, **options)
|
| 117 |
+
else:
|
| 118 |
+
return {
|
| 119 |
+
"success": False,
|
| 120 |
+
"message": f"Unsupported file type: {file_type}",
|
| 121 |
+
"file_path": file_path,
|
| 122 |
+
"detected_type": file_type
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
return {
|
| 127 |
+
"success": False,
|
| 128 |
+
"message": f"File processing failed: {str(e)}",
|
| 129 |
+
"file_path": file_path,
|
| 130 |
+
"error_type": type(e).__name__
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def _detect_file_type(self, file_path: str, extension: str) -> str:
|
| 134 |
+
"""Detect file type based on extension and MIME type"""
|
| 135 |
+
|
| 136 |
+
if extension in self.image_extensions:
|
| 137 |
+
return "image"
|
| 138 |
+
elif extension in self.audio_extensions:
|
| 139 |
+
return "audio"
|
| 140 |
+
elif extension in self.data_extensions:
|
| 141 |
+
return "data"
|
| 142 |
+
elif extension in self.code_extensions:
|
| 143 |
+
return "code"
|
| 144 |
+
elif extension in {'.txt', '.md', '.rst'}:
|
| 145 |
+
return "text"
|
| 146 |
+
else:
|
| 147 |
+
# Try MIME type detection
|
| 148 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
| 149 |
+
if mime_type:
|
| 150 |
+
if mime_type.startswith('image/'):
|
| 151 |
+
return "image"
|
| 152 |
+
elif mime_type.startswith('audio/'):
|
| 153 |
+
return "audio"
|
| 154 |
+
elif mime_type.startswith('text/'):
|
| 155 |
+
return "text"
|
| 156 |
+
|
| 157 |
+
return "unknown"
|
| 158 |
+
|
| 159 |
+
def _analyze_image(self, file_path: str, **options) -> Dict[str, Any]:
|
| 160 |
+
"""
|
| 161 |
+
Analyze image files and extract metadata
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
with Image.open(file_path) as img:
|
| 165 |
+
# Basic image information
|
| 166 |
+
metadata = {
|
| 167 |
+
"format": img.format,
|
| 168 |
+
"mode": img.mode,
|
| 169 |
+
"size": img.size,
|
| 170 |
+
"width": img.width,
|
| 171 |
+
"height": img.height,
|
| 172 |
+
"file_size": os.path.getsize(file_path)
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# EXIF data if available
|
| 176 |
+
if hasattr(img, '_getexif') and img._getexif():
|
| 177 |
+
exif = img._getexif()
|
| 178 |
+
if exif:
|
| 179 |
+
metadata["exif_data"] = dict(list(exif.items())[:10]) # First 10 EXIF entries
|
| 180 |
+
|
| 181 |
+
# Color analysis
|
| 182 |
+
if img.mode in ['RGB', 'RGBA']:
|
| 183 |
+
colors = img.getcolors(maxcolors=10)
|
| 184 |
+
if colors:
|
| 185 |
+
dominant_colors = sorted(colors, reverse=True)[:5]
|
| 186 |
+
metadata["dominant_colors"] = [
|
| 187 |
+
{"count": count, "rgb": color}
|
| 188 |
+
for count, color in dominant_colors
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
# Basic content description
|
| 192 |
+
content_description = self._describe_image_content(img, metadata)
|
| 193 |
+
|
| 194 |
+
result = FileProcessingResult(
|
| 195 |
+
file_path=file_path,
|
| 196 |
+
file_type="image",
|
| 197 |
+
success=True,
|
| 198 |
+
content=content_description,
|
| 199 |
+
metadata=metadata
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"success": True,
|
| 204 |
+
"result": result.to_dict(),
|
| 205 |
+
"message": f"Successfully analyzed image: {img.width}x{img.height} {img.format}"
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
return {
|
| 210 |
+
"success": False,
|
| 211 |
+
"message": f"Image analysis failed: {str(e)}",
|
| 212 |
+
"file_path": file_path,
|
| 213 |
+
"error_type": type(e).__name__
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
def _describe_image_content(self, img: Image.Image, metadata: Dict[str, Any]) -> str:
|
| 217 |
+
"""Generate basic description of image content"""
|
| 218 |
+
description_parts = []
|
| 219 |
+
|
| 220 |
+
# Size description
|
| 221 |
+
width, height = img.size
|
| 222 |
+
if width > height:
|
| 223 |
+
orientation = "landscape"
|
| 224 |
+
elif height > width:
|
| 225 |
+
orientation = "portrait"
|
| 226 |
+
else:
|
| 227 |
+
orientation = "square"
|
| 228 |
+
|
| 229 |
+
description_parts.append(f"{orientation} {img.format} image")
|
| 230 |
+
description_parts.append(f"Dimensions: {width} x {height} pixels")
|
| 231 |
+
|
| 232 |
+
# Color information
|
| 233 |
+
if img.mode == 'RGB':
|
| 234 |
+
description_parts.append("Full color RGB image")
|
| 235 |
+
elif img.mode == 'RGBA':
|
| 236 |
+
description_parts.append("RGB image with transparency")
|
| 237 |
+
elif img.mode == 'L':
|
| 238 |
+
description_parts.append("Grayscale image")
|
| 239 |
+
elif img.mode == '1':
|
| 240 |
+
description_parts.append("Black and white image")
|
| 241 |
+
|
| 242 |
+
# File size
|
| 243 |
+
file_size = metadata.get("file_size", 0)
|
| 244 |
+
if file_size > 0:
|
| 245 |
+
size_mb = file_size / (1024 * 1024)
|
| 246 |
+
if size_mb >= 1:
|
| 247 |
+
description_parts.append(f"File size: {size_mb:.1f} MB")
|
| 248 |
+
else:
|
| 249 |
+
size_kb = file_size / 1024
|
| 250 |
+
description_parts.append(f"File size: {size_kb:.1f} KB")
|
| 251 |
+
|
| 252 |
+
return ". ".join(description_parts)
|
| 253 |
+
|
| 254 |
+
def _analyze_audio(self, file_path: str, **options) -> Dict[str, Any]:
|
| 255 |
+
"""
|
| 256 |
+
Analyze audio files (basic metadata for now)
|
| 257 |
+
"""
|
| 258 |
+
try:
|
| 259 |
+
# Basic file information
|
| 260 |
+
file_size = os.path.getsize(file_path)
|
| 261 |
+
file_extension = Path(file_path).suffix.lower()
|
| 262 |
+
|
| 263 |
+
metadata = {
|
| 264 |
+
"file_extension": file_extension,
|
| 265 |
+
"file_size": file_size,
|
| 266 |
+
"file_size_mb": round(file_size / (1024 * 1024), 2)
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
# For now, provide basic file info
|
| 270 |
+
# In a full implementation, you might use libraries like:
|
| 271 |
+
# - pydub for audio processing
|
| 272 |
+
# - speech_recognition for transcription
|
| 273 |
+
# - librosa for audio analysis
|
| 274 |
+
|
| 275 |
+
content_description = f"Audio file ({file_extension}) - {metadata['file_size_mb']} MB"
|
| 276 |
+
|
| 277 |
+
result = FileProcessingResult(
|
| 278 |
+
file_path=file_path,
|
| 279 |
+
file_type="audio",
|
| 280 |
+
success=True,
|
| 281 |
+
content=content_description,
|
| 282 |
+
metadata=metadata
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
"success": True,
|
| 287 |
+
"result": result.to_dict(),
|
| 288 |
+
"message": f"Audio file detected: {metadata['file_size_mb']} MB {file_extension}",
|
| 289 |
+
"note": "Full audio transcription requires additional setup"
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
return {
|
| 294 |
+
"success": False,
|
| 295 |
+
"message": f"Audio analysis failed: {str(e)}",
|
| 296 |
+
"file_path": file_path,
|
| 297 |
+
"error_type": type(e).__name__
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
def _process_data_file(self, file_path: str, **options) -> Dict[str, Any]:
|
| 301 |
+
"""
|
| 302 |
+
Process Excel, CSV, and other data files
|
| 303 |
+
"""
|
| 304 |
+
try:
|
| 305 |
+
file_extension = Path(file_path).suffix.lower()
|
| 306 |
+
|
| 307 |
+
# Read data based on file type
|
| 308 |
+
if file_extension == '.csv':
|
| 309 |
+
df = pd.read_csv(file_path)
|
| 310 |
+
elif file_extension in ['.xlsx', '.xls']:
|
| 311 |
+
df = pd.read_excel(file_path)
|
| 312 |
+
elif file_extension == '.json':
|
| 313 |
+
df = pd.read_json(file_path)
|
| 314 |
+
else:
|
| 315 |
+
# Try as text file
|
| 316 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 317 |
+
content = f.read()
|
| 318 |
+
return self._process_text_content(content, file_path)
|
| 319 |
+
|
| 320 |
+
# Analyze DataFrame
|
| 321 |
+
metadata = {
|
| 322 |
+
"shape": df.shape,
|
| 323 |
+
"columns": df.columns.tolist(),
|
| 324 |
+
"column_count": len(df.columns),
|
| 325 |
+
"row_count": len(df),
|
| 326 |
+
"data_types": df.dtypes.to_dict(),
|
| 327 |
+
"memory_usage": df.memory_usage(deep=True).sum(),
|
| 328 |
+
"has_missing_values": df.isnull().any().any()
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# Basic statistics for numeric columns
|
| 332 |
+
numeric_columns = df.select_dtypes(include=['number']).columns.tolist()
|
| 333 |
+
if numeric_columns:
|
| 334 |
+
metadata["numeric_columns"] = numeric_columns
|
| 335 |
+
metadata["numeric_stats"] = df[numeric_columns].describe().to_dict()
|
| 336 |
+
|
| 337 |
+
# Sample data (first few rows)
|
| 338 |
+
sample_data = df.head(5).to_dict(orient='records')
|
| 339 |
+
|
| 340 |
+
# Generate content description
|
| 341 |
+
content_description = self._describe_data_content(df, metadata)
|
| 342 |
+
|
| 343 |
+
result = FileProcessingResult(
|
| 344 |
+
file_path=file_path,
|
| 345 |
+
file_type="data",
|
| 346 |
+
success=True,
|
| 347 |
+
content={
|
| 348 |
+
"description": content_description,
|
| 349 |
+
"sample_data": sample_data,
|
| 350 |
+
"full_data": df.to_dict(orient='records') if len(df) <= 100 else None
|
| 351 |
+
},
|
| 352 |
+
metadata=metadata
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return {
|
| 356 |
+
"success": True,
|
| 357 |
+
"result": result.to_dict(),
|
| 358 |
+
"message": f"Successfully processed data file: {df.shape[0]} rows, {df.shape[1]} columns"
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
return {
|
| 363 |
+
"success": False,
|
| 364 |
+
"message": f"Data file processing failed: {str(e)}",
|
| 365 |
+
"file_path": file_path,
|
| 366 |
+
"error_type": type(e).__name__
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
def _describe_data_content(self, df: pd.DataFrame, metadata: Dict[str, Any]) -> str:
|
| 370 |
+
"""Generate description of data file content"""
|
| 371 |
+
description_parts = []
|
| 372 |
+
|
| 373 |
+
# Basic structure
|
| 374 |
+
rows, cols = df.shape
|
| 375 |
+
description_parts.append(f"Data table with {rows} rows and {cols} columns")
|
| 376 |
+
|
| 377 |
+
# Column information
|
| 378 |
+
if cols <= 10:
|
| 379 |
+
column_names = ", ".join(df.columns.tolist())
|
| 380 |
+
description_parts.append(f"Columns: {column_names}")
|
| 381 |
+
else:
|
| 382 |
+
description_parts.append(f"Columns include: {', '.join(df.columns.tolist()[:5])}... and {cols-5} more")
|
| 383 |
+
|
| 384 |
+
# Data types
|
| 385 |
+
numeric_cols = len(metadata.get("numeric_columns", []))
|
| 386 |
+
if numeric_cols > 0:
|
| 387 |
+
description_parts.append(f"{numeric_cols} numeric columns")
|
| 388 |
+
|
| 389 |
+
# Missing values
|
| 390 |
+
if metadata.get("has_missing_values"):
|
| 391 |
+
description_parts.append("Contains missing values")
|
| 392 |
+
|
| 393 |
+
return ". ".join(description_parts)
|
| 394 |
+
|
| 395 |
+
def _analyze_code(self, file_path: str, **options) -> Dict[str, Any]:
|
| 396 |
+
"""
|
| 397 |
+
Analyze code files (focusing on Python for now)
|
| 398 |
+
"""
|
| 399 |
+
try:
|
| 400 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 401 |
+
code_content = f.read()
|
| 402 |
+
|
| 403 |
+
file_extension = Path(file_path).suffix.lower()
|
| 404 |
+
|
| 405 |
+
if file_extension == '.py':
|
| 406 |
+
return self._analyze_python_code(code_content, file_path)
|
| 407 |
+
else:
|
| 408 |
+
return self._analyze_generic_code(code_content, file_path, file_extension)
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
return {
|
| 412 |
+
"success": False,
|
| 413 |
+
"message": f"Code analysis failed: {str(e)}",
|
| 414 |
+
"file_path": file_path,
|
| 415 |
+
"error_type": type(e).__name__
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
def _analyze_python_code(self, code_content: str, file_path: str) -> Dict[str, Any]:
|
| 419 |
+
"""Analyze Python code structure and content"""
|
| 420 |
+
try:
|
| 421 |
+
# Parse the Python code
|
| 422 |
+
tree = ast.parse(code_content)
|
| 423 |
+
|
| 424 |
+
# Extract code elements
|
| 425 |
+
functions = []
|
| 426 |
+
classes = []
|
| 427 |
+
imports = []
|
| 428 |
+
|
| 429 |
+
for node in ast.walk(tree):
|
| 430 |
+
if isinstance(node, ast.FunctionDef):
|
| 431 |
+
functions.append({
|
| 432 |
+
"name": node.name,
|
| 433 |
+
"line": node.lineno,
|
| 434 |
+
"args": [arg.arg for arg in node.args.args]
|
| 435 |
+
})
|
| 436 |
+
elif isinstance(node, ast.ClassDef):
|
| 437 |
+
classes.append({
|
| 438 |
+
"name": node.name,
|
| 439 |
+
"line": node.lineno
|
| 440 |
+
})
|
| 441 |
+
elif isinstance(node, (ast.Import, ast.ImportFrom)):
|
| 442 |
+
if isinstance(node, ast.Import):
|
| 443 |
+
for alias in node.names:
|
| 444 |
+
imports.append(alias.name)
|
| 445 |
+
else:
|
| 446 |
+
module = node.module or ""
|
| 447 |
+
for alias in node.names:
|
| 448 |
+
imports.append(f"{module}.{alias.name}")
|
| 449 |
+
|
| 450 |
+
# Code statistics
|
| 451 |
+
lines = code_content.split('\n')
|
| 452 |
+
metadata = {
|
| 453 |
+
"total_lines": len(lines),
|
| 454 |
+
"non_empty_lines": len([line for line in lines if line.strip()]),
|
| 455 |
+
"comment_lines": len([line for line in lines if line.strip().startswith('#')]),
|
| 456 |
+
"function_count": len(functions),
|
| 457 |
+
"class_count": len(classes),
|
| 458 |
+
"import_count": len(imports),
|
| 459 |
+
"functions": functions[:10], # First 10 functions
|
| 460 |
+
"classes": classes[:10], # First 10 classes
|
| 461 |
+
"imports": list(set(imports)) # Unique imports
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
# Generate description
|
| 465 |
+
content_description = self._describe_python_code(metadata)
|
| 466 |
+
|
| 467 |
+
result = FileProcessingResult(
|
| 468 |
+
file_path=file_path,
|
| 469 |
+
file_type="python_code",
|
| 470 |
+
success=True,
|
| 471 |
+
content={
|
| 472 |
+
"description": content_description,
|
| 473 |
+
"code_snippet": code_content[:1000] + "..." if len(code_content) > 1000 else code_content,
|
| 474 |
+
"full_code": code_content
|
| 475 |
+
},
|
| 476 |
+
metadata=metadata
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return {
|
| 480 |
+
"success": True,
|
| 481 |
+
"result": result.to_dict(),
|
| 482 |
+
"message": f"Python code analyzed: {metadata['function_count']} functions, {metadata['class_count']} classes"
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
except SyntaxError as e:
|
| 486 |
+
return {
|
| 487 |
+
"success": False,
|
| 488 |
+
"message": f"Python syntax error: {str(e)}",
|
| 489 |
+
"file_path": file_path,
|
| 490 |
+
"error_type": "syntax_error"
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
def _describe_python_code(self, metadata: Dict[str, Any]) -> str:
|
| 494 |
+
"""Generate description of Python code"""
|
| 495 |
+
description_parts = []
|
| 496 |
+
|
| 497 |
+
# Basic statistics
|
| 498 |
+
total_lines = metadata.get("total_lines", 0)
|
| 499 |
+
non_empty_lines = metadata.get("non_empty_lines", 0)
|
| 500 |
+
description_parts.append(f"Python file with {total_lines} total lines ({non_empty_lines} non-empty)")
|
| 501 |
+
|
| 502 |
+
# Functions and classes
|
| 503 |
+
func_count = metadata.get("function_count", 0)
|
| 504 |
+
class_count = metadata.get("class_count", 0)
|
| 505 |
+
|
| 506 |
+
if func_count > 0:
|
| 507 |
+
description_parts.append(f"{func_count} functions defined")
|
| 508 |
+
if class_count > 0:
|
| 509 |
+
description_parts.append(f"{class_count} classes defined")
|
| 510 |
+
|
| 511 |
+
# Imports
|
| 512 |
+
imports = metadata.get("imports", [])
|
| 513 |
+
if imports:
|
| 514 |
+
if len(imports) <= 5:
|
| 515 |
+
description_parts.append(f"Imports: {', '.join(imports)}")
|
| 516 |
+
else:
|
| 517 |
+
description_parts.append(f"Imports {len(imports)} modules including: {', '.join(imports[:3])}...")
|
| 518 |
+
|
| 519 |
+
return ". ".join(description_parts)
|
| 520 |
+
|
| 521 |
+
def _analyze_generic_code(self, code_content: str, file_path: str, extension: str) -> Dict[str, Any]:
|
| 522 |
+
"""Analyze non-Python code files"""
|
| 523 |
+
lines = code_content.split('\n')
|
| 524 |
+
|
| 525 |
+
metadata = {
|
| 526 |
+
"file_extension": extension,
|
| 527 |
+
"total_lines": len(lines),
|
| 528 |
+
"non_empty_lines": len([line for line in lines if line.strip()]),
|
| 529 |
+
"file_size": len(code_content),
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
# Basic content analysis
|
| 533 |
+
content_description = f"{extension.upper()} code file with {metadata['total_lines']} lines"
|
| 534 |
+
|
| 535 |
+
result = FileProcessingResult(
|
| 536 |
+
file_path=file_path,
|
| 537 |
+
file_type="code",
|
| 538 |
+
success=True,
|
| 539 |
+
content={
|
| 540 |
+
"description": content_description,
|
| 541 |
+
"code_snippet": code_content[:500] + "..." if len(code_content) > 500 else code_content
|
| 542 |
+
},
|
| 543 |
+
metadata=metadata
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
return {
|
| 547 |
+
"success": True,
|
| 548 |
+
"result": result.to_dict(),
|
| 549 |
+
"message": f"Code file analyzed: {metadata['total_lines']} lines of {extension.upper()} code"
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
def _process_text_file(self, file_path: str, **options) -> Dict[str, Any]:
|
| 553 |
+
"""Process plain text files"""
|
| 554 |
+
try:
|
| 555 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 556 |
+
content = f.read()
|
| 557 |
+
|
| 558 |
+
return self._process_text_content(content, file_path)
|
| 559 |
+
|
| 560 |
+
except UnicodeDecodeError:
|
| 561 |
+
# Try with different encoding
|
| 562 |
+
try:
|
| 563 |
+
with open(file_path, 'r', encoding='latin-1') as f:
|
| 564 |
+
content = f.read()
|
| 565 |
+
return self._process_text_content(content, file_path)
|
| 566 |
+
except Exception as e:
|
| 567 |
+
return {
|
| 568 |
+
"success": False,
|
| 569 |
+
"message": f"Text file processing failed: {str(e)}",
|
| 570 |
+
"file_path": file_path,
|
| 571 |
+
"error_type": type(e).__name__
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
def _process_text_content(self, content: str, file_path: str) -> Dict[str, Any]:
|
| 575 |
+
"""Process text content and extract metadata"""
|
| 576 |
+
lines = content.split('\n')
|
| 577 |
+
words = content.split()
|
| 578 |
+
|
| 579 |
+
metadata = {
|
| 580 |
+
"character_count": len(content),
|
| 581 |
+
"word_count": len(words),
|
| 582 |
+
"line_count": len(lines),
|
| 583 |
+
"non_empty_lines": len([line for line in lines if line.strip()]),
|
| 584 |
+
"average_line_length": sum(len(line) for line in lines) / max(len(lines), 1)
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
# Generate preview
|
| 588 |
+
preview = content[:500] + "..." if len(content) > 500 else content
|
| 589 |
+
|
| 590 |
+
result = FileProcessingResult(
|
| 591 |
+
file_path=file_path,
|
| 592 |
+
file_type="text",
|
| 593 |
+
success=True,
|
| 594 |
+
content={
|
| 595 |
+
"text": content,
|
| 596 |
+
"preview": preview
|
| 597 |
+
},
|
| 598 |
+
metadata=metadata
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
return {
|
| 602 |
+
"success": True,
|
| 603 |
+
"result": result.to_dict(),
|
| 604 |
+
"message": f"Text file processed: {metadata['word_count']} words, {metadata['line_count']} lines"
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
def test_file_processor_tool():
|
| 608 |
+
"""Test the file processor tool with various file types"""
|
| 609 |
+
tool = FileProcessorTool()
|
| 610 |
+
|
| 611 |
+
# Create test files for demonstration
|
| 612 |
+
test_files = []
|
| 613 |
+
|
| 614 |
+
# Create a simple CSV file
|
| 615 |
+
csv_content = """name,age,city
|
| 616 |
+
John,25,New York
|
| 617 |
+
Jane,30,San Francisco
|
| 618 |
+
Bob,35,Chicago"""
|
| 619 |
+
|
| 620 |
+
csv_path = "/tmp/test_data.csv"
|
| 621 |
+
with open(csv_path, 'w') as f:
|
| 622 |
+
f.write(csv_content)
|
| 623 |
+
test_files.append(csv_path)
|
| 624 |
+
|
| 625 |
+
# Create a simple Python file
|
| 626 |
+
py_content = """#!/usr/bin/env python3
|
| 627 |
+
import os
|
| 628 |
+
import sys
|
| 629 |
+
|
| 630 |
+
def hello_world():
|
| 631 |
+
'''Simple greeting function'''
|
| 632 |
+
return "Hello, World!"
|
| 633 |
+
|
| 634 |
+
class TestClass:
|
| 635 |
+
def __init__(self):
|
| 636 |
+
self.value = 42
|
| 637 |
+
|
| 638 |
+
def get_value(self):
|
| 639 |
+
return self.value
|
| 640 |
+
|
| 641 |
+
if __name__ == "__main__":
|
| 642 |
+
print(hello_world())
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
py_path = "/tmp/test_script.py"
|
| 646 |
+
with open(py_path, 'w') as f:
|
| 647 |
+
f.write(py_content)
|
| 648 |
+
test_files.append(py_path)
|
| 649 |
+
|
| 650 |
+
print("🧪 Testing File Processor Tool...")
|
| 651 |
+
|
| 652 |
+
for i, file_path in enumerate(test_files, 1):
|
| 653 |
+
print(f"\n--- Test {i}: {file_path} ---")
|
| 654 |
+
try:
|
| 655 |
+
result = tool.execute(file_path)
|
| 656 |
+
|
| 657 |
+
if result.success:
|
| 658 |
+
file_result = result.result['result']
|
| 659 |
+
print(f"✅ Success: {file_result['file_type']} file")
|
| 660 |
+
print(f" Message: {result.result.get('message', 'No message')}")
|
| 661 |
+
if 'metadata' in file_result:
|
| 662 |
+
metadata = file_result['metadata']
|
| 663 |
+
print(f" Metadata: {list(metadata.keys())}")
|
| 664 |
+
else:
|
| 665 |
+
print(f"❌ Error: {result.result.get('message', 'Unknown error')}")
|
| 666 |
+
|
| 667 |
+
print(f" Execution time: {result.execution_time:.3f}s")
|
| 668 |
+
|
| 669 |
+
except Exception as e:
|
| 670 |
+
print(f"❌ Exception: {str(e)}")
|
| 671 |
+
|
| 672 |
+
# Clean up test files
|
| 673 |
+
for file_path in test_files:
|
| 674 |
+
try:
|
| 675 |
+
os.remove(file_path)
|
| 676 |
+
except:
|
| 677 |
+
pass
|
| 678 |
+
|
| 679 |
+
if __name__ == "__main__":
|
| 680 |
+
# Test when run directly
|
| 681 |
+
test_file_processor_tool()
|
src/tools/web_search_tool.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Web Search Tool for GAIA Agent System
|
| 4 |
+
Handles web searches using DuckDuckGo and content extraction from URLs
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
from typing import Dict, List, Optional, Any
|
| 11 |
+
from urllib.parse import urlparse, urljoin
|
| 12 |
+
import requests
|
| 13 |
+
from bs4 import BeautifulSoup
|
| 14 |
+
from duckduckgo_search import DDGS
|
| 15 |
+
|
| 16 |
+
from tools import BaseTool
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class WebSearchResult:
|
| 21 |
+
"""Container for web search results"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, title: str, url: str, snippet: str, content: str = ""):
|
| 24 |
+
self.title = title
|
| 25 |
+
self.url = url
|
| 26 |
+
self.snippet = snippet
|
| 27 |
+
self.content = content
|
| 28 |
+
|
| 29 |
+
def to_dict(self) -> Dict[str, str]:
|
| 30 |
+
return {
|
| 31 |
+
"title": self.title,
|
| 32 |
+
"url": self.url,
|
| 33 |
+
"snippet": self.snippet,
|
| 34 |
+
"content": self.content[:1500] + "..." if len(self.content) > 1500 else self.content
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
class WebSearchTool(BaseTool):
|
| 38 |
+
"""
|
| 39 |
+
Web search tool using DuckDuckGo
|
| 40 |
+
Handles searches, URL content extraction, and result filtering
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
super().__init__("web_search")
|
| 45 |
+
|
| 46 |
+
# Configure requests session for web scraping
|
| 47 |
+
self.session = requests.Session()
|
| 48 |
+
self.session.headers.update({
|
| 49 |
+
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 50 |
+
})
|
| 51 |
+
self.session.timeout = 10
|
| 52 |
+
|
| 53 |
+
def _execute_impl(self, input_data: Any, **kwargs) -> Dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Execute web search operations based on input type
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
input_data: Can be:
|
| 59 |
+
- str: Search query or URL to extract content from
|
| 60 |
+
- dict: {"query": str, "action": str, "limit": int, "extract_content": bool}
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
if isinstance(input_data, str):
|
| 64 |
+
# Handle both search queries and URLs
|
| 65 |
+
if self._is_url(input_data):
|
| 66 |
+
return self._extract_content_from_url(input_data)
|
| 67 |
+
else:
|
| 68 |
+
return self._search_web(input_data)
|
| 69 |
+
|
| 70 |
+
elif isinstance(input_data, dict):
|
| 71 |
+
query = input_data.get("query", "")
|
| 72 |
+
action = input_data.get("action", "search")
|
| 73 |
+
limit = input_data.get("limit", 5)
|
| 74 |
+
extract_content = input_data.get("extract_content", False)
|
| 75 |
+
|
| 76 |
+
if action == "search":
|
| 77 |
+
return self._search_web(query, limit, extract_content)
|
| 78 |
+
elif action == "extract":
|
| 79 |
+
return self._extract_content_from_url(query)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f"Unknown action: {action}")
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"Unsupported input type: {type(input_data)}")
|
| 84 |
+
|
| 85 |
+
def _is_url(self, text: str) -> bool:
|
| 86 |
+
"""Check if text is a URL"""
|
| 87 |
+
return bool(re.match(r'https?://', text))
|
| 88 |
+
|
| 89 |
+
def _search_web(self, query: str, limit: int = 5, extract_content: bool = False) -> Dict[str, Any]:
|
| 90 |
+
"""
|
| 91 |
+
Search the web using DuckDuckGo
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
logger.info(f"Searching web for: {query}")
|
| 95 |
+
|
| 96 |
+
# Perform DuckDuckGo search
|
| 97 |
+
with DDGS() as ddgs:
|
| 98 |
+
search_results = list(ddgs.text(
|
| 99 |
+
keywords=query,
|
| 100 |
+
max_results=limit,
|
| 101 |
+
region='us-en',
|
| 102 |
+
safesearch='moderate'
|
| 103 |
+
))
|
| 104 |
+
|
| 105 |
+
if not search_results:
|
| 106 |
+
return {
|
| 107 |
+
"query": query,
|
| 108 |
+
"found": False,
|
| 109 |
+
"message": "No web search results found",
|
| 110 |
+
"results": []
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
results = []
|
| 114 |
+
for result in search_results:
|
| 115 |
+
web_result = WebSearchResult(
|
| 116 |
+
title=result.get('title', 'No title'),
|
| 117 |
+
url=result.get('href', ''),
|
| 118 |
+
snippet=result.get('body', 'No description')
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Optionally extract full content from each URL
|
| 122 |
+
if extract_content and web_result.url:
|
| 123 |
+
try:
|
| 124 |
+
content_result = self._extract_content_from_url(web_result.url)
|
| 125 |
+
if content_result.get('found'):
|
| 126 |
+
web_result.content = content_result['content'][:1000] # Limit content size
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.warning(f"Failed to extract content from {web_result.url}: {e}")
|
| 129 |
+
|
| 130 |
+
results.append(web_result.to_dict())
|
| 131 |
+
|
| 132 |
+
return {
|
| 133 |
+
"query": query,
|
| 134 |
+
"found": True,
|
| 135 |
+
"results": results,
|
| 136 |
+
"total_results": len(results),
|
| 137 |
+
"message": f"Found {len(results)} web search results"
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
raise Exception(f"Web search failed: {str(e)}")
|
| 142 |
+
|
| 143 |
+
def _extract_content_from_url(self, url: str) -> Dict[str, Any]:
|
| 144 |
+
"""
|
| 145 |
+
Extract readable content from a web page
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
logger.info(f"Extracting content from: {url}")
|
| 149 |
+
|
| 150 |
+
# Get page content
|
| 151 |
+
response = self.session.get(url)
|
| 152 |
+
response.raise_for_status()
|
| 153 |
+
|
| 154 |
+
# Parse with BeautifulSoup
|
| 155 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 156 |
+
|
| 157 |
+
# Remove script and style elements
|
| 158 |
+
for script in soup(["script", "style", "nav", "header", "footer", "aside"]):
|
| 159 |
+
script.decompose()
|
| 160 |
+
|
| 161 |
+
# Extract title
|
| 162 |
+
title = soup.find('title')
|
| 163 |
+
title_text = title.get_text().strip() if title else "No title"
|
| 164 |
+
|
| 165 |
+
# Extract main content
|
| 166 |
+
content = self._extract_main_content(soup)
|
| 167 |
+
|
| 168 |
+
# Extract metadata
|
| 169 |
+
meta_description = ""
|
| 170 |
+
meta_desc = soup.find('meta', attrs={'name': 'description'})
|
| 171 |
+
if meta_desc:
|
| 172 |
+
meta_description = meta_desc.get('content', '')
|
| 173 |
+
|
| 174 |
+
# Extract links
|
| 175 |
+
links = []
|
| 176 |
+
for link in soup.find_all('a', href=True)[:10]: # First 10 links
|
| 177 |
+
link_url = urljoin(url, link['href'])
|
| 178 |
+
link_text = link.get_text().strip()
|
| 179 |
+
if link_text and len(link_text) > 5: # Filter out short/empty links
|
| 180 |
+
links.append({"text": link_text, "url": link_url})
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
"url": url,
|
| 184 |
+
"found": True,
|
| 185 |
+
"title": title_text,
|
| 186 |
+
"content": content,
|
| 187 |
+
"meta_description": meta_description,
|
| 188 |
+
"links": links,
|
| 189 |
+
"content_length": len(content),
|
| 190 |
+
"message": "Successfully extracted content from URL"
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
except requests.exceptions.RequestException as e:
|
| 194 |
+
return {
|
| 195 |
+
"url": url,
|
| 196 |
+
"found": False,
|
| 197 |
+
"message": f"Failed to fetch URL: {str(e)}",
|
| 198 |
+
"error_type": "network_error"
|
| 199 |
+
}
|
| 200 |
+
except Exception as e:
|
| 201 |
+
return {
|
| 202 |
+
"url": url,
|
| 203 |
+
"found": False,
|
| 204 |
+
"message": f"Failed to extract content: {str(e)}",
|
| 205 |
+
"error_type": "parsing_error"
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
def _extract_main_content(self, soup: BeautifulSoup) -> str:
|
| 209 |
+
"""
|
| 210 |
+
Extract main content from HTML using various strategies
|
| 211 |
+
"""
|
| 212 |
+
content_parts = []
|
| 213 |
+
|
| 214 |
+
# Strategy 1: Look for article/main tags
|
| 215 |
+
main_content = soup.find(['article', 'main'])
|
| 216 |
+
if main_content:
|
| 217 |
+
content_parts.append(main_content.get_text())
|
| 218 |
+
|
| 219 |
+
# Strategy 2: Look for content in common div classes
|
| 220 |
+
content_selectors = [
|
| 221 |
+
'div.content',
|
| 222 |
+
'div.article-content',
|
| 223 |
+
'div.post-content',
|
| 224 |
+
'div.entry-content',
|
| 225 |
+
'div.main-content',
|
| 226 |
+
'div#content',
|
| 227 |
+
'div.text'
|
| 228 |
+
]
|
| 229 |
+
|
| 230 |
+
for selector in content_selectors:
|
| 231 |
+
elements = soup.select(selector)
|
| 232 |
+
for element in elements:
|
| 233 |
+
content_parts.append(element.get_text())
|
| 234 |
+
|
| 235 |
+
# Strategy 3: Look for paragraphs in body
|
| 236 |
+
if not content_parts:
|
| 237 |
+
paragraphs = soup.find_all('p')
|
| 238 |
+
for p in paragraphs[:20]: # First 20 paragraphs
|
| 239 |
+
text = p.get_text().strip()
|
| 240 |
+
if len(text) > 50: # Filter out short paragraphs
|
| 241 |
+
content_parts.append(text)
|
| 242 |
+
|
| 243 |
+
# Clean and combine content
|
| 244 |
+
combined_content = '\n\n'.join(content_parts)
|
| 245 |
+
|
| 246 |
+
# Clean up whitespace and formatting
|
| 247 |
+
combined_content = re.sub(r'\n\s*\n', '\n\n', combined_content) # Multiple newlines
|
| 248 |
+
combined_content = re.sub(r' +', ' ', combined_content) # Multiple spaces
|
| 249 |
+
|
| 250 |
+
return combined_content.strip()[:5000] # Limit to 5000 characters
|
| 251 |
+
|
| 252 |
+
def search_youtube_metadata(self, query: str) -> Dict[str, Any]:
|
| 253 |
+
"""
|
| 254 |
+
Specialized search for YouTube video information
|
| 255 |
+
"""
|
| 256 |
+
try:
|
| 257 |
+
# Search specifically for YouTube videos
|
| 258 |
+
youtube_query = f"site:youtube.com {query}"
|
| 259 |
+
|
| 260 |
+
with DDGS() as ddgs:
|
| 261 |
+
search_results = list(ddgs.text(
|
| 262 |
+
keywords=youtube_query,
|
| 263 |
+
max_results=3,
|
| 264 |
+
region='us-en',
|
| 265 |
+
safesearch='moderate'
|
| 266 |
+
))
|
| 267 |
+
|
| 268 |
+
youtube_results = []
|
| 269 |
+
for result in search_results:
|
| 270 |
+
if 'youtube.com/watch' in result.get('href', ''):
|
| 271 |
+
video_id = self._extract_youtube_id(result['href'])
|
| 272 |
+
|
| 273 |
+
youtube_result = {
|
| 274 |
+
"title": result.get('title', 'No title'),
|
| 275 |
+
"url": result.get('href', ''),
|
| 276 |
+
"description": result.get('body', 'No description'),
|
| 277 |
+
"video_id": video_id
|
| 278 |
+
}
|
| 279 |
+
youtube_results.append(youtube_result)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
"query": query,
|
| 283 |
+
"found": len(youtube_results) > 0,
|
| 284 |
+
"results": youtube_results,
|
| 285 |
+
"message": f"Found {len(youtube_results)} YouTube videos"
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
except Exception as e:
|
| 289 |
+
raise Exception(f"YouTube search failed: {str(e)}")
|
| 290 |
+
|
| 291 |
+
def _extract_youtube_id(self, url: str) -> str:
|
| 292 |
+
"""Extract YouTube video ID from URL"""
|
| 293 |
+
patterns = [
|
| 294 |
+
r'(?:v=|\/)([0-9A-Za-z_-]{11}).*',
|
| 295 |
+
r'(?:embed\/)([0-9A-Za-z_-]{11})',
|
| 296 |
+
r'(?:youtu\.be\/)([0-9A-Za-z_-]{11})'
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
for pattern in patterns:
|
| 300 |
+
match = re.search(pattern, url)
|
| 301 |
+
if match:
|
| 302 |
+
return match.group(1)
|
| 303 |
+
return ""
|
| 304 |
+
|
| 305 |
+
def test_web_search_tool():
|
| 306 |
+
"""Test the web search tool with various queries"""
|
| 307 |
+
tool = WebSearchTool()
|
| 308 |
+
|
| 309 |
+
# Test cases
|
| 310 |
+
test_cases = [
|
| 311 |
+
"Python programming tutorial",
|
| 312 |
+
"https://en.wikipedia.org/wiki/Machine_learning",
|
| 313 |
+
{"query": "artificial intelligence news", "action": "search", "limit": 3},
|
| 314 |
+
{"query": "https://www.python.org", "action": "extract"},
|
| 315 |
+
{"query": "OpenAI ChatGPT", "action": "search", "limit": 2, "extract_content": True}
|
| 316 |
+
]
|
| 317 |
+
|
| 318 |
+
print("🧪 Testing Web Search Tool...")
|
| 319 |
+
|
| 320 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 321 |
+
print(f"\n--- Test {i}: {test_case} ---")
|
| 322 |
+
try:
|
| 323 |
+
result = tool.execute(test_case)
|
| 324 |
+
|
| 325 |
+
if result.success:
|
| 326 |
+
print(f"✅ Success: {result.result.get('message', 'No message')}")
|
| 327 |
+
if result.result.get('found'):
|
| 328 |
+
if 'results' in result.result:
|
| 329 |
+
print(f" Found {len(result.result['results'])} results")
|
| 330 |
+
# Show first result details
|
| 331 |
+
if result.result['results']:
|
| 332 |
+
first_result = result.result['results'][0]
|
| 333 |
+
print(f" First result: {first_result.get('title', 'No title')}")
|
| 334 |
+
print(f" URL: {first_result.get('url', 'No URL')}")
|
| 335 |
+
elif 'content' in result.result:
|
| 336 |
+
print(f" Extracted {len(result.result['content'])} characters")
|
| 337 |
+
print(f" Title: {result.result.get('title', 'No title')}")
|
| 338 |
+
else:
|
| 339 |
+
print(f" Not found: {result.result.get('message', 'Unknown error')}")
|
| 340 |
+
else:
|
| 341 |
+
print(f"❌ Error: {result.error}")
|
| 342 |
+
|
| 343 |
+
print(f" Execution time: {result.execution_time:.2f}s")
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
print(f"❌ Exception: {str(e)}")
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
# Test when run directly
|
| 350 |
+
test_web_search_tool()
|
src/tools/wikipedia_tool.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Wikipedia Tool for GAIA Agent System
|
| 4 |
+
Handles Wikipedia searches, content extraction, and information retrieval
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Any
|
| 10 |
+
import wikipediaapi # Fixed import - using Wikipedia-API package
|
| 11 |
+
from urllib.parse import urlparse, unquote
|
| 12 |
+
|
| 13 |
+
from tools import BaseTool
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class WikipediaSearchResult:
|
| 18 |
+
"""Container for Wikipedia search results"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, title: str, summary: str, url: str, content: str = ""):
|
| 21 |
+
self.title = title
|
| 22 |
+
self.summary = summary
|
| 23 |
+
self.url = url
|
| 24 |
+
self.content = content
|
| 25 |
+
|
| 26 |
+
def to_dict(self) -> Dict[str, str]:
|
| 27 |
+
return {
|
| 28 |
+
"title": self.title,
|
| 29 |
+
"summary": self.summary,
|
| 30 |
+
"url": self.url,
|
| 31 |
+
"content": self.content[:1000] + "..." if len(self.content) > 1000 else self.content
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
class WikipediaTool(BaseTool):
|
| 35 |
+
"""
|
| 36 |
+
Wikipedia tool for searching and extracting information
|
| 37 |
+
Handles disambiguation, missing pages, and content extraction
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super().__init__("wikipedia")
|
| 42 |
+
|
| 43 |
+
# Initialize Wikipedia API client
|
| 44 |
+
self.wiki = wikipediaapi.Wikipedia(
|
| 45 |
+
language='en',
|
| 46 |
+
extract_format=wikipediaapi.ExtractFormat.WIKI,
|
| 47 |
+
user_agent='GAIA-Agent/1.0 (educational-purpose)'
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _execute_impl(self, input_data: Any, **kwargs) -> Dict[str, Any]:
|
| 51 |
+
"""
|
| 52 |
+
Execute Wikipedia operations based on input type
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
input_data: Can be:
|
| 56 |
+
- str: Search query or Wikipedia URL
|
| 57 |
+
- dict: {"query": str, "action": str, "limit": int}
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
if isinstance(input_data, str):
|
| 61 |
+
# Handle both search queries and URLs
|
| 62 |
+
if self._is_wikipedia_url(input_data):
|
| 63 |
+
return self._extract_from_url(input_data)
|
| 64 |
+
else:
|
| 65 |
+
return self._get_page_info(input_data)
|
| 66 |
+
|
| 67 |
+
elif isinstance(input_data, dict):
|
| 68 |
+
query = input_data.get("query", "")
|
| 69 |
+
action = input_data.get("action", "summary")
|
| 70 |
+
|
| 71 |
+
if action == "summary":
|
| 72 |
+
return self._get_summary(query)
|
| 73 |
+
elif action == "content":
|
| 74 |
+
return self._get_full_content(query)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown action: {action}")
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"Unsupported input type: {type(input_data)}")
|
| 79 |
+
|
| 80 |
+
def _is_wikipedia_url(self, url: str) -> bool:
|
| 81 |
+
"""Check if URL is a Wikipedia URL"""
|
| 82 |
+
return "wikipedia.org" in url.lower()
|
| 83 |
+
|
| 84 |
+
def _extract_title_from_url(self, url: str) -> str:
|
| 85 |
+
"""Extract article title from Wikipedia URL"""
|
| 86 |
+
try:
|
| 87 |
+
parsed = urlparse(url)
|
| 88 |
+
if "/wiki/" in parsed.path:
|
| 89 |
+
title = parsed.path.split("/wiki/", 1)[1]
|
| 90 |
+
return unquote(title).replace("_", " ")
|
| 91 |
+
return ""
|
| 92 |
+
except Exception:
|
| 93 |
+
return ""
|
| 94 |
+
|
| 95 |
+
def _extract_from_url(self, url: str) -> Dict[str, Any]:
|
| 96 |
+
"""Extract information from Wikipedia URL"""
|
| 97 |
+
title = self._extract_title_from_url(url)
|
| 98 |
+
if not title:
|
| 99 |
+
raise ValueError(f"Could not extract title from URL: {url}")
|
| 100 |
+
|
| 101 |
+
return self._get_full_content(title)
|
| 102 |
+
|
| 103 |
+
def _get_page_info(self, query: str) -> Dict[str, Any]:
|
| 104 |
+
"""Get basic page information (summary-level)"""
|
| 105 |
+
try:
|
| 106 |
+
page = self.wiki.page(query)
|
| 107 |
+
|
| 108 |
+
if not page.exists():
|
| 109 |
+
return {
|
| 110 |
+
"query": query,
|
| 111 |
+
"found": False,
|
| 112 |
+
"message": f"Wikipedia page '{query}' does not exist",
|
| 113 |
+
"suggestions": self._get_suggestions(query)
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# Get summary (first paragraph)
|
| 117 |
+
summary = page.summary[:500] + "..." if len(page.summary) > 500 else page.summary
|
| 118 |
+
|
| 119 |
+
result = WikipediaSearchResult(
|
| 120 |
+
title=page.title,
|
| 121 |
+
summary=summary,
|
| 122 |
+
url=page.fullurl,
|
| 123 |
+
content=""
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
"query": query,
|
| 128 |
+
"found": True,
|
| 129 |
+
"result": result.to_dict(),
|
| 130 |
+
"message": "Successfully retrieved Wikipedia page info"
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
raise Exception(f"Failed to get Wikipedia page info: {str(e)}")
|
| 135 |
+
|
| 136 |
+
def _get_summary(self, title: str) -> Dict[str, Any]:
|
| 137 |
+
"""Get summary of a specific Wikipedia article"""
|
| 138 |
+
try:
|
| 139 |
+
page = self.wiki.page(title)
|
| 140 |
+
|
| 141 |
+
if not page.exists():
|
| 142 |
+
return {
|
| 143 |
+
"title": title,
|
| 144 |
+
"found": False,
|
| 145 |
+
"message": f"Wikipedia page '{title}' does not exist",
|
| 146 |
+
"suggestions": self._get_suggestions(title)
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# Get summary (first few sentences)
|
| 150 |
+
summary = page.summary[:800] + "..." if len(page.summary) > 800 else page.summary
|
| 151 |
+
|
| 152 |
+
result = WikipediaSearchResult(
|
| 153 |
+
title=page.title,
|
| 154 |
+
summary=summary,
|
| 155 |
+
url=page.fullurl
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"title": title,
|
| 160 |
+
"found": True,
|
| 161 |
+
"result": result.to_dict(),
|
| 162 |
+
"categories": list(page.categories.keys())[:5], # First 5 categories
|
| 163 |
+
"message": "Successfully retrieved Wikipedia summary"
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
raise Exception(f"Failed to get Wikipedia summary: {str(e)}")
|
| 168 |
+
|
| 169 |
+
def _get_full_content(self, title: str) -> Dict[str, Any]:
|
| 170 |
+
"""Get full content of a Wikipedia article"""
|
| 171 |
+
try:
|
| 172 |
+
page = self.wiki.page(title)
|
| 173 |
+
|
| 174 |
+
if not page.exists():
|
| 175 |
+
return {
|
| 176 |
+
"title": title,
|
| 177 |
+
"found": False,
|
| 178 |
+
"message": f"Wikipedia page '{title}' does not exist",
|
| 179 |
+
"suggestions": self._get_suggestions(title)
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Extract key sections
|
| 183 |
+
content_sections = self._parse_content_sections(page.text)
|
| 184 |
+
|
| 185 |
+
result = WikipediaSearchResult(
|
| 186 |
+
title=page.title,
|
| 187 |
+
summary=page.summary[:800] + "..." if len(page.summary) > 800 else page.summary,
|
| 188 |
+
url=page.fullurl,
|
| 189 |
+
content=page.text
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Get linked pages (limit to avoid overwhelming)
|
| 193 |
+
links = []
|
| 194 |
+
link_count = 0
|
| 195 |
+
for link_title in page.links.keys():
|
| 196 |
+
if link_count >= 20: # Limit to first 20 links
|
| 197 |
+
break
|
| 198 |
+
links.append(link_title)
|
| 199 |
+
link_count += 1
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"title": title,
|
| 203 |
+
"found": True,
|
| 204 |
+
"result": result.to_dict(),
|
| 205 |
+
"sections": content_sections,
|
| 206 |
+
"links": links,
|
| 207 |
+
"categories": list(page.categories.keys())[:10], # First 10 categories
|
| 208 |
+
"backlinks_count": len(page.backlinks),
|
| 209 |
+
"message": "Successfully retrieved full Wikipedia content"
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
raise Exception(f"Failed to get Wikipedia content: {str(e)}")
|
| 214 |
+
|
| 215 |
+
def _parse_content_sections(self, content: str) -> Dict[str, str]:
|
| 216 |
+
"""Parse Wikipedia content into sections"""
|
| 217 |
+
sections = {}
|
| 218 |
+
current_section = "Introduction"
|
| 219 |
+
current_content = []
|
| 220 |
+
|
| 221 |
+
lines = content.split('\n')
|
| 222 |
+
for line in lines:
|
| 223 |
+
line = line.strip()
|
| 224 |
+
|
| 225 |
+
# Check for section headers (== Section Name ==)
|
| 226 |
+
if line.startswith('==') and line.endswith('==') and len(line) > 4:
|
| 227 |
+
# Save previous section
|
| 228 |
+
if current_content:
|
| 229 |
+
sections[current_section] = '\n'.join(current_content).strip()
|
| 230 |
+
|
| 231 |
+
# Start new section
|
| 232 |
+
current_section = line.strip('= ').strip()
|
| 233 |
+
current_content = []
|
| 234 |
+
else:
|
| 235 |
+
if line: # Skip empty lines
|
| 236 |
+
current_content.append(line)
|
| 237 |
+
|
| 238 |
+
# Save last section
|
| 239 |
+
if current_content:
|
| 240 |
+
sections[current_section] = '\n'.join(current_content).strip()
|
| 241 |
+
|
| 242 |
+
# Return only first few sections to avoid overwhelming output
|
| 243 |
+
section_items = list(sections.items())[:5]
|
| 244 |
+
return dict(section_items)
|
| 245 |
+
|
| 246 |
+
def _get_suggestions(self, query: str) -> List[str]:
|
| 247 |
+
"""Get search suggestions for a query (simplified)"""
|
| 248 |
+
# Wikipedia-API doesn't have direct search, so we'll provide basic suggestions
|
| 249 |
+
# In a real implementation, you might use the Wikipedia search API
|
| 250 |
+
common_suggestions = [
|
| 251 |
+
query.lower(),
|
| 252 |
+
query.title(),
|
| 253 |
+
query.upper(),
|
| 254 |
+
query.replace(' ', '_'),
|
| 255 |
+
]
|
| 256 |
+
return list(set(common_suggestions))[:3]
|
| 257 |
+
|
| 258 |
+
def test_wikipedia_tool():
|
| 259 |
+
"""Test the Wikipedia tool with various queries"""
|
| 260 |
+
tool = WikipediaTool()
|
| 261 |
+
|
| 262 |
+
# Test cases
|
| 263 |
+
test_cases = [
|
| 264 |
+
"Albert Einstein",
|
| 265 |
+
"https://en.wikipedia.org/wiki/Machine_learning",
|
| 266 |
+
{"query": "Python (programming language)", "action": "summary"},
|
| 267 |
+
{"query": "Artificial Intelligence", "action": "content"},
|
| 268 |
+
"NonexistentPageTest12345"
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
print("🧪 Testing Wikipedia Tool...")
|
| 272 |
+
|
| 273 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 274 |
+
print(f"\n--- Test {i}: {test_case} ---")
|
| 275 |
+
try:
|
| 276 |
+
result = tool.execute(test_case)
|
| 277 |
+
|
| 278 |
+
if result.success:
|
| 279 |
+
print(f"✅ Success: {result.result.get('message', 'No message')}")
|
| 280 |
+
if result.result.get('found'):
|
| 281 |
+
if 'result' in result.result:
|
| 282 |
+
print(f" Title: {result.result['result'].get('title', 'No title')}")
|
| 283 |
+
print(f" Summary: {result.result['result'].get('summary', 'No summary')[:100]}...")
|
| 284 |
+
else:
|
| 285 |
+
print(f" Not found: {result.result.get('message', 'Unknown error')}")
|
| 286 |
+
else:
|
| 287 |
+
print(f"❌ Error: {result.error}")
|
| 288 |
+
|
| 289 |
+
print(f" Execution time: {result.execution_time:.2f}s")
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"❌ Exception: {str(e)}")
|
| 293 |
+
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
# Test when run directly
|
| 296 |
+
test_wikipedia_tool()
|
src/workflow/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GAIA Agent Workflow Package
|
| 4 |
+
Main orchestration workflows for the GAIA benchmark agent system
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .gaia_workflow import GAIAWorkflow, SimpleGAIAWorkflow
|
| 8 |
+
|
| 9 |
+
__all__ = ['GAIAWorkflow', 'SimpleGAIAWorkflow']
|
src/workflow/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (385 Bytes). View file
|
|
|
src/workflow/__pycache__/gaia_workflow.cpython-310.pyc
ADDED
|
Binary file (8.75 kB). View file
|
|
|
src/workflow/gaia_workflow.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GAIA Agent LangGraph Workflow
|
| 4 |
+
Main orchestration workflow for the GAIA benchmark agent system
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict, Any, List, Literal
|
| 9 |
+
from langgraph.graph import StateGraph, END
|
| 10 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 11 |
+
|
| 12 |
+
from agents.state import GAIAAgentState, AgentRole, QuestionType
|
| 13 |
+
from agents.router import RouterAgent
|
| 14 |
+
from agents.web_researcher import WebResearchAgent
|
| 15 |
+
from agents.file_processor_agent import FileProcessorAgent
|
| 16 |
+
from agents.reasoning_agent import ReasoningAgent
|
| 17 |
+
from agents.synthesizer import SynthesizerAgent
|
| 18 |
+
from models.qwen_client import QwenClient
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
class GAIAWorkflow:
|
| 23 |
+
"""
|
| 24 |
+
Main GAIA agent workflow using LangGraph
|
| 25 |
+
Orchestrates router → specialized agents → synthesizer pipeline
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, llm_client: QwenClient):
|
| 29 |
+
self.llm_client = llm_client
|
| 30 |
+
|
| 31 |
+
# Initialize all agents
|
| 32 |
+
self.router = RouterAgent(llm_client)
|
| 33 |
+
self.web_researcher = WebResearchAgent(llm_client)
|
| 34 |
+
self.file_processor = FileProcessorAgent(llm_client)
|
| 35 |
+
self.reasoning_agent = ReasoningAgent(llm_client)
|
| 36 |
+
self.synthesizer = SynthesizerAgent(llm_client)
|
| 37 |
+
|
| 38 |
+
# Create workflow graph
|
| 39 |
+
self.workflow = self._create_workflow()
|
| 40 |
+
|
| 41 |
+
# Compile workflow with memory
|
| 42 |
+
self.app = self.workflow.compile(checkpointer=MemorySaver())
|
| 43 |
+
|
| 44 |
+
def _create_workflow(self) -> StateGraph:
|
| 45 |
+
"""Create the LangGraph workflow"""
|
| 46 |
+
|
| 47 |
+
# Define the workflow graph
|
| 48 |
+
workflow = StateGraph(GAIAAgentState)
|
| 49 |
+
|
| 50 |
+
# Add nodes
|
| 51 |
+
workflow.add_node("router", self._router_node)
|
| 52 |
+
workflow.add_node("web_researcher", self._web_researcher_node)
|
| 53 |
+
workflow.add_node("file_processor", self._file_processor_node)
|
| 54 |
+
workflow.add_node("reasoning_agent", self._reasoning_agent_node)
|
| 55 |
+
workflow.add_node("synthesizer", self._synthesizer_node)
|
| 56 |
+
|
| 57 |
+
# Define entry point
|
| 58 |
+
workflow.set_entry_point("router")
|
| 59 |
+
|
| 60 |
+
# Add conditional edges from router to agents
|
| 61 |
+
workflow.add_conditional_edges(
|
| 62 |
+
"router",
|
| 63 |
+
self._route_to_agents,
|
| 64 |
+
{
|
| 65 |
+
"web_researcher": "web_researcher",
|
| 66 |
+
"file_processor": "file_processor",
|
| 67 |
+
"reasoning_agent": "reasoning_agent",
|
| 68 |
+
"multi_agent": "web_researcher", # Start with web researcher for multi-agent
|
| 69 |
+
"synthesizer": "synthesizer" # Direct to synthesizer if no agents needed
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Add edges from agents to synthesizer
|
| 74 |
+
workflow.add_edge("web_researcher", "synthesizer")
|
| 75 |
+
workflow.add_edge("file_processor", "synthesizer")
|
| 76 |
+
workflow.add_edge("reasoning_agent", "synthesizer")
|
| 77 |
+
|
| 78 |
+
# Add conditional edges for multi-agent scenarios
|
| 79 |
+
workflow.add_conditional_edges(
|
| 80 |
+
"synthesizer",
|
| 81 |
+
self._check_if_complete,
|
| 82 |
+
{
|
| 83 |
+
"complete": END,
|
| 84 |
+
"need_more_agents": "file_processor" # Route to next agent if needed
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return workflow
|
| 89 |
+
|
| 90 |
+
def _router_node(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 91 |
+
"""Router node - classifies question and selects agents"""
|
| 92 |
+
logger.info("🧭 Executing router node")
|
| 93 |
+
return self.router.route_question(state)
|
| 94 |
+
|
| 95 |
+
def _web_researcher_node(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 96 |
+
"""Web researcher node"""
|
| 97 |
+
logger.info("🌐 Executing web researcher node")
|
| 98 |
+
return self.web_researcher.process(state)
|
| 99 |
+
|
| 100 |
+
def _file_processor_node(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 101 |
+
"""File processor node"""
|
| 102 |
+
logger.info("📁 Executing file processor node")
|
| 103 |
+
return self.file_processor.process(state)
|
| 104 |
+
|
| 105 |
+
def _reasoning_agent_node(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 106 |
+
"""Reasoning agent node"""
|
| 107 |
+
logger.info("🧠 Executing reasoning agent node")
|
| 108 |
+
return self.reasoning_agent.process(state)
|
| 109 |
+
|
| 110 |
+
def _synthesizer_node(self, state: GAIAAgentState) -> GAIAAgentState:
|
| 111 |
+
"""Synthesizer node - combines agent results"""
|
| 112 |
+
logger.info("🔗 Executing synthesizer node")
|
| 113 |
+
return self.synthesizer.process(state)
|
| 114 |
+
|
| 115 |
+
def _route_to_agents(self, state: GAIAAgentState) -> str:
|
| 116 |
+
"""Determine which agent(s) to route to based on router decision"""
|
| 117 |
+
|
| 118 |
+
selected_agents = state.selected_agents
|
| 119 |
+
|
| 120 |
+
# Remove synthesizer from routing decision (it's always last)
|
| 121 |
+
agent_roles = [agent for agent in selected_agents if agent != AgentRole.SYNTHESIZER]
|
| 122 |
+
|
| 123 |
+
if not agent_roles:
|
| 124 |
+
# No specific agents selected, go directly to synthesizer
|
| 125 |
+
return "synthesizer"
|
| 126 |
+
elif len(agent_roles) == 1:
|
| 127 |
+
# Single agent selected
|
| 128 |
+
agent = agent_roles[0]
|
| 129 |
+
if agent == AgentRole.WEB_RESEARCHER:
|
| 130 |
+
return "web_researcher"
|
| 131 |
+
elif agent == AgentRole.FILE_PROCESSOR:
|
| 132 |
+
return "file_processor"
|
| 133 |
+
elif agent == AgentRole.REASONING_AGENT:
|
| 134 |
+
return "reasoning_agent"
|
| 135 |
+
else:
|
| 136 |
+
return "synthesizer"
|
| 137 |
+
else:
|
| 138 |
+
# Multiple agents - start with web researcher
|
| 139 |
+
# The workflow will handle additional agents in subsequent steps
|
| 140 |
+
return "multi_agent"
|
| 141 |
+
|
| 142 |
+
def _check_if_complete(self, state: GAIAAgentState) -> str:
|
| 143 |
+
"""Check if processing is complete or if more agents are needed"""
|
| 144 |
+
|
| 145 |
+
# If synthesis is complete, we're done
|
| 146 |
+
if state.is_complete:
|
| 147 |
+
return "complete"
|
| 148 |
+
|
| 149 |
+
# Check if we need to run additional agents
|
| 150 |
+
selected_agents = state.selected_agents
|
| 151 |
+
executed_agents = set(state.agent_results.keys())
|
| 152 |
+
|
| 153 |
+
# Find agents that haven't been executed yet
|
| 154 |
+
remaining_agents = [
|
| 155 |
+
agent for agent in selected_agents
|
| 156 |
+
if agent not in executed_agents and agent != AgentRole.SYNTHESIZER
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
if remaining_agents:
|
| 160 |
+
# Route to next agent
|
| 161 |
+
next_agent = remaining_agents[0]
|
| 162 |
+
if next_agent == AgentRole.FILE_PROCESSOR:
|
| 163 |
+
return "need_more_agents" # This will route to file_processor
|
| 164 |
+
elif next_agent == AgentRole.REASONING_AGENT:
|
| 165 |
+
return "need_more_agents" # Would need additional routing logic
|
| 166 |
+
else:
|
| 167 |
+
return "complete"
|
| 168 |
+
else:
|
| 169 |
+
return "complete"
|
| 170 |
+
|
| 171 |
+
def process_question(self, question: str, file_path: str = None, file_name: str = None,
|
| 172 |
+
task_id: str = None, difficulty_level: int = 1) -> GAIAAgentState:
|
| 173 |
+
"""
|
| 174 |
+
Process a GAIA question through the complete workflow
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
question: The question to process
|
| 178 |
+
file_path: Optional path to associated file
|
| 179 |
+
file_name: Optional name of associated file
|
| 180 |
+
task_id: Optional task identifier
|
| 181 |
+
difficulty_level: Question difficulty (1-3)
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
GAIAAgentState with final results
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
logger.info(f"🚀 Processing question: {question[:100]}...")
|
| 188 |
+
|
| 189 |
+
# Initialize state
|
| 190 |
+
initial_state = GAIAAgentState()
|
| 191 |
+
initial_state.task_id = task_id or f"workflow_{hash(question) % 10000}"
|
| 192 |
+
initial_state.question = question
|
| 193 |
+
initial_state.file_path = file_path
|
| 194 |
+
initial_state.file_name = file_name
|
| 195 |
+
initial_state.difficulty_level = difficulty_level
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
# Execute workflow
|
| 199 |
+
final_state = self.app.invoke(
|
| 200 |
+
initial_state,
|
| 201 |
+
config={"configurable": {"thread_id": initial_state.task_id}}
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
logger.info(f"✅ Workflow complete: {final_state.final_answer[:100]}...")
|
| 205 |
+
return final_state
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
error_msg = f"Workflow execution failed: {str(e)}"
|
| 209 |
+
logger.error(error_msg)
|
| 210 |
+
|
| 211 |
+
# Create error state
|
| 212 |
+
initial_state.add_error(error_msg)
|
| 213 |
+
initial_state.final_answer = "Workflow execution failed"
|
| 214 |
+
initial_state.final_confidence = 0.0
|
| 215 |
+
initial_state.final_reasoning = error_msg
|
| 216 |
+
initial_state.is_complete = True
|
| 217 |
+
initial_state.requires_human_review = True
|
| 218 |
+
|
| 219 |
+
return initial_state
|
| 220 |
+
|
| 221 |
+
def get_workflow_visualization(self) -> str:
|
| 222 |
+
"""Get a text representation of the workflow"""
|
| 223 |
+
return """
|
| 224 |
+
GAIA Agent Workflow:
|
| 225 |
+
|
| 226 |
+
┌─────────────┐
|
| 227 |
+
│ Router │ ← Entry Point
|
| 228 |
+
└──────┬──────┘
|
| 229 |
+
│
|
| 230 |
+
├─ Web Researcher ──┐
|
| 231 |
+
├─ File Processor ──┤
|
| 232 |
+
├─ Reasoning Agent ─┤
|
| 233 |
+
│ │
|
| 234 |
+
▼ ▼
|
| 235 |
+
┌─────────────┐ ┌──────────────┐
|
| 236 |
+
│ Synthesizer │ ←──┤ Agent Results │
|
| 237 |
+
└──────┬──────┘ └──────────────┘
|
| 238 |
+
│
|
| 239 |
+
▼
|
| 240 |
+
┌─────────────┐
|
| 241 |
+
│ END │
|
| 242 |
+
└─────────────┘
|
| 243 |
+
|
| 244 |
+
Flow:
|
| 245 |
+
1. Router classifies question and selects appropriate agent(s)
|
| 246 |
+
2. Selected agents process question in parallel/sequence
|
| 247 |
+
3. Synthesizer combines results into final answer
|
| 248 |
+
4. Workflow completes with final state
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
# Simplified workflow for cases where we don't need full LangGraph
|
| 252 |
+
class SimpleGAIAWorkflow:
|
| 253 |
+
"""
|
| 254 |
+
Simplified workflow that doesn't require LangGraph for basic cases
|
| 255 |
+
Useful for testing and lightweight deployments
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, llm_client: QwenClient):
|
| 259 |
+
self.llm_client = llm_client
|
| 260 |
+
self.router = RouterAgent(llm_client)
|
| 261 |
+
self.web_researcher = WebResearchAgent(llm_client)
|
| 262 |
+
self.file_processor = FileProcessorAgent(llm_client)
|
| 263 |
+
self.reasoning_agent = ReasoningAgent(llm_client)
|
| 264 |
+
self.synthesizer = SynthesizerAgent(llm_client)
|
| 265 |
+
|
| 266 |
+
def process_question(self, question: str, file_path: str = None, file_name: str = None,
|
| 267 |
+
task_id: str = None, difficulty_level: int = 1) -> GAIAAgentState:
|
| 268 |
+
"""Process question with simplified sequential workflow"""
|
| 269 |
+
|
| 270 |
+
# Initialize state
|
| 271 |
+
state = GAIAAgentState()
|
| 272 |
+
state.task_id = task_id or f"simple_{hash(question) % 10000}"
|
| 273 |
+
state.question = question
|
| 274 |
+
state.file_path = file_path
|
| 275 |
+
state.file_name = file_name
|
| 276 |
+
state.difficulty_level = difficulty_level
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
# Step 1: Route
|
| 280 |
+
state = self.router.route_question(state)
|
| 281 |
+
|
| 282 |
+
# Step 2: Execute agents
|
| 283 |
+
for agent_role in state.selected_agents:
|
| 284 |
+
if agent_role == AgentRole.WEB_RESEARCHER:
|
| 285 |
+
state = self.web_researcher.process(state)
|
| 286 |
+
elif agent_role == AgentRole.FILE_PROCESSOR:
|
| 287 |
+
state = self.file_processor.process(state)
|
| 288 |
+
elif agent_role == AgentRole.REASONING_AGENT:
|
| 289 |
+
state = self.reasoning_agent.process(state)
|
| 290 |
+
# Skip synthesizer for now
|
| 291 |
+
|
| 292 |
+
# Step 3: Synthesize
|
| 293 |
+
state = self.synthesizer.process(state)
|
| 294 |
+
|
| 295 |
+
return state
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
error_msg = f"Simple workflow failed: {str(e)}"
|
| 299 |
+
state.add_error(error_msg)
|
| 300 |
+
state.final_answer = "Processing failed"
|
| 301 |
+
state.final_confidence = 0.0
|
| 302 |
+
state.final_reasoning = error_msg
|
| 303 |
+
state.is_complete = True
|
| 304 |
+
return state
|